In [2]:
%load_ext autoreload
%autoreload 2

import os, json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib import animation
%matplotlib inline

from dataset import SportsDataset
from datatools.trace_animator import TraceAnimator
from datatools.trace_helper import TraceHelper
from datatools.visualize_helper import VisualizeHelper
from datatools.nba_helper import NBADataHelper, NBADataAnimator
from datatools.nfl_helper import NFLDataHelper
from models import load_model
from models.utils import get_dataset_config, print_helper, reshape_tensor, sort_players

from models.graph_imputer.graph_imputer import BidirectionalGraphImputer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Model evaluating on test data

### Load model

In [42]:
device = "cuda:0"
trial = 232
save_path = f"saved/{trial:03d}"

with open(f"{save_path}/params.json", "r") as f:
    params = json.load(f)

if params["model"] == "nrtsi":
    model = load_model(params["model"], params).to(device)

    gap_models = dict()
    gap_models[1] = f"{save_path}/model/nrtsi_state_dict_best_gap_1.pt"
    gap_models[2] = f"{save_path}/model/nrtsi_state_dict_best_gap_2.pt"
    gap_models[4] = f"{save_path}/model/nrtsi_state_dict_best_gap_4.pt"
    gap_models[8] = f"{save_path}/model/nrtsi_state_dict_best_gap_8.pt"
    gap_models[16] = f"{save_path}/model/nrtsi_state_dict_best_gap_16.pt"

    for k in gap_models:
        gap_models[k] = torch.load(gap_models[k], map_location=lambda storage, _: storage)
else:
    state_dict = torch.load(
        f"{save_path}/model/{params['model']}_state_dict_best.pt",
        map_location=lambda storage, _: storage,
    )

    model = load_model(params["model"], params).to(device)
    model.load_state_dict(state_dict)

In [5]:
sports = params["dataset"]
model_type = params["model"]
naive_baselines = True

if model_type == "dbhp":
    deriv_accum = params["deriv_accum"]
    dynamic_hybrid = params["dynamic_hybrid"]

print(f"- Sports: {sports}")
print(f"- Model type: {model_type}")
print(f"- Compute stats for naive baselines: {naive_baselines}")

- Sports: soccer
- Model type: dbhp
- Compute stats for naive baselines: True


In [6]:
metrica_files = ["match1.csv", "match2.csv", "match3_valid.csv", "match3_test.csv"]
metrica_paths = [f"data/metrica_traces/{f}" for f in metrica_files]

nba_files = os.listdir("data/nba_traces")
nba_paths = [f"data/nba_traces/{f}" for f in nba_files]
nba_paths.sort()

nfl_files = os.listdir("data/nfl_traces")
nfl_paths = [f"data/nfl_traces/{f}" for f in nfl_files if f.endswith(".csv")]
nfl_paths.sort()

if sports == "soccer":
    trace_helper = TraceHelper
    test_data_paths = metrica_paths[3:4]
elif sports == "basketball":
    trace_helper = NBADataHelper
    test_data_paths = nba_paths[90:]
else: # e.g. "American football"
    trace_helper = NFLDataHelper
    test_data_paths = nfl_paths[0:1]

print(f"Test data paths: {test_data_paths}")

Test data paths: ['data/metrica_traces/match3_test.csv']


### Function for testing a trial and printing performance statistics

In [56]:
def print_stats(trial, model, params, sports="soccer", naive_baselines=True):
    print(f"\n---------- Trial {trial} ----------")

    pred_keys = ["pred"]
    if model_type == "dbhp":
        if model.params["deriv_accum"]:
            pred_keys += ["dap_f", "dap_b"]
        if model.params["dynamic_hybrid"]:
            pred_keys += ["hybrid_s", "hybrid_s2", "hybrid_d"]
    if naive_baselines:
        pred_keys += ["linear", "knn", "ffill"]

    stat_keys = ["total_frames", "missing_frames"]
    stat_keys += [f"{k}_{m}" for k in pred_keys for m in ["pe", "se", "sce", "ple"]]
    stats = {k: 0 for k in stat_keys}

    for path in test_data_paths:
        print()
        print(f"{path}:")
        match_traces = pd.read_csv(path, header=0, encoding="utf-8-sig")
        helper = trace_helper(traces=match_traces)

        if params["model"] == "nrtsi":
            match_ret, match_stats = helper.predict(
                model, dataset_type=sports, naive_baselines=naive_baselines, gap_models=gap_models
            )
        else:
            match_ret, match_stats = helper.predict(model, dataset_type=sports, naive_baselines=naive_baselines)

        for k, v in match_stats.items():
            stats[k] += v

    # print("Total Performance:")
    # print_helper(ret, pred_keys, trial=trial, save_txt=True)

    # torch.save(helper, f"{save_path}/helper")
    # torch.save(ret, f"{save_path}/df_dict")

    n_players, _ = get_dataset_config(sports)
    stats_df = pd.DataFrame(index=pred_keys, columns=["pe", "se", "sce", "ple"])

    for k, v in stats.items():
        if k in ["total_frames", "missing_frames"]:
            continue
        
        pred_key = "_".join(k.split("_")[:-1])
        metric = k.split("_")[-1]

        if metric in ["pe", "se"]:
            stats[k] = round(v / stats["missing_frames"], 6)
            
        elif metric in ["sce", "ple"]:
            stats[k] = round(v / (stats["total_frames"] * n_players), 6)

        stats_df.at[pred_key, metric] = stats[k]

    # print(f"Total frames: {stats['total_frames'] * n_players}")
    # print(f"Missing frames: {stats['missing_frames']}")
    print()
    print(f"Window size: {params['window_size']}")
    print(f"Missing pattern: {params['missing_pattern']}")
    print(f"Missing rate: {stats['missing_frames'] / (stats['total_frames'] * n_players):.4f}")
    print(stats_df.loc[["pred", "dap_f", "dap_b", "hybrid_s2", "hybrid_d", "linear"], "pe"])

    return stats_df

### Ablation study on Set Transformer architecture

In [63]:
device = "cuda:0"
trial_ids = [310, 320, 330, 331, 332, 220, 341, 342]

for trial in trial_ids:
    save_path = f"saved/{trial:03d}"

    with open(f"{save_path}/params.json", "r") as f:
        params = json.load(f)

    state_dict = torch.load(
        f"{save_path}/model/{params['model']}_state_dict_best.pt",
        map_location=lambda storage, _: storage,
    )

    model = load_model(params["model"], params).to(device)
    model.load_state_dict(state_dict)

    print_stats(trial, model, params)


---------- Trial 310 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  3.36it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.12it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  2.32it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  1.91it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.03it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  2.37it/s]



Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred         22.88179
dap_f        5.460254
dap_b        5.544753
hybrid_s2    2.748455
hybrid_d     2.578049
linear       3.166009
Name: pe, dtype: object

---------- Trial 320 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:03<00:00,  2.56it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  4.35it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  3.05it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  2.49it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.92it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.43it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  2.46it/s]



Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred         16.337782
dap_f         5.127047
dap_b         5.192201
hybrid_s2      2.51853
hybrid_d      2.530383
linear        3.166009
Name: pe, dtype: object

---------- Trial 330 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  3.40it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.37it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  2.72it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  2.19it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.69it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.12it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  2.65it/s]



Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred         1.553781
dap_f        1.442645
dap_b        1.446085
hybrid_s2    1.258591
hybrid_d     1.248082
linear       3.166009
Name: pe, dtype: object

---------- Trial 331 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:03<00:00,  2.63it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  4.33it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  3.00it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  2.50it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.98it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.02it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  2.86it/s]



Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred         1.545655
dap_f        1.417635
dap_b         1.42606
hybrid_s2    1.246887
hybrid_d     1.235573
linear       3.166009
Name: pe, dtype: object

---------- Trial 332 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  3.83it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  2.55it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  1.89it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.68it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.54it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  2.99it/s]



Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred         1.536491
dap_f        1.461935
dap_b        1.425656
hybrid_s2    1.246257
hybrid_d     1.234534
linear       3.166009
Name: pe, dtype: object

---------- Trial 220 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:03<00:00,  2.50it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  4.30it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  2.78it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  2.36it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.38it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  2.98it/s]



Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred          1.53002
dap_f        1.461889
dap_b        1.442529
hybrid_s2    1.273531
hybrid_d     1.263448
linear       3.166009
Name: pe, dtype: object

---------- Trial 341 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:03<00:00,  2.61it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  4.30it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  2.50it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  2.14it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.93it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.54it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  3.03it/s]



Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred         1.533742
dap_f        1.517984
dap_b        1.515321
hybrid_s2    1.286369
hybrid_d     1.273371
linear       3.166009
Name: pe, dtype: object

---------- Trial 342 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:03<00:00,  2.56it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  4.36it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  2.33it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  1.82it/s]
Phase 9: 100%|██████████| 4/4 [00:02<00:00,  1.70it/s]
Phase 10: 100%|██████████| 6/6 [00:01<00:00,  3.52it/s]
Phase 11: 100%|██████████| 7/7 [00:02<00:00,  3.00it/s]


Window size: 200
Missing pattern: camera
Missing rate: 0.5375
pred         1.624392
dap_f         1.47451
dap_b        1.446096
hybrid_s2    1.256186
hybrid_d     1.244057
linear       3.166009
Name: pe, dtype: object





### Ablation study on window size and missing rate

In [64]:
device = "cuda:0"
# trial_ids = np.sort([int(i) for i in os.listdir("saved") if int(i) >= 200 and int(i) < 250])
trial_ids = [205]

for trial in trial_ids:
    save_path = f"saved/{trial:03d}"

    with open(f"{save_path}/params.json", "r") as f:
        params = json.load(f)

    state_dict = torch.load(
        f"{save_path}/model/{params['model']}_state_dict_best.pt",
        map_location=lambda storage, _: storage,
    )

    model = load_model(params["model"], params).to(device)
    model.load_state_dict(state_dict)

    print_stats(trial, model, params)


---------- Trial 205 ----------

data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]
Phase 3: 100%|██████████| 2/2 [00:00<00:00,  2.79it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:01<00:00,  1.05s/it]
Phase 7: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s]
Phase 8: 100%|██████████| 3/3 [00:01<00:00,  1.65it/s]
Phase 9: 100%|██████████| 4/4 [00:03<00:00,  1.12it/s]
Phase 10: 100%|██████████| 6/6 [00:02<00:00,  2.06it/s]
Phase 11: 100%|██████████| 7/7 [00:04<00:00,  1.66it/s]


Window size: 50
Missing pattern: playerwise
Missing rate: 0.9000
pred         5.267954
dap_f        1.723606
dap_b        1.748702
hybrid_s2     0.54647
hybrid_d     0.477616
linear       1.501628
Name: pe, dtype: object





In [18]:
self = helper
ep_traces = self.traces[helper.traces["episode"] == 36]

feature_types = ["_x", "_y", "_vx", "_vy", "_ax", "_ay"]
players = self.team1_players + self.team2_players
player_cols = [f"{p}{x}" for p in players for x in feature_types]

phase_gks = SportsDataset.detect_goalkeepers(ep_traces)
team1_code, team2_code = phase_gks[0][0], phase_gks[1][0]

ep_player_cols = ep_traces[player_cols].dropna(axis=1).columns
team1_cols = [c for c in ep_player_cols if c.startswith(team1_code)]
team2_cols = [c for c in ep_player_cols if c.startswith(team2_code)]
ball_cols = ["ball_x", "ball_y"]

ep_player_cols = team1_cols + team2_cols
ep_player_traces = torch.FloatTensor(ep_traces[ep_player_cols].values).unsqueeze(0)
ep_player_traces.shape

torch.Size([1, 253, 132])

In [20]:
bs, seq_len = ep_player_traces.shape[:2]
tensor = ep_player_traces.reshape(bs, seq_len, 22, -1)

x = tensor[..., 0:1]  # [bs, time, players, 1]
y = tensor[..., 1:2]
xy = torch.cat([x, y], dim=-1)  # [bs, time, players, 2]

x_plus_y = torch.sum(xy, dim=-1)  # [bs, time, players]

sorted_tensor = tensor.clone()
sort_idxs = torch.zeros(bs, n_players, dtype=int)

x_plus_y[0, 0]

tensor([0.6058, 1.0526, 0.8447, 0.5932, 1.2615, 0.9120, 0.8317, 0.9120, 1.1420,
        0.7295, 0.7110, 1.4746, 1.0250, 1.2611, 1.5325, 0.9509, 1.1192, 1.3261,
        1.2470, 0.8331, 0.9942, 1.0473])

## Performance analysis

##### (1) Get Main model results

In [None]:
trial = 3003
save_path = f"saved/{trial:03d}"
if os.path.isfile(save_path + "/df_dict"):
    helper =  torch.load(save_path + "/helper")
    df_dict = torch.load(save_path + "/df_dict")
    with open(f"{save_path}/params.json", "r") as f:
        params = json.load(f)

##### (2) Add baseline model results

In [None]:
# trial_dict = {4000 : "brits", 5000 : "naomi", 214 : "nrtsi"} # Metrica
trial_dict = {4003 : "brits", 5001 : "naomi", 6001 : "nrtsi", 9996 : "graphimputer"} # NBA
for (t, model_type) in trial_dict.items():
    save_path = f"saved/{t:03d}"
    if os.path.isfile(save_path + "/df_dict"):
        df_dict_ = torch.load(save_path + "/df_dict")
        df_dict[f"{model_type}_df"] = df_dict_["pred"]

In [None]:
df_dict.keys()

### Animation

##### (1) Soccer Animator

In [None]:
helper.traces["episode"].unique()

In [None]:
i0 = 479
i1 = 873

animator = TraceAnimator(
    trace_dict={"main": df_dict["target"][i0:i1], "pred": df_dict["dbhp_df"][i0:i1]},
    mask = df_dict["mask"][i0:i1],
    show_episodes=True,
    show_events=False,
    show_frames=False,
    show_polygon=True,
    annot_cols=None,
)
anim = animator.run()

path = f"animations/trial_{trial}.mp4"

writer = animation.FFMpegWriter(fps=10)
anim.save(path, writer=writer)

##### (2) Basketball Animator

In [None]:
i0 = 326
i1 = 737
animator = NBADataAnimator(
    trace_dict={"main": df_dict["target"][i0:i1], "pred": df_dict["dbhp_df"][i0:i1]},
    show_episodes=True,
    show_frames=True,
    masks = df_dict["mask"][i0:i1],
)
anim = animator.run()

path = f"animations/trial_{trial}.mp4"

writer = animation.FFMpegWriter(fps=10)
anim.save(path, writer=writer)

### plotting

In [None]:
plot_mode = "imputed_traj" # "imputed_traj", "dist_heatmap", "weights_heatmap"
sports = params["dataset"]
visualizer = VisualizeHelper(trial, df_dict, plot_mode, dataset=sports, helper=helper)
visualizer.valid_episodes()

In [None]:
visualizer.plot_run(epi_idx=0)
plt.close()