In [1]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np

from dataset import SportsDataset
from models.utils import generate_mask

In [None]:
sports = "soccer"

if sports == "soccer":
    metrica_files = ["match1.csv", "match2.csv", "match3_valid.csv", "match3_test.csv"]
    metrica_paths = [f"data/metrica_traces/{f}" for f in metrica_files]
    data_paths = metrica_paths[-1:]

elif sports == "basketball":
    nba_files = os.listdir("data/nba_traces")
    nba_paths = [f"data/nba_traces/{f}" for f in nba_files]
    nba_paths.sort()
    data_paths = nba_paths[:10]

else:  # sports == "afootball"
    nfl_paths = ["data/nfl_traces/nfl_train.csv", "data/nfl_traces/nfl_test.csv"]
    data_paths = nfl_paths[:-1]

data_paths

In [None]:
window_size = 600
dataset = SportsDataset(
    sports=sports,
    data_paths=data_paths,
    n_features=6,
    window_size=window_size,
    normalize=True,
    flip_pitch=True
)
len(dataset)

In [None]:
player_data = dataset.player_data.to("cuda:0")
ball_data = dataset.ball_data.to("cuda:0")

if player_data.is_cuda:
    valid_frames = np.array(player_data.cpu()[..., 0] != -100).astype(int).sum(axis=-1)
else:
    valid_frames = np.array(player_data[..., 0] != -100).astype(int).sum(axis=-1)

len(valid_frames), (valid_frames < window_size).astype(int).sum()

In [None]:
data_dict = {"target": player_data, "ball": ball_data}
mask, missing_rate = generate_mask(data_dict, mode="camera")
mask.shape, round(missing_rate, 4)

In [None]:
n_players = 22
missing_rate = 0.5
verbose = False

mask = np.ones((player_data.shape[0], player_data.shape[1], n_players))
missing_frames = np.zeros((mask.shape[0], n_players)).astype(int)  # [bs, players]

residue = (valid_frames * n_players * missing_rate).astype(int)  # [bs,]
max_shares = np.tile(valid_frames - 10, (n_players, 1)).T  # [bs, players]
assert np.all(residue < max_shares.sum(axis=-1))

for i in range(mask.shape[0]):
    while residue[i] > 0:
        slots = missing_frames[i] < max_shares[i]  # [players,]
        breakpoints = np.random.choice(residue[i] + 1, slots.astype(int).sum() - 1, replace=True)  # [players - 1,]
        shares = np.diff(np.sort(breakpoints.tolist() + [0, residue[i]]))  # [players,]
        if verbose:
            print()
            print(residue)
            print(slots.astype(int))
            print(shares)

        missing_frames[i, ~slots] = max_shares[i, ~slots]
        missing_frames[i, slots] += shares
        residue[i] = np.clip(missing_frames[i] - max_shares[i], 0, None).sum()
        if verbose:
            print(missing_frames[i])

start_idxs = np.random.randint(1, max_shares - missing_frames + 2)  # [bs, players]
end_idxs = start_idxs + missing_frames  # [bs, players]

for i in range(mask.shape[0]):
    for p in range(n_players):
        mask[i, start_idxs[i, p] : end_idxs[i, p], p] = 0

missing_frames, missing_frames.sum(axis=1) // n_players

In [11]:
import random
import numpy as np
import torch

player_data = torch.zeros(32, 200, 22)

mask = np.ones((player_data.shape[0], player_data.shape[1], 22))  # [bs, time, players]
window_size = player_data.shape[1]
missing_len = int(window_size * 0.5)

mask[:, random.sample(range(1, window_size - 1), missing_len)] = 0

for i in range(mask.shape[0]):
    print(f"Missing rate for batch {i} : {(mask[i] == 0).sum() / mask[i].reshape(-1).shape[0]}")

Missing rate for batch 0 : 0.5
Missing rate for batch 1 : 0.5
Missing rate for batch 2 : 0.5
Missing rate for batch 3 : 0.5
Missing rate for batch 4 : 0.5
Missing rate for batch 5 : 0.5
Missing rate for batch 6 : 0.5
Missing rate for batch 7 : 0.5
Missing rate for batch 8 : 0.5
Missing rate for batch 9 : 0.5
Missing rate for batch 10 : 0.5
Missing rate for batch 11 : 0.5
Missing rate for batch 12 : 0.5
Missing rate for batch 13 : 0.5
Missing rate for batch 14 : 0.5
Missing rate for batch 15 : 0.5
Missing rate for batch 16 : 0.5
Missing rate for batch 17 : 0.5
Missing rate for batch 18 : 0.5
Missing rate for batch 19 : 0.5
Missing rate for batch 20 : 0.5
Missing rate for batch 21 : 0.5
Missing rate for batch 22 : 0.5
Missing rate for batch 23 : 0.5
Missing rate for batch 24 : 0.5
Missing rate for batch 25 : 0.5
Missing rate for batch 26 : 0.5
Missing rate for batch 27 : 0.5
Missing rate for batch 28 : 0.5
Missing rate for batch 29 : 0.5
Missing rate for batch 30 : 0.5
Missing rate for b

In [None]:
import torch

gt_data_ = torch.load("gt_data_")
imputations_ = torch.load("imputations_")

In [1]:
import torch

df_dict = torch.load("./df_dict")
df_dict

{'target_df':        episode  frame  quarter  game_clock  shot_clock  player0_x  player0_y   
 0            1      1      1.0   719.96875    23.92625  13.068354   9.743013  \
 1            1      2      1.0   719.87000    23.81300  12.874110   9.814178   
 2            1      3      1.0   719.76500    23.72100  12.676634   9.870276   
 3            1      4      1.0   719.65700    23.63000  12.475927   9.911305   
 4            1      5      1.0   719.54900    23.53900  12.271988   9.937267   
 ...        ...    ...      ...         ...         ...        ...        ...   
 28690        0  28691      4.0     0.23000     5.37000  25.750887   5.878579   
 28691        0  28692      4.0     0.12300     5.37000  25.752978   5.885402   
 28692        0  28693      4.0     0.16800     5.37000  15.475092   3.493432   
 28693        0  28694      4.0     0.20000     5.37000  22.549985   4.684078   
 28694        0  28695      4.0     0.04000     5.37000  22.544872   4.685783   
 
        playe