In [1]:
import pickle
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
import copy

In [1]:
teams = ['BKN', 'MIL', 'GSW', 'LAL', 'IND', 'CHA', 'CHI', 'DET',
          'WAS', 'TOR', 'BOS', 'NYK', 'CLE', 'MEM', 'PHI', 'NOP',
          'HOU','MIN', 'ORL', 'SAS', 'OKC', 'UTA', 'SAC', 'POR',
          'DEN', 'PHX', 'DAL', 'ATL', 'MIA', 'LAC']
window = 4 # 4 samples per game (one sample / quarter)

In [2]:
N_players = 5
players = [f'player{i}' for i in range(1, N_players+1)]
n_components = 3 # num of hidden states
n_features = 3 # num of observed states
O_symbols = [0, 1, 2] # under-, avg-, over- performance
H_symbols = [0, 1, 2] # corresponding mental states
T = 100
learning_iterations = 100

In [3]:
# === M ===
avg_transO = np.array([[0.5, 0.3, 0.2],
                       [0.25, 0.5, 0.25],
                       [0.2, 0.3, 0.5]])

star_transO = np.array([[0.7, 0.3, 0],
                        [0.1, 0.8, 0.1],
                        [0, 0.3, 0.7]])
# === N ===
avg_transH = np.array([[0.8, 0.2, 0.0],
                       [0.1, 0.8, 0.1],
                       [0.0, 0.2, 0.8]])

# === emission_prob ===
avg_emission = np.array([[0.7, 0.3, 0],
                       [0.1, 0.8, 0.1],
                       [0.0, 0.3, 0.7]])

# ===  R ===
R_singleH = np.array([1] + [0] * len(players))
def R_singleHO(player):
    i = int(player[-1])
    arr = [0] * (len(players)+1)
    arr[0] = 0.7
    arr[i] = 0.3
    return np.array(arr)

def R_singleO(player):
    i = int(player[-1])
    arr = [0] * (len(players)+1)
    arr[i] = 1
    return np.array(arr)

def R_star(player, star):
    if player == star:
        return R_singleH
    arr = [0] * (len(players)+1)
    i = int(star[-1])
    arr[i] = 1
    return np.array(arr)
    
R_uniform = np.array([1/(1 + len(players))] * (len(players) + 1))

In [4]:
M = star_transO
N = avg_transH
E = avg_emission
R = {player: R_star(player, 'player1') for player in players}
initial_dist = np.array([0, 1, 0])

In [5]:
Os = {team: dict() for team in teams}
Hs = {team: dict() for team in teams}

for team in teams:
    with open(f'../team-data/observations/{team}_observations.pickle', 'rb') as file:
        Os[team] = pickle.load(file)

In [6]:
def cond(player, h1, h2, t, M_, N_, R_): # P(H_t^player = h1 | H_{t-1}^player = h2, O_{t-1})    
    # Requires M_, N_, R_, Os
    v = [N_[h2][h1]]
    for teammate in players:
        v.append(M_[Os[teammate][t-1]][h1])
            
    v = np.array(v)
    return np.dot(R_[player], v) 

In [7]:
def reconstruct_hidden(Os, M, N, E, R):    
    delta = {p: dict() for p in players}
    psi = {p: dict() for p in players}
    for p in players:
        delta[p][0] = np.array([0, 1, 0])
        psi[p][0] = np.array([None, None, None])

    for p in players:  
        for t in range(1, T):
            max_prev = np.array([np.max([delta[p][t-1][h_] * cond(p, h, h_, t, M, N, R) for h_ in H_symbols]) for h in H_symbols])
            delta[p][t] = [E[h][Os[p][t]] * max_prev[h] for h in H_symbols]
            psi[p][t] = np.array([np.argmax([delta[p][t-1][h_] * cond(p, h, h_, t, M, N, R) for h_ in H_symbols]) for h in H_symbols])

    Hs_predicted = {p: [] for p in players}
    for p in players:
        h_final = np.argmax([delta[p][T-1][h] for h in H_symbols])
        Hs_predicted[p] = [h_final]
        h = psi[p][T-1][h_final]
        for t in range(T-2, -1, -1):
            Hs_predicted[p] = [h] + Hs_predicted[p]
            h = psi[p][t][h]
            
    return Hs_predicted

In [8]:
Hs_predicted = reconstruct_hidden(Os, M, N, E, R)

In [10]:
def find_timestamps(timeseries):
    # Find timestamps of team collapse (ingore continuous 0, only consider the time of the first 0)
    indices = np.where(timeseries == 0)[0]
    if not indices.any():
        return []
    
    timestamps = [indices[0]]
    for i in range(1, len(indices)): # remove contiguous indices
        if indices[i] != indices[i-1] + 1:
            timestamps.append(indices[i])
    return timestamps

In [11]:
Hs_team_predicted = np.array([sum(states) for states in list(zip(*[Hs_predicted[p] for p in players]))])

In [12]:
find_timestamps(Hs_team_predicted)

[35]