In [1]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np

from dataset import SportsDataset
from models.utils import generate_mask

In [17]:
sports = "soccer"

if sports == "soccer":
    metrica_files = ["match1.csv", "match2.csv", "match3_valid.csv", "match3_test.csv"]
    metrica_paths = [f"data/metrica_traces/{f}" for f in metrica_files]
    data_paths = metrica_paths[-1:]

elif sports == "basketball":
    nba_files = os.listdir("data/nba_traces")
    nba_paths = [f"data/nba_traces/{f}" for f in nba_files]
    nba_paths.sort()
    data_paths = nba_paths[:10]

else:  # sports == "afootball"
    nfl_paths = ["data/nfl_traces/nfl_train.csv", "data/nfl_traces/nfl_test.csv"]
    data_paths = nfl_paths[:-1]

data_paths

['data/metrica_traces/match3_test.csv']

In [35]:
window_size = 600
dataset = SportsDataset(
    sports=sports,
    data_paths=data_paths,
    n_features=6,
    window_size=window_size,
    normalize=True,
    flip_pitch=True
)
len(dataset)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:04<00:00,  4.16s/it]


49

In [36]:
player_data = dataset.player_data.to("cuda:0")
ball_data = dataset.ball_data.to("cuda:0")

if player_data.is_cuda:
    valid_frames = np.array(player_data.cpu()[..., 0] != -100).astype(int).sum(axis=-1)
else:
    valid_frames = np.array(player_data[..., 0] != -100).astype(int).sum(axis=-1)

len(valid_frames), (valid_frames < window_size).astype(int).sum()

(49, 21)

In [37]:
data_dict = {"target": player_data, "ball": ball_data}
mask, missing_rate = generate_mask(data_dict, mode="camera")
mask.shape, round(missing_rate, 4)

((49, 600, 22), 0.5593)

In [6]:
n_players = 22
missing_rate = 0.5
verbose = False

mask = np.ones((player_data.shape[0], player_data.shape[1], n_players))
missing_frames = np.zeros((mask.shape[0], n_players)).astype(int)  # [bs, players]

residue = (valid_frames * n_players * missing_rate).astype(int)  # [bs,]
max_shares = np.tile(valid_frames - 10, (n_players, 1)).T  # [bs, players]
assert np.all(residue < max_shares.sum(axis=-1))

for i in range(mask.shape[0]):
    while residue[i] > 0:
        slots = missing_frames[i] < max_shares[i]  # [players,]
        breakpoints = np.random.choice(residue[i] + 1, slots.astype(int).sum() - 1, replace=True)  # [players - 1,]
        shares = np.diff(np.sort(breakpoints.tolist() + [0, residue[i]]))  # [players,]
        if verbose:
            print()
            print(residue)
            print(slots.astype(int))
            print(shares)

        missing_frames[i, ~slots] = max_shares[i, ~slots]
        missing_frames[i, slots] += shares
        residue[i] = np.clip(missing_frames[i] - max_shares[i], 0, None).sum()
        if verbose:
            print(missing_frames[i])

start_idxs = np.random.randint(1, max_shares - missing_frames + 2)  # [bs, players]
end_idxs = start_idxs + missing_frames  # [bs, players]

for i in range(mask.shape[0]):
    for p in range(n_players):
        mask[i, start_idxs[i, p] : end_idxs[i, p], p] = 0

missing_frames, missing_frames.sum(axis=1) // n_players

(array([[190,  39,  78,  93,  50, 190,  49, 100, 190, 119,  82, 188,  20,
          73,  41, 190,  80,  47, 190,  28,  91,  72],
        [101, 103,  64,  95, 190, 136,  36,  20,  25, 138, 190, 127,  32,
         113,  53,  57, 190, 190, 190,  89,  29,  32],
        [ 26,  37,  27,  70,  19, 190, 190, 161, 190,  59, 142,  57,  37,
         171, 123,  15, 141, 187,  31,  81, 190,  56],
        [ 43, 116, 177,  61, 114, 130, 130,  19,  80, 156,  95, 177,  41,
           7, 169,  91,  28,  66,  21, 133,  26, 177],
        [190,  47,  77,  47, 136, 105, 154,  35, 190,  80,  84, 154, 160,
          35, 118, 190,  31,  69,  34, 185,  11,  68]]),
 array([100, 100, 100,  93, 100]))