In [1]:
%load_ext autoreload
%autoreload 2

import os, json

import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

from matplotlib import animation
from datatools.trace_animator import TraceAnimator
from datatools.trace_helper import TraceHelper
from datatools.visualize_helper import VisualizeHelper
from datatools.nba_helper import NBADataHelper, NBADataAnimator
from datatools.nfl_helper import NFLDataHelper

from models import load_model
from models.utils import print_helper, reshape_tensor, get_dataset_config, normalize_tensor

from models.graph_imputer.graph_imputer import BidirectionalGraphImputer

## Model evaluating on test data

### Load model

In [2]:
device = "cuda:0"
trial = 222
save_path = f"saved/{trial:03d}"

with open(f"{save_path}/params.json", "r") as f:
    params = json.load(f)

if params["model"] == "nrtsi":
    model = load_model(params["model"], params).to(device)

    gap_models = dict()
    gap_models[1] = f"{save_path}/model/nrtsi_state_dict_best_gap_1.pt"
    gap_models[2] = f"{save_path}/model/nrtsi_state_dict_best_gap_2.pt"
    gap_models[4] = f"{save_path}/model/nrtsi_state_dict_best_gap_4.pt"
    gap_models[8] = f"{save_path}/model/nrtsi_state_dict_best_gap_8.pt"
    gap_models[16] = f"{save_path}/model/nrtsi_state_dict_best_gap_16.pt"

    for k in gap_models:
        gap_models[k] = torch.load(gap_models[k], map_location=lambda storage, _: storage)
else:
    state_dict = torch.load(
        f"{save_path}/model/{params['model']}_state_dict_best.pt",
        map_location=lambda storage, _: storage,
    )

    model = load_model(params["model"], params).to(device)
    model.load_state_dict(state_dict)

In [3]:
sports = params["dataset"]
model_type = params["model"]
naive_baselines = True

if model_type == "dbhp":
    deriv_accum = params["deriv_accum"]
    dynamic_hybrid = params["dynamic_hybrid"]

print(f"- Sports: {sports}")
print(f"- Model type: {model_type}")
print(f"- Compute stats for naive baselines: {naive_baselines}")

- Sports: soccer
- Model type: dbhp
- Compute stats for naive baselines: True


In [4]:
metrica_files = ["match1.csv", "match2.csv", "match3_valid.csv", "match3_test.csv"]
metrica_paths = [f"data/metrica_traces/{f}" for f in metrica_files]

nba_files = os.listdir("data/nba_traces")
nba_paths = [f"data/nba_traces/{f}" for f in nba_files]
nba_paths.sort()

nfl_files = os.listdir("data/nfl_traces")
nfl_paths = [f"data/nfl_traces/{f}" for f in nfl_files if f.endswith(".csv")]
nfl_paths.sort()

if sports == "soccer":
    trace_helper = TraceHelper
    test_data_paths = metrica_paths[3:4]
elif sports == "basketball":
    trace_helper = NBADataHelper
    test_data_paths = nba_paths[90:]
else: # e.g. "American football"
    trace_helper = NFLDataHelper
    test_data_paths = nfl_paths[0:1]

print(f"Test data paths: {test_data_paths}")

Test data paths: ['data/metrica_traces/match3_test.csv']


### Run model

In [30]:
pred_keys = ["pred"]
if model_type == "dbhp":
    if model.params["deriv_accum"]:
        pred_keys += ["dap_f", "dap_b"]
    if model.params["dynamic_hybrid"]:
        pred_keys += ["hybrid_s", "hybrid_s2", "hybrid_d"]
if naive_baselines:
    pred_keys += ["linear", "knn", "ffill"]

stat_keys = ["total_frames", "missing_frames"]
stat_keys += [f"{k}_{m}" for k in pred_keys for m in ["pe", "se", "sce", "ple"]]
stats = {k: 0 for k in stat_keys}

for path in test_data_paths:
    print()
    print(f"{path}:")
    match_traces = pd.read_csv(path, header=0, encoding="utf-8-sig")
    helper = trace_helper(traces=match_traces)

    if model_type == "nrtsi":
        match_ret, match_stats = helper.predict(
            model, dataset_type=sports, naive_baselines=naive_baselines, gap_models=gap_models
        )
    else:
        match_ret, match_stats = helper.predict(model, dataset_type=sports, naive_baselines=naive_baselines)

    for k, v in match_stats.items():
        stats[k] += v

# print("Total Performance:")
# print_helper(ret, pred_keys, trial=trial, save_txt=True)

# torch.save(helper, f"{save_path}/helper")
# torch.save(ret, f"{save_path}/df_dict")

n_players, _ = get_dataset_config(sports)
stats_df = pd.DataFrame(index=pred_keys, columns=["pe", "se", "sce", "ple"])

for k, v in stats.items():
    if k in ["total_frames", "missing_frames"]:
        continue
    
    pred_key = "_".join(k.split("_")[:-1])
    metric = k.split("_")[-1]

    if metric in ["pe", "se"]:
        stats[k] = round(v / stats["missing_frames"], 6)
        
    elif metric in ["sce", "ple"]:
        stats[k] = round(v / (stats["total_frames"] * n_players), 6)

    stats_df.at[pred_key, metric] = stats[k]

stats_df


data/metrica_traces/match3_test.csv:


Phase 2: 100%|██████████| 10/10 [00:12<00:00,  1.27s/it]
Phase 3: 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]
Phase 4: 0it [00:00, ?it/s]
Phase 5: 0it [00:00, ?it/s]
Phase 6: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it]
Phase 7: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]
Phase 8: 100%|██████████| 3/3 [00:04<00:00,  1.43s/it]
Phase 9: 100%|██████████| 4/4 [00:07<00:00,  1.80s/it]
Phase 10: 100%|██████████| 6/6 [00:06<00:00,  1.03s/it]
Phase 11: 100%|██████████| 7/7 [00:08<00:00,  1.16s/it]


Unnamed: 0,pe,se,sce,ple
pred,2.135689,0.782768,2.8e-05,0.011214
dap_f,2.107197,0.576114,1.6e-05,0.012526
dap_b,2.118469,0.577828,1.7e-05,0.012497
hybrid_s,2.037109,0.579688,1.6e-05,0.012852
hybrid_s2,2.009928,0.538293,1.5e-05,0.013421
hybrid_d,2.005881,0.536693,1.5e-05,0.013356
linear,5.372363,1.186698,4.2e-05,0.046397
knn,7.809323,6.283521,0.001897,0.142324
ffill,10.354956,3.620982,0.001565,0.046397


## Performance analysis

##### (1) Get Main model results

In [None]:
trial = 3003
save_path = f"saved/{trial:03d}"
if os.path.isfile(save_path + "/df_dict"):
    helper =  torch.load(save_path + "/helper")
    df_dict = torch.load(save_path + "/df_dict")
    with open(f"{save_path}/params.json", "r") as f:
        params = json.load(f)

##### (2) Add baseline model results

In [None]:
# trial_dict = {4000 : "brits", 5000 : "naomi", 214 : "nrtsi"} # Metrica
trial_dict = {4003 : "brits", 5001 : "naomi", 6001 : "nrtsi", 9996 : "graphimputer"} # NBA
for (t, model_type) in trial_dict.items():
    save_path = f"saved/{t:03d}"
    if os.path.isfile(save_path + "/df_dict"):
        df_dict_ = torch.load(save_path + "/df_dict")
        df_dict[f"{model_type}_df"] = df_dict_["pred"]

In [None]:
df_dict.keys()

### Animation

##### (1) Soccer Animator

In [None]:
helper.traces["episode"].unique()

In [None]:
i0 = 479
i1 = 873

animator = TraceAnimator(
    trace_dict={"main": df_dict["target"][i0:i1], "pred": df_dict["dbhp_df"][i0:i1]},
    mask = df_dict["mask"][i0:i1],
    show_episodes=True,
    show_events=False,
    show_frames=False,
    show_polygon=True,
    annot_cols=None,
)
anim = animator.run()

path = f"animations/trial_{trial}.mp4"

writer = animation.FFMpegWriter(fps=10)
anim.save(path, writer=writer)

##### (2) Basketball Animator

In [None]:
i0 = 326
i1 = 737
animator = NBADataAnimator(
    trace_dict={"main": df_dict["target"][i0:i1], "pred": df_dict["dbhp_df"][i0:i1]},
    show_episodes=True,
    show_frames=True,
    masks = df_dict["mask"][i0:i1],
)
anim = animator.run()

path = f"animations/trial_{trial}.mp4"

writer = animation.FFMpegWriter(fps=10)
anim.save(path, writer=writer)

### plotting

In [None]:
plot_mode = "imputed_traj" # "imputed_traj", "dist_heatmap", "weights_heatmap"
sports = params["dataset"]
visualizer = VisualizeHelper(trial, df_dict, plot_mode, dataset=sports, helper=helper)
visualizer.valid_episodes()

In [None]:
visualizer.plot_run(epi_idx=0)
plt.close()