In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dist
import matplotlib.pyplot as plt

from tqdm import tqdm
import copy

pd.set_option('display.max_columns', 500)

In [None]:
BALL_VARS = ['ballx', 'bally']
VARS_X = [f'p{i}x' for i in range(1, 23)]
VARS_Y = [f'p{i}y' for i in range(1, 23)]
CONTEXT_VARS = BALL_VARS + VARS_X + VARS_Y
FRAME_VARS = [f.replace('x', 'dx').replace('y', 'dy') for f in CONTEXT_VARS]

In [None]:
N_CONTEXT_FRAMES = 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
MODEL_LAYERS = {
    'layer_1': {
        'input': ['playDirection'] + [f'c1_{var}' for var in CONTEXT_VARS] + [f'c1_{var}' for var in FRAME_VARS],
        'output': [var for var in FRAME_VARS if ('dx' in var) and ('ball' not in var)],
    },
    'layer_2': {
        'input': ['playDirection'] + [f'c1_{var}' for var in CONTEXT_VARS] + [f'c1_{var}' for var in FRAME_VARS] + [var for var in FRAME_VARS if ('dx' in var) and ('ball' not in var)],
        'output': [var for var in FRAME_VARS if ('dy' in var) and ('ball' not in var)],
    },
    'layer_3': {
        'input': ['playDirection'] + [f'c1_{var}' for var in CONTEXT_VARS] + [f'c1_{var}' for var in FRAME_VARS] + [var for var in FRAME_VARS if ('dx' in var) and ('ball' not in var)] + [var for var in FRAME_VARS if ('dy' in var) and ('ball' not in var)],
        'output': [var for var in FRAME_VARS if 'ball' in var],
    },
}

print('playDirection', 0)
print('c1_ball_xy', 1, 1 + len(BALL_VARS))
print('c1_player_x', 1 + len(BALL_VARS), 1 + len(BALL_VARS) + len(VARS_X))
print('c1_player_y', 1 + len(BALL_VARS) + len(VARS_X), 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y))
print('c1_player_dx', 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y), 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X))
print('c1_player_dy', 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X), 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y))
print('player_dx', 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y), 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y) + len(VARS_X))
print('player_dy', 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y) + len(VARS_X), 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y))
print('ball_dxdy', 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y), 1 + len(BALL_VARS) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y) + len(VARS_X) + len(VARS_Y) + len(BALL_VARS))
print(len(MODEL_LAYERS['layer_3']['input']))

playDirection 0
c1_ball_xy 1 3
c1_player_x 3 25
c1_player_y 25 47
c1_player_dx 47 69
c1_player_dy 69 91
player_dx 91 113
player_dy 113 135
ball_dxdy 135 137
137


# Loading data

In [None]:
data = pd.read_feather('data/tracking2.feather')

In [None]:
features = ['playDirection']
# add context vars
i = 1
data_context = data.shift(i)
data_context.loc[data_context['gameId'] != data['gameId'], FRAME_VARS] = pd.NA
data_context = data_context[CONTEXT_VARS].add_prefix(f'c{i}_')
data = pd.concat([data, data_context], axis=1)
features += list(data_context.columns)
# add frame vars
for i in range(1, N_CONTEXT_FRAMES + 1):
    data_context = data.shift(i)
    data_context.loc[data_context['gameId'] != data['gameId'], FRAME_VARS] = pd.NA
    data_context = data_context[FRAME_VARS].add_prefix(f'c{i}_')
    data = pd.concat([data, data_context], axis=1)
    features += list(data_context.columns)

data = data.dropna()

In [None]:
data

In [None]:
data = data[data['is_after_snap'] == True]

In [None]:
print(features, len(features))

In [None]:
train_data = data[data['gameId'] < 2022101700]
val_data = data[(data['gameId'] >= 2022101700) & (data['gameId'] < 2022102400)]
test_data = data[data['gameId'] >= 2022102400]

# Modeling

In [None]:
# MDN Model Definition
class MDN(nn.Module):
    def __init__(self, input_dim, output_dim, n_gaussians):
        super(MDN, self).__init__()

        self.output_dim = output_dim
        self.n_gaussians = n_gaussians

        self.fc1 = nn.Linear(input_dim, 1024)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(1024, 1024)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc_mu = nn.Linear(1024, output_dim * n_gaussians * 2)   # Means
        self.fc_sigma = nn.Linear(1024, output_dim * n_gaussians * 2) # Std deviations
        self.fc_pi = nn.Linear(1024, output_dim * n_gaussians)        # Mixture weights

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        mu = self.fc_mu(x).view(-1, self.output_dim, self.n_gaussians)
        sigma = F.softplus(self.fc_sigma(x)).view(-1, self.output_dim, self.n_gaussians)# + 1e-6
        pi = torch.softmax(self.fc_pi(x).view(-1, self.output_dim, self.n_gaussians), dim=-1)
        return mu, sigma, pi

# Loss function
def mdn_nll_loss(y, mu, sigma, pi):
    m = dist.Normal(mu, sigma)
    log_prob = m.log_prob(y.unsqueeze(2))  # Add Gaussian components dimension
    log_prob = torch.sum(log_prob, dim=-1) # Sum over x and y dimensions
    weighted_log_prob = log_prob + torch.log(pi)
    nll = -torch.logsumexp(weighted_log_prob, dim=-1)  # Log-sum-exp over gaussians
    return nll.mean()

# Sampling function
def sample_from_mdn(mu, sigma, pi, n_samples=1):
    batch_size, n_points, n_gaussians, _ = mu.shape
    samples = []

    for b in range(batch_size):
        point_samples = []
        for p in range(n_points):
            # Sample component
            categorical = dist.Categorical(pi[b, p])
            component = categorical.sample((n_samples,))
            
            # Sample from Gaussian
            chosen_mu = mu[b, p, component, :]
            chosen_sigma = sigma[b, p, component, :]

            # Check if elements in chosen_sigma are close to zero
            is_sigma_zero = torch.isclose(chosen_sigma, torch.zeros_like(chosen_sigma))
            
            # Set sigma to 1 temporarily where it's close to zero to avoid errors during sampling
            safe_sigma = torch.where(is_sigma_zero, torch.ones_like(chosen_sigma), chosen_sigma)
            
            # Sample from Normal distribution with safe_sigma
            normal_dist = dist.Normal(chosen_mu, safe_sigma)
            samples_for_point = normal_dist.sample()
            
            # Replace values where sigma was originally close to zero with mean
            samples_for_point = torch.where(is_sigma_zero, chosen_mu, samples_for_point)
            
            point_samples.append(samples_for_point)

        samples.append(torch.stack(point_samples).squeeze(1))
    
    return torch.stack(samples)  # Shape: (batch_size, n_points, n_samples, 2)

In [None]:
n_gaussians = 3 # Number of mixture components

In [None]:
for layer in MODEL_LAYERS.values():
    x_train = torch.tensor(train_data[layer['input']].values, dtype=torch.float32).to(DEVICE)
    y_train = torch.tensor(train_data[layer['output']].values, dtype=torch.float32).to(DEVICE)
    x_val = torch.tensor(val_data[layer['input']].values, dtype=torch.float32).to(DEVICE)
    y_val = torch.tensor(val_data[layer['output']].values, dtype=torch.float32).to(DEVICE)

    train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    val_dataset = torch.utils.data.TensorDataset(x_val, y_val)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, shuffle=False)

    model = MDN(input_dim=len(layer['input']), output_dim=len(layer['output']), n_gaussians=n_gaussians).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    n_epochs = 25
    

In [None]:
for epoch in range(n_epochs):
    for x_train, y_train in tqdm(train_loader):
        optimizer.zero_grad()
        mu, sigma, pi = model(x_train)
        loss = mdn_nll_loss(y_train, mu, sigma, pi)
        loss.backward()
        optimizer.step()
        
    if ((epoch % 1)) == 0:
        val_loss = 0
        for x_val, y_val in val_loader:
            mu, sigma, pi = model(x_val)
            val_loss += mdn_nll_loss(y_val, mu, sigma, pi).item()
        val_loss /= len(val_loader)
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

# Plotting

In [None]:
play_data = val_data[(val_data['gameId'] == 2022101700) & (val_data['playId'] == 90)]
play_data.shape, play_data['playDirection'].iloc[0]

In [None]:
for i in range(1, 23):
    plt.plot(play_data[f'p{i}x'], play_data[f'p{i}y'], '.', color='red' if i <= 11 else 'blue', markersize=2, alpha=0.5)
plt.plot(play_data['ballx'], play_data['bally'], 'o', color='black')
plt.xlim(0, 1);
plt.ylim(0, 1);

In [None]:
frames = []

current_frame = play_data.iloc[0][features].values.tolist()
for i in tqdm(range(play_data.shape[0])):
    x = torch.tensor(current_frame, dtype=torch.float32).to(DEVICE).unsqueeze(0)
    mu, sigma, pi = model(x)
    samples = sample_from_mdn(mu, sigma, pi, n_samples=1).squeeze(2).detach().cpu().numpy()
    # Sum the sample with the current frame to get the next frame
    current_frame[1:n_points * 2 + 1] += samples.flatten()
    # shift the context frames
    current_frame = np.concat([current_frame[:n_points * 2 + 1], samples.flatten(), current_frame[2 * n_points * 2 + 1:]])
    frames.append(current_frame[1:n_points * 2 + 1])
    
frames = pd.DataFrame(frames, columns=CONTEXT_VARS)

In [None]:
plt.rcParams['figure.figsize'] = [7.5, 5]
for i in range(1, 23):
    plt.plot(frames[f'p{i}x'], frames[f'p{i}y'], '.', color='red' if i <= 11 else 'blue', markersize=2, alpha=0.5)
plt.plot(frames['ballx'], frames['bally'], color='black', marker='o')
plt.xlim(0, 1);
plt.ylim(0, 1);

In [None]:
ball_positions_in_lookahead = {}
max_lookahead = 70

current_frame = play_data.iloc[0][features].values.tolist()
current_frame = [current_frame for _ in range(100)]
x = torch.tensor(current_frame, dtype=torch.float32).to(DEVICE)
for i in tqdm(range(max_lookahead + 1)):
    mu, sigma, pi = model(x)
    samples = sample_from_mdn(mu, sigma, pi, n_samples=1).squeeze(2).detach().cpu().numpy()
    samples = torch.tensor([sample.flatten() for sample in samples], dtype=torch.float32).to(DEVICE)
    x[:, 1:n_points * 2 + 1] += samples
    x = torch.concat([x[:, :n_points * 2 + 1], samples], axis=1)

    if (i % 10 == 0) and (i != 0):
        ball_positions_in_lookahead[str(i)] = x[:, 1:3].detach().cpu().numpy()

ball_positions_in_lookahead = {'10': [], '20': [], '30': [], '40': [], '50': [], '60': [], '70': []}
max_lookahead = 70
for i in tqdm(range(100)):
    current_frame = play_data.iloc[0][features].values.tolist()
    for i in range(max_lookahead + 1):
        x = torch.tensor(current_frame, dtype=torch.float32).to(DEVICE).unsqueeze(0)
        mu, sigma, pi = model(x)
        samples = sample_from_mdn(mu, sigma, pi, n_samples=1).squeeze(2).detach().cpu().numpy()
        # Sum the sample with the current frame to get the next frame
        current_frame[1:n_points * 2 + 1] += samples.flatten()
        # shift the context frames
        current_frame = np.concat([current_frame[:n_points * 2 + 1], samples.flatten(), current_frame[2 * n_points * 2 + 1:]])

        if (i % 10 == 0) and (i != 0):
            ball_positions_in_lookahead[str(i)].append(current_frame[1:3])

In [None]:
plt.rcParams['figure.figsize'] = [7.5, 5]
for i, ball_positions in enumerate(ball_positions_in_lookahead.values()):
    lookahead_data = pd.DataFrame(ball_positions, columns=['ballx', 'bally'])
    plt.scatter(lookahead_data['ballx'], lookahead_data['bally'], c=f'C{i}', alpha=0.5)
plt.scatter(play_data.iloc[0]['ballx'], play_data.iloc[0]['bally'], c='black')
plt.xlim(0, 1);
plt.ylim(0, 1);

In [None]:
import seaborn as sns
plt.rcParams['figure.figsize'] = [7.5, 5]
cmaps = ['Reds', 'Blues', 'Greens', 'Purples', 'Oranges', 'Greys']
for i, ball_positions in enumerate(list(ball_positions_in_lookahead.values())[:6]):
    lookahead_data = pd.DataFrame(ball_positions, columns=['ballx', 'bally'])
    sns.kdeplot(x=lookahead_data['ballx'], y=lookahead_data['bally'], cmap=cmaps[i])
plt.scatter(play_data.iloc[0]['ballx'], play_data.iloc[0]['bally'], c='black')
plt.xlim(0, 1);
plt.ylim(0, 1);

In [None]:
plt.rcParams['figure.figsize'] = [30, 5]
cmaps = ['Reds', 'Blues', 'Greens', 'Purples', 'Oranges', 'Greys', 'pink', 'copper', 'bone', 'hot', 'cool', 'spring', 'summer', 'autumn', 'winter', 'cividis', 'twilight', 'twilight_shifted', 'hsv', 'viridis', 'plasma', 'inferno', 'magma']
n_lookaheads = len(list(ball_positions_in_lookahead.values()))
for i, ball_positions in enumerate(list(ball_positions_in_lookahead.values())):
    plt.subplot(1, n_lookaheads, i+1)
    plt.title(f'Lookahead {i * 10}')
    lookahead_data = pd.DataFrame(ball_positions, columns=['ballx', 'bally'])
    sns.kdeplot(x=lookahead_data['bally'], y=lookahead_data['ballx'], cmap='cividis_r')
    plt.scatter(play_data.iloc[0]['bally'], play_data.iloc[0]['ballx'], c='black', zorder=10)
    plt.plot(play_data.iloc[:10*i]['bally'], play_data.iloc[:10*i]['ballx'], color='black', marker='.', markersize=1)
    plt.xlim(0, 1);
    plt.ylim(0, 1);

In [None]:
plt.rcParams['figure.figsize'] = [30, 5]
cmaps = ['Reds', 'Blues', 'Greens', 'Purples']
for i, ball_positions_in_lookahead in enumerate([ball_positions_in_lookahead_10, ball_positions_in_lookahead_20, ball_positions_in_lookahead_30, ball_positions_in_lookahead_40]):
    # plt.subplot(1, 4, i+1)
    lookahead_data = pd.DataFrame(ball_positions_in_lookahead, columns=['ballx', 'bally'])
    sns.jointplot(data=lookahead_data, x="ballx", y="bally", kind="kde", cmap=cmaps[i])#hue="species", 
    # sns.kdeplot(x=lookahead_data['ballx'], y=lookahead_data['bally'])
    plt.scatter(play_data.iloc[0]['ballx'], play_data.iloc[0]['bally'], c='black')
    plt.xlim(0, 1);
    plt.ylim(0, 1);

In [None]:
current_frame = play_data.iloc[0][features].values
x = torch.tensor(current_frame, dtype=torch.float32).to(DEVICE).unsqueeze(0)
mu, sigma, pi = model(x)
mu.flatten()

In [None]:
current_frame.round(4)

In [None]:
sigma.flatten()

In [None]:
mu.flatten().detach().cpu().numpy() - current_frame.round(4)