In [171]:
%load_ext autoreload
%autoreload 2

import os
import random

import numpy as np

from dataset import SportsDataset
from models.utils import generate_mask

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
sports = "soccer"

if sports == "soccer":
    metrica_files = ["match1.csv", "match2.csv", "match3_valid.csv"]
    metrica_paths = [f"data/metrica_traces/{f}" for f in metrica_files]
    data_paths = metrica_paths[:-1]

elif sports == "basketball":
    nba_files = os.listdir("data/nba_traces")
    nba_paths = [f"data/nba_traces/{f}" for f in nba_files]
    nba_paths.sort()
    data_paths = nba_paths[:10]

else:  # sports == "afootball"
    nfl_paths = ["data/nfl_traces/nfl_train.csv", "data/nfl_traces/nfl_test.csv"]
    data_paths = nfl_paths[:-1]

In [32]:
train_dataset = SportsDataset(
    sports=sports,
    data_paths=data_paths,
    n_features=6,
    normalize=True,
    flip_pitch=True
)
len(train_dataset)

100%|██████████| 2/2 [00:08<00:00,  4.00s/it]


9098

In [183]:
player_data = train_dataset[75:80][0]
ball_data = train_dataset[75:80][1]
valid_lens = np.array(player_data[..., 0] != -100).astype(int).sum(axis=-1)
valid_lens

array([200, 200, 200, 187, 200])

In [173]:
n_players = 22
missing_rate = 0.5

mask = np.ones((player_data.shape[0], player_data.shape[1], n_players))
missing_lens = np.zeros((mask.shape[0], n_players)).astype(int)  # [bs, players]

residue = (valid_lens * n_players * missing_rate).astype(int)  # [bs,]
max_shares = np.tile(valid_lens - 10, (n_players, 1)).T  # [bs, players]
assert np.all(residue < max_shares.sum(axis=-1))

for i in range(mask.shape[0]):
    while residue[i] > 0:
        slots = missing_lens[i] < max_shares[i]  # [players,]
        breakpoints = np.random.choice(residue[i] + 1, slots.astype(int).sum() - 1, replace=True)  # [players - 1,]
        shares = np.diff(np.sort(breakpoints.tolist() + [0, residue[i]]))  # [players,]
        # print()
        # print(residue)
        # print(slots.astype(int))
        # print(shares)

        missing_lens[i, ~slots] = max_shares[i, ~slots]
        missing_lens[i, slots] += shares
        residue[i] = np.clip(missing_lens[i] - max_shares[i], 0, None).sum()
        # print(missing_lens[i])

start_idxs = np.random.randint(1, max_shares - missing_lens + 2)  # [bs, players]
end_idxs = start_idxs + missing_lens  # [bs, players]

for i in range(mask.shape[0]):
    for p in range(n_players):
        mask[i, start_idxs[i, p] : end_idxs[i, p], p] = 0

missing_lens, missing_lens.sum(axis=1) // n_players

(array([[ 89, 103, 190,  91,  20,  94, 190,  18, 131,  78,  99,  20,  20,
          74,  40, 139, 115, 190, 103, 190,  16, 190],
        [ 38,   7,  48,  65, 170, 190, 163,  57, 190,  32,  33, 128, 190,
          45,  55, 102, 175, 190, 109,  48, 138,  27],
        [151,  91,  55,  39,  76,  98, 181, 105,  16,  11, 162,  46,  10,
          38, 190,   6, 190, 190, 185,  17, 154, 189],
        [133,  82, 105, 104,  36,  64, 177,  39, 114,  87,  39,  73,  46,
         177, 177,  50, 177,  35,  22,  31, 177, 112],
        [ 49, 105,  73,  74,  64,  43,  60, 190,  62,  92, 190,  31, 178,
          98, 190, 104,  85,  41, 126,  68,  99, 178]]),
 array([100, 100, 100,  93, 100]))

In [188]:
data_dict = {"target": player_data, "ball": ball_data}
mask = generate_mask(data_dict, mode="playerwise", missing_rate=0.5)