In [10]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np
import torch

from dataset import SportsDataset
from models.utils import generate_mask

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
sports = "soccer"

if sports == "soccer":
    metrica_files = ["match1.csv", "match2.csv", "match3_valid.csv"]
    metrica_paths = [f"data/metrica_traces/{f}" for f in metrica_files]
    data_paths = metrica_paths[:-1]

elif sports == "basketball":
    nba_files = os.listdir("data/nba_traces")
    nba_paths = [f"data/nba_traces/{f}" for f in nba_files]
    nba_paths.sort()
    data_paths = nba_paths[:10]

else:  # sports == "afootball"
    nfl_paths = ["data/nfl_traces/nfl_train.csv", "data/nfl_traces/nfl_test.csv"]
    data_paths = nfl_paths[:-1]

In [3]:
train_dataset = SportsDataset(
    sports=sports,
    data_paths=data_paths,
    n_features=6,
    normalize=True,
    flip_pitch=True
)
len(train_dataset)

100%|██████████| 2/2 [00:08<00:00,  4.15s/it]


9098

In [6]:
player_data = train_dataset[75:80][0].to("cuda:0")
ball_data = train_dataset[75:80][1].to("cuda:0")

if player_data.is_cuda:
    valid_lens = np.array(player_data.cpu()[..., 0] != -100).astype(int).sum(axis=-1)
else:
    valid_lens = np.array(player_data[..., 0] != -100).astype(int).sum(axis=-1)

valid_lens

array([200, 200, 200, 187, 200])

In [7]:
n_players = 22
missing_rate = 0.5

mask = np.ones((player_data.shape[0], player_data.shape[1], n_players))
missing_lens = np.zeros((mask.shape[0], n_players)).astype(int)  # [bs, players]

residue = (valid_lens * n_players * missing_rate).astype(int)  # [bs,]
max_shares = np.tile(valid_lens - 10, (n_players, 1)).T  # [bs, players]
assert np.all(residue < max_shares.sum(axis=-1))

for i in range(mask.shape[0]):
    while residue[i] > 0:
        slots = missing_lens[i] < max_shares[i]  # [players,]
        breakpoints = np.random.choice(residue[i] + 1, slots.astype(int).sum() - 1, replace=True)  # [players - 1,]
        shares = np.diff(np.sort(breakpoints.tolist() + [0, residue[i]]))  # [players,]
        # print()
        # print(residue)
        # print(slots.astype(int))
        # print(shares)

        missing_lens[i, ~slots] = max_shares[i, ~slots]
        missing_lens[i, slots] += shares
        residue[i] = np.clip(missing_lens[i] - max_shares[i], 0, None).sum()
        # print(missing_lens[i])

start_idxs = np.random.randint(1, max_shares - missing_lens + 2)  # [bs, players]
end_idxs = start_idxs + missing_lens  # [bs, players]

for i in range(mask.shape[0]):
    for p in range(n_players):
        mask[i, start_idxs[i, p] : end_idxs[i, p], p] = 0

missing_lens, missing_lens.sum(axis=1) // n_players

(array([[ 22, 150, 190, 115,  43, 190,  96, 113, 142,  21,  76, 190,  78,
         110, 190,   4,   8,  59,   5,  75, 190, 133],
        [144,  22, 135,  57,  34,  46,  84,  65,  80,  43,  17, 105,  67,
         190, 158,  60, 170,  88, 164, 190, 190,  91],
        [ 54,  18,  48, 190, 155,  96, 190,  50, 190,  22,  52, 103,  53,
         153, 151,  89, 184,  85,  79,  57,  33, 148],
        [177,  11, 177, 177,  70,  58,  52, 118,  45, 161,  90, 137,  62,
          57, 101,  73,  91,  38,  91,  34,  60, 177],
        [ 85, 140,  70,  42,  92,  51,  38,  37, 190, 167,  63,  21,  31,
         127,  94,  53, 190, 167, 122, 190, 172,  58]]),
 array([100, 100, 100,  93, 100]))

In [9]:
data_dict = {"target": player_data, "ball": ball_data}
mask = generate_mask(data_dict, mode="playerwise", missing_rate=0.5)
mask.shape

(5, 200, 22)