In [1]:
%load_ext autoreload
%autoreload 2

import os, json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib import animation
from tqdm import tqdm
%matplotlib inline

from dataset import SportsDataset
# from models import load_model
from models.utils import reshape_tensor

from torch.nn.functional import softplus

### Load data

### (1) Load single model result

In [2]:
trial = 508
save_path = f"saved/{trial:03d}"
rets = {}

helper =  torch.load(save_path + "/helper")
match_ret = torch.load(save_path + "/match_ret")

rets["target"] = match_ret["target"]
rets["mask"] = match_ret["mask"]

if helper.params["model"] == "midas":
    rets[helper.params["model"]] = match_ret["hybrid_d"]
else:
    rets[helper.params["model"]] = match_ret["pred"]

if helper.params["navie_baselines"]:
    rets["linear"] = match_ret["linear"]
    # rets["knn"] = match_ret["knn"]
    # rets["ffill"] = match_ret["ffill"]

if helper.params["dataset"] == "soccer":
    rets["ball"] = match_ret["ball"]

rets.keys()

dict_keys(['target', 'mask', 'imputeformer', 'linear', 'ball'])

### (2) Add baseline models result

In [3]:
# trial_ids = [3, 902,9932]

# for trial in trial_ids:
#     save_path = f"saved/{trial:03d}"
#     helper =  torch.load(save_path + "/helper")
#     match_ret = torch.load(save_path + "/match_ret")

#     rets[helper.params["model"]] = match_ret["pred"]

In [4]:
rets.keys()

dict_keys(['target', 'mask', 'imputeformer', 'linear', 'ball'])

### Define the pitch control computation function

In [5]:
def _compute_pitch_control(model_type: str="midas", 
                            traces: None = torch.Tensor,
                            ball_traces: None = torch.Tensor,
                            n_grid_x=50, n_grid_y=30):

    '''
    traces : [time, players * feats (=x_dim)]
    ball_traces : [time, 2]
    '''

    device = "cuda:0"; dtype = torch.float32

    window_size, min_window_size = 200, 100
    seq_len = traces.shape[0]
    pc = torch.zeros(seq_len, n_grid_x, n_grid_y)
    
    if traces.shape[1] % window_size < min_window_size:
        n_windows = seq_len // window_size
    else:
        n_windows = seq_len // window_size + 1
    with torch.no_grad():
        for i in range(n_windows):
            i_from = window_size * i
            i_to = window_size * (i + 1) if i < n_windows - 1 else seq_len
            
            window_traces = traces[i_from:i_to]
            window_ball = ball_traces[i_from:i_to]

            pos_traces = reshape_tensor(window_traces, mode="pos").transpose(0, 1) # [players, time, 2]

            players, seq_len = pos_traces.shape[:2]

            team1_pos = pos_traces[:11, :, None, None, :] # [team1_players, time, 1, 1, 2]
            team2_pos = pos_traces[11:, :, None, None, :] # [team2_players, time, 1, 1, 2]

            jitter = 1e-12 ## prevents division by zero when player is stationary
            x_traces = pos_traces[..., 0]
            y_traces = pos_traces[..., 1]

            vx = torch.cat((x_traces[:, :1], x_traces[:, 1:] - x_traces[:, :-1]), dim=1 ) / 0.1 # [players, time]
            vy = torch.cat((y_traces[:, :1], y_traces[:, 1:] - y_traces[:, :-1]), dim=1) / 0.1
            vel_traces = torch.cat((vx.unsqueeze(-1), vy.unsqueeze(-1)), dim=-1) + jitter # [players, time, 2]

            # if model_type == "midas":
            #     vel_traces = reshape_tensor(traces, mode="vel").transpose(0, 1) + jitter # [players, time, 2]
            # else:
            #     x_traces = pos_traces[..., 0]
            #     y_traces = pos_traces[..., 1]

            #     vx = torch.cat((x_traces[:, :1], x_traces[:, 1:] - x_traces[:, :-1]), dim=1 ) / 0.1 # [players, time]
            #     vy = torch.cat((y_traces[:, :1], y_traces[:, 1:] - y_traces[:, :-1]), dim=1) / 0.1
            #     vel_traces = torch.cat((vx.unsqueeze(-1), vy.unsqueeze(-1)), dim=-1) + jitter # [players, time, 2]
            
            team1_vel = vel_traces[:11, :, None, None, :] # [team1_players, time, 1, 1, 2]
            team2_vel = vel_traces[11:, :, None, None, :] # [team2_players, time, 1, 1, 2]

            ball_pos = window_ball[None, :, None, None, :] # [1, time, 1, 1, 2]

            ## set up evaluation grid and set some pitch control parameters (these are taken from the FoT code)
            reaction_time = 0.7
            max_player_speed = 5.
            average_ball_speed = 15.
            sigma = np.pi / np.sqrt(3.) / 0.45
            lamb = 4.3
        
            XX, YY = torch.meshgrid(torch.linspace(0, 108, n_grid_x, device=device, dtype=dtype), 
                                torch.linspace(0, 72, n_grid_y, device=device, dtype=dtype))
            
            ti, wi = np.polynomial.legendre.leggauss(50) ## used for numerical integration later on
            ti = torch.tensor(ti, device=device, dtype=dtype)
            wi = torch.tensor(wi, device=device, dtype=dtype)

            target_position = torch.stack([XX, YY], 2)[None, None, :, :, :] # [1, 1, n_grid_x, n_grid_y, 2]

            tti = torch.empty([players, seq_len, n_grid_x, n_grid_y], device = device, dtype=dtype)
            tmp2 = torch.empty([players, seq_len, n_grid_x, n_grid_y, 1], device = device, dtype=dtype)

            ## compute pitch control
            bp = ball_pos.to(device, dtype=dtype)
            hp = team1_pos.to(device, dtype=dtype)
            hv = team1_vel.to(device, dtype=dtype)
            ap = team2_pos.to(device, dtype=dtype)
            av = team2_vel.to(device, dtype=dtype)
            
            ball_travel_time = torch.norm(target_position - bp, dim=4).div_(average_ball_speed)
            r_reaction_home = hp + hv.mul_(reaction_time)
            r_reaction_away = ap + av.mul_(reaction_time)
            r_reaction_home = r_reaction_home - target_position
            r_reaction_away = r_reaction_away - target_position
            tti[:11, :seq_len] = torch.norm(r_reaction_home,dim=4).add_(reaction_time).div_(max_player_speed)
            tti[11:, :seq_len] = torch.norm(r_reaction_away,dim=4).add_(reaction_time).div_(max_player_speed)

            tmp2[..., 0] = sigma * (ball_travel_time - tti)
            tmp1 = sigma * 0.5 * (ti + 1) * 10 + tmp2
            hh = torch.sigmoid(tmp1[:11]).mul_(4.3)
            h = hh.sum(0)
            S = torch.exp(-lamb*torch.sum(softplus(tmp1) - softplus(tmp2), dim=0).div_(sigma))

            pc[i_from : i_to] = (torch.matmul(S*h, wi).mul_(5.)).detach().cpu()

            # Move the result to CPU and clean up
            del XX, YY, ti, wi, tti, tmp1, tmp2, bp, hp, hv, ap, av, ball_travel_time, r_reaction_home, r_reaction_away, hh, h, S, target_position

            torch.cuda.empty_cache()

        return pc

In [6]:
rets.keys()

dict_keys(['target', 'mask', 'imputeformer', 'linear', 'ball'])

In [13]:
players = helper.team1_players + helper.team2_players
# feature_types = ["_x", "_y"]
feature_types = ["_x", "_y", "_vx", "_vy", "_ax", "_ay"]
player_cols = [f"{p}{x}" for p in players for x in feature_types]

x_cols = [c for c in helper.traces.columns if c.endswith("_x")]
y_cols = [c for c in helper.traces.columns if c.endswith("_y")]

traces = rets['target']
# traces = rets['imputeformer']
masks = rets["mask"]

outputs = []
for phase in traces["phase"].unique():
    phase_traces = traces[traces["phase"] == phase]

    phase_gks = SportsDataset.detect_goalkeepers(traces)
    team1_code, team2_code = phase_gks[0][0], phase_gks[1][0]

    phase_player_cols = phase_traces[player_cols].dropna(axis=1).columns
    team1_cols = [c for c in phase_player_cols if c.startswith(team1_code)]
    team2_cols = [c for c in phase_player_cols if c.startswith(team2_code)]
    ball_cols = ["ball_x", "ball_y"]

    # reorder teams so that the left team comes first
    phase_player_cols = team1_cols + team2_cols

    if min(len(team1_cols), len(team2_cols)) < helper.params["team_size"] * len(feature_types):
        continue
    
    player_traces = torch.FloatTensor(phase_traces[phase_player_cols].values) # [time, x_dim]
    ball_traces = torch.FloatTensor(phase_traces[ball_cols].values) # [time, x_dim]
    mask = torch.FloatTensor(masks.loc[phase_traces.index, phase_player_cols].values) # [time, 2]

    phase_pc = _compute_pitch_control(traces=player_traces, ball_traces=ball_traces) # [time, n_grid_x, n_grid_y]

    outputs.append([player_traces, ball_traces, phase_pc])

### Animating Pitch Control result

In [14]:
phase_output = outputs[0] # get phase outputs

traces, ball_traces, pc = phase_output

print(traces.shape)
print(ball_traces.shape)
print(pc.shape)

torch.Size([8014, 132])
torch.Size([8014, 2])
torch.Size([8014, 50, 30])


In [15]:
torch.save(pc, 'target.pt')

In [16]:
pos_traces = reshape_tensor(traces, mode="pos").transpose(0, 1) # [players, time, 2]
team1_pos = pos_traces[:11, :, None, None, :] # [team1_players, time, 1, 1, 2]
team2_pos = pos_traces[11:, :, None, None, :] # [team2_players, time, 1, 1, 2]
ball_pos = ball_traces[None, :, None, None, :] # [1, time, 1, 1, 2]

print(team1_pos.shape, team2_pos.shape, ball_pos.shape)
print(pc.shape)

torch.Size([11, 8014, 1, 1, 2]) torch.Size([11, 8014, 1, 1, 2]) torch.Size([1, 8014, 1, 1, 2])
torch.Size([8014, 50, 30])


In [17]:
import matplotlib.pyplot as plt
import matplotlib.animation
from matplotsoccer import field
from IPython.core.display import HTML

## use these parameters to set which frames you want to see
first_frame_to_plot = 6460
n_frames_to_plot = 10

ps = (105, 68)

xx = np.linspace(0, ps[0], 50)
yy = np.linspace(0, ps[1], 30)

locs_ball_reduced = ball_pos[0, :, 0, 0] # [time, 2]
locs_home_reduced = team1_pos[:, :, 0, 0] # [players, time, 2]
locs_away_reduced = team2_pos[:, :, 0, 0]

fig, ax=plt.subplots()
field(ax=ax,show = False)
ax.set_xlim(0, ps[0])
ax.set_ylim(0, ps[1])
ball_points = ax.scatter(locs_ball_reduced[first_frame_to_plot,0],locs_ball_reduced[first_frame_to_plot,1],color = 'black',zorder = 15, s = 16)
ball_points2 = ax.scatter(locs_ball_reduced[first_frame_to_plot,0],locs_ball_reduced[first_frame_to_plot,1],color = 'white',zorder = 15, s = 9)
home_points = ax.scatter(locs_home_reduced[:,first_frame_to_plot,0],locs_home_reduced[:,first_frame_to_plot,1],color = 'red',zorder = 10)
away_points = ax.scatter(locs_away_reduced[:,first_frame_to_plot,0],locs_away_reduced[:,first_frame_to_plot,1],color = 'blue',zorder = 10)
p = [ax.contourf(xx,
                 yy,
                 pc[first_frame_to_plot].t().cpu(),
                 extent = (0, ps[0], 0, ps[1]),
                 levels = np.linspace(0, 1, 100),
                 cmap = 'coolwarm',
                 extend='both')]

def update(i):
    fr = i + first_frame_to_plot
    for tp in p[0].collections:
        tp.remove()
    p[0] = ax.contourf(xx,
                    yy,
                    pc[fr].t().cpu(),
                    extent = (0, ps[0], 0, ps[1]),
                    levels = np.linspace(0, 1, 100),
                    cmap = 'coolwarm',
                    extend='both')
    ball_points.set_offsets(np.c_[[locs_ball_reduced[fr,0]],[locs_ball_reduced[fr,1]]])
    ball_points2.set_offsets(np.c_[[locs_ball_reduced[fr,0]],[locs_ball_reduced[fr,1]]])
    home_points.set_offsets(np.c_[locs_home_reduced[:,fr,0],locs_home_reduced[:,fr,1]])
    away_points.set_offsets(np.c_[locs_away_reduced[:,fr,0],locs_away_reduced[:,fr,1]])
    return p[0].collections + [ball_points,home_points,away_points]

ani = matplotlib.animation.FuncAnimation(fig, update, frames=n_frames_to_plot, 
                                         interval=40, blit=True, repeat=True)

# Save the animation as an MP4 file
ani.save('midas_pc.mp4', writer='ffmpeg', dpi=300)

HTML(ani.to_html5_video())

### Compare Pitch Control Analysis

In [19]:
# midas_pred_pc = torch.load('midas_pc.pt')
midas_pred_pc = torch.load('imputeformer.pt')
target_pc = torch.load('target.pt')

first_frame_to_plot = 6460
n_frames_to_plot = 0

ps = (105, 68)

xx = np.linspace(0, ps[0], 50)
yy = np.linspace(0, ps[1], 30)

locs_ball_reduced = ball_pos[0, :, 0, 0] # [time, 2]
locs_home_reduced = team1_pos[:, :, 0, 0] # [players, time, 2]
locs_away_reduced = team2_pos[:, :, 0, 0]

fig, ax=plt.subplots()
field(ax=ax,show = False)
ax.set_xlim(0, ps[0])
ax.set_ylim(0, ps[1])
ball_points = ax.scatter(locs_ball_reduced[first_frame_to_plot,0],locs_ball_reduced[first_frame_to_plot,1],color = 'black',zorder = 15, s = 16)
ball_points2 = ax.scatter(locs_ball_reduced[first_frame_to_plot,0],locs_ball_reduced[first_frame_to_plot,1],color = 'white',zorder = 15, s = 9)
home_points = ax.scatter(locs_home_reduced[:,first_frame_to_plot,0],locs_home_reduced[:,first_frame_to_plot,1],color = 'red',zorder = 10)
away_points = ax.scatter(locs_away_reduced[:,first_frame_to_plot,0],locs_away_reduced[:,first_frame_to_plot,1],color = 'blue',zorder = 10)
p = [ax.contourf(xx,
                 yy,
                 (target_pc - midas_pred_pc)[first_frame_to_plot].t().abs().cpu(),
                 extent = (0, ps[0], 0, ps[1]),
                 levels = np.linspace(0, 1, 100),
                 cmap = 'binary',
                 extend='both')]

plt.savefig('imputeformer.pdf', bbox_inches='tight')