In [1]:
%load_ext autoreload
%autoreload 2

import os, json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib import animation
from tqdm import tqdm
%matplotlib inline

from dataset import SportsDataset
from datatools.trace_animator import TraceAnimator
from datatools.trace_helper import TraceHelper
from datatools.visualize_helper import VisualizeHelper
from datatools.nba_helper import NBADataHelper, NBADataAnimator
from datatools.nfl_helper import NFLDataHelper
from models import load_model
from models.utils import get_dataset_config, print_helper, reshape_tensor, sort_players

from models.graph_imputer.graph_imputer import BidirectionalGraphImputer

## Model evaluating on test data

### Load model

In [2]:
device = "cuda:0"
trial = 300
save_path = f"saved/{trial:03d}"

with open(f"{save_path}/params.json", "r") as f:
    params = json.load(f)

if params["model"] == "nrtsi":
    model = load_model(params["model"], params).to(device)

    gap_models = dict()
    gap_models[1] = f"{save_path}/model/nrtsi_state_dict_best_gap_1.pt"
    gap_models[2] = f"{save_path}/model/nrtsi_state_dict_best_gap_2.pt"
    gap_models[4] = f"{save_path}/model/nrtsi_state_dict_best_gap_4.pt"
    gap_models[8] = f"{save_path}/model/nrtsi_state_dict_best_gap_8.pt"
    gap_models[16] = f"{save_path}/model/nrtsi_state_dict_best_gap_16.pt"

    for k in gap_models:
        gap_models[k] = torch.load(gap_models[k], map_location=lambda storage, _: storage)
else:
    state_dict = torch.load(
        f"{save_path}/model/{params['model']}_state_dict_best.pt",
        map_location=lambda storage, _: storage,
    )

    model = load_model(params["model"], params).to(device)
    model.load_state_dict(state_dict)

In [3]:
sports = params["dataset"]
model_type = params["model"]
naive_baselines = True

if model_type == "dbhp":
    deriv_accum = params["deriv_accum"]
    dynamic_hybrid = params["dynamic_hybrid"]

print(f"- Sports: {sports}")
print(f"- Model type: {model_type}")
print(f"- Compute stats for naive baselines: {naive_baselines}")

- Sports: afootball
- Model type: dbhp
- Compute stats for naive baselines: True


In [4]:
metrica_files = ["match1.csv", "match2.csv", "match3_valid.csv", "match3_test.csv"]
metrica_paths = [f"data/metrica_traces/{f}" for f in metrica_files]

nba_files = os.listdir("data/nba_traces")
nba_paths = [f"data/nba_traces/{f}" for f in nba_files]
nba_paths.sort()

nfl_files = os.listdir("data/nfl_traces")
nfl_paths = [f"data/nfl_traces/{f}" for f in nfl_files if f.endswith(".csv")]
nfl_paths.sort()

if sports == "soccer":
    trace_helper = TraceHelper
    test_data_paths = metrica_paths[3:4]
elif sports == "basketball":
    trace_helper = NBADataHelper
    test_data_paths = nba_paths[90:]
else: # e.g. "American football"
    trace_helper = NFLDataHelper
    test_data_paths = nfl_paths[0:1]

print(f"Test data paths: {test_data_paths}")

Test data paths: ['data/nfl_traces/nfl_test.csv']


In [5]:
paths = nba_paths[70:80]
n_episodes = 0
n_frames = 0

for f in tqdm(paths):
    match_traces = pd.read_csv(f, header=0)
    episodes = [e for e in match_traces["episode"].unique() if e > 0]
    for e in episodes:
        ep_traces = match_traces[match_traces["episode"] == e]
        if len(ep_traces) >= 100:
            n_episodes += 1
            n_frames += len(ep_traces)

n_episodes, n_frames

100%|██████████| 10/10 [00:02<00:00,  3.49it/s]


(687, 249474)

### Function for testing a trial and printing performance statistics

In [8]:
def print_stats(trial, model, params, sports="soccer", naive_baselines=True):
    print(f"\n---------- Trial {trial} ----------")

    pred_keys = ["pred"]
    if model_type == "dbhp":
        if model.params["deriv_accum"]:
            pred_keys += ["dap_f"]
            if model.params["missing_pattern"] != "forecast":
                pred_keys += ["dap_b"]
        if model.params["dynamic_hybrid"]:
            if model.params["missing_pattern"] == "forecast":
                pred_keys += ["hybrid_d"]
            else:
                pred_keys += ["hybrid_s", "hybrid_s2", "hybrid_d"]
    if naive_baselines:
        if model.params["missing_pattern"] == "forecast":
            pred_keys += ["ffill"]
        else:
            pred_keys += ["linear", "knn", "ffill"]

    stat_keys = ["total_frames", "missing_frames"]
    stat_keys += [f"{k}_{m}" for k in pred_keys for m in ["pe", "se", "sce", "ple"]]
    stats = {k: 0 for k in stat_keys}

    for path in test_data_paths:
        print()
        print(f"{path}:")
        match_traces = pd.read_csv(path, header=0, encoding="utf-8-sig")
        helper = trace_helper(traces=match_traces)

        if params["model"] == "nrtsi":
            match_ret, match_stats = helper.predict(
                model, dataset_type=sports, naive_baselines=naive_baselines, gap_models=gap_models
            )
        else:
            match_ret, match_stats = helper.predict(model, dataset_type=sports, naive_baselines=naive_baselines)

        for k, v in match_stats.items():
            stats[k] += v

    n_players, _ = get_dataset_config(sports)
    stats_df = pd.DataFrame(index=pred_keys, columns=["pe", "se", "sce", "ple"])

    for k, v in stats.items():
        if k in ["total_frames", "missing_frames"]:
            continue
        
        pred_key = "_".join(k.split("_")[:-1])
        metric = k.split("_")[-1]

        if metric in ["pe", "se"]:
            stats[k] = round(v / stats["missing_frames"], 6)
            
        elif metric in ["sce", "ple"]:
            stats[k] = round(v / (stats["total_frames"] * n_players), 6)

        stats_df.at[pred_key, metric] = stats[k]
    
    params["missing_rate"] = round(stats['missing_frames'] / (stats['total_frames'] * n_players), 4)

    print()
    print_args = pd.Series(dtype=object)
    for arg in ["window_size", "missing_pattern", "missing_rate"]:
        print_args[arg] = params[arg]
    print(print_args)
    
    print()
    if params["missing_pattern"] == "forecast":
        print(stats_df.loc[["pred", "dap_f", "hybrid_d", "ffill"], "pe"])
    else:
        print(stats_df.loc[["pred", "dap_f", "dap_b", "hybrid_s2", "hybrid_d", "linear"], "pe"])

    return match_ret, stats_df

### Testing imputation performence on a single model

In [10]:
match_ret, stats_df = print_stats(trial, model, params, sports=sports)


---------- Trial 300 ----------

data/nfl_traces/nfl_test.csv:


Episode: 100%|██████████| 1043/1043 [01:19<00:00, 13.13it/s]


window_size                50
missing_pattern    playerwise
missing_rate              0.5
dtype: object

pred         0.766386
dap_f        1.133616
dap_b        1.526139
hybrid_s2    0.986628
hybrid_d     1.233043
linear       0.021118
Name: pe, dtype: object





In [11]:
stats_df

Unnamed: 0,pe,se,sce,ple
pred,0.766386,1.194555,0.024692,0.431697
dap_f,1.133616,2.223005,0.061682,0.882608
dap_b,1.526139,2.522712,0.07625,0.969203
hybrid_s,0.998425,2.564118,0.059008,1.000563
hybrid_s2,0.986628,1.460341,0.051725,0.508137
hybrid_d,1.233043,2.535433,0.064635,0.986676
linear,0.021118,0.018739,0.000136,0.001279
knn,0.032029,0.064812,0.001481,0.006494
ffill,0.040115,0.066952,0.001975,0.001279


### Ablation study on Set Transformer architecture

In [None]:
device = "cuda:0"
trial_ids = [150, 153, 152, 160, 161, 220]

for trial in trial_ids:
    save_path = f"saved/{trial:03d}"

    with open(f"{save_path}/params.json", "r") as f:
        params = json.load(f)

    state_dict = torch.load(
        f"{save_path}/model/{params['model']}_state_dict_best.pt",
        map_location=lambda storage, _: storage,
    )

    model = load_model(params["model"], params).to(device)
    model.load_state_dict(state_dict)

    print_stats(trial, model, params)

### Ablation study on window size and missing rate

In [None]:
"EvolveGraph: Multi-Agent Trajectory Prediction with Dynamic Relational Reasoning".lower()

In [None]:
device = "cuda:0"
# trial_ids = np.sort([int(i) for i in os.listdir("saved") if int(i) >= 200 and int(i) < 250])
trial_ids = [205]

for trial in trial_ids:
    save_path = f"saved/{trial:03d}"

    with open(f"{save_path}/params.json", "r") as f:
        params = json.load(f)

    state_dict = torch.load(
        f"{save_path}/model/{params['model']}_state_dict_best.pt",
        map_location=lambda storage, _: storage,
    )

    model = load_model(params["model"], params).to(device)
    model.load_state_dict(state_dict)

    print_stats(trial, model, params)

In [None]:
self = helper
ep_traces = self.traces[helper.traces["episode"] == 36]

feature_types = ["_x", "_y", "_vx", "_vy", "_ax", "_ay"]
players = self.team1_players + self.team2_players
player_cols = [f"{p}{x}" for p in players for x in feature_types]

phase_gks = SportsDataset.detect_goalkeepers(ep_traces)
team1_code, team2_code = phase_gks[0][0], phase_gks[1][0]

ep_player_cols = ep_traces[player_cols].dropna(axis=1).columns
team1_cols = [c for c in ep_player_cols if c.startswith(team1_code)]
team2_cols = [c for c in ep_player_cols if c.startswith(team2_code)]
ball_cols = ["ball_x", "ball_y"]

ep_player_cols = team1_cols + team2_cols
ep_player_traces = torch.FloatTensor(ep_traces[ep_player_cols].values).unsqueeze(0)
ep_player_traces.shape

In [None]:
bs, seq_len = ep_player_traces.shape[:2]
tensor = ep_player_traces.reshape(bs, seq_len, 22, -1)

x = tensor[..., 0:1]  # [bs, time, players, 1]
y = tensor[..., 1:2]
xy = torch.cat([x, y], dim=-1)  # [bs, time, players, 2]

x_plus_y = torch.sum(xy, dim=-1)  # [bs, time, players]

sorted_tensor = tensor.clone()
sort_idxs = torch.zeros(bs, n_players, dtype=int)

x_plus_y[0, 0]

## Performance analysis

##### (1) Get Main model results

In [None]:
trial = 3003
save_path = f"saved/{trial:03d}"
if os.path.isfile(save_path + "/df_dict"):
    helper =  torch.load(save_path + "/helper")
    df_dict = torch.load(save_path + "/df_dict")
    with open(f"{save_path}/params.json", "r") as f:
        params = json.load(f)

##### (2) Add baseline model results

In [None]:
# trial_dict = {4000 : "brits", 5000 : "naomi", 214 : "nrtsi"} # Metrica
trial_dict = {4003 : "brits", 5001 : "naomi", 6001 : "nrtsi", 9996 : "graphimputer"} # NBA
for (t, model_type) in trial_dict.items():
    save_path = f"saved/{t:03d}"
    if os.path.isfile(save_path + "/df_dict"):
        df_dict_ = torch.load(save_path + "/df_dict")
        df_dict[f"{model_type}_df"] = df_dict_["pred"]

In [None]:
df_dict.keys()

### Animation

##### (1) Soccer Animator

In [None]:
match_ret.keys()

In [None]:
i0 = 479
i1 = 873

animator = TraceAnimator(
    match_ret={"main": match_ret["target"][i0:i1], "pred": match_ret["hybrid_d"][i0:i1]},
    mask = match_ret["mask"][i0:i1],
    show_episodes=True,
    show_events=False,
    show_frames=False,
    show_polygon=True,
    annot_cols=None,
)
anim = animator.run()

path = f"animations/trial_{trial}.mp4"
if not os.path.exists("animations"):
    os.makedirs("animations")

writer = animation.FFMpegWriter(fps=10)
anim.save(path, writer=writer)

##### (2) Basketball Animator

In [None]:
match_ret.keys()

In [None]:
i0 = 326
i1 = 737
animator = NBADataAnimator(
    match_ret={"main": match_ret["target"][i0:i1], "pred": match_ret["hybrid_d"][i0:i1]},
    show_episodes=True,
    show_frames=True,
    masks = match_ret["mask"][i0:i1],
)
anim = animator.run()

path = f"animations/trial_{trial}.mp4"
if not os.path.exists("animations"):
    os.makedirs("animations")

writer = animation.FFMpegWriter(fps=10)
anim.save(path, writer=writer)

### Visualizing imputed trajectories

In [None]:
plot_mode = "imputed_traj" # "imputed_traj", "dist_heatmap", "weights_heatmap"
sports = params["dataset"]
visualizer = VisualizeHelper(trial, df_dict, plot_mode, dataset=sports, helper=helper)
visualizer.valid_episodes()

In [None]:
visualizer.plot_run(epi_idx=0)
plt.close()