In [1]:
import json
with open('action_types.json', 'r') as f:
    action_types = json.load(f)

In [2]:
import pandas as pd
from preprocess_data import *

df = (
    pd.read_csv("WSL_actions.csv", index_col = 0)
    .pipe(add_coordinate_bins, n_bins_x = 10, n_bins_y = 10)
    .pipe(add_team_as_dummy)
    .pipe(get_action_type_names, action_types)
    .pipe(get_action_tokens)
    .assign(
        match_id = lambda d: d.groupby(['game_id']).ngroup(),
        action_token = lambda d: pd.Categorical(d.action_token)
    )
    [['match_id', 'action_token']]
)

vocab = df['action_token'].cat.categories

df

Unnamed: 0,match_id,action_token
0,315,"True,pass,4,4"
1,315,"True,receival,4,5"
2,315,"True,dribble,4,5"
3,315,"True,pass,5,5"
4,315,"True,receival,6,4"
...,...,...
3130,82,"False,receival,8,0"
3131,82,"False,dribble,8,0"
3132,82,"False,tackle,9,0"
3133,82,"False,dribble,9,0"


In [3]:
from numpy.random import choice, seed

seed(42)
train_groups = choice(df['match_id'].unique(), int(0.8 * df['match_id'].nunique()), replace = False)
validation_candidates = list(set(df['match_id'].unique()) - set(train_groups))
val_groups = choice(validation_candidates, int(len(validation_candidates) * 0.5), replace = False)
test_groups = np.array(list(set(validation_candidates) - set(val_groups)))
train_groups[:5]

array([ 96, 313,  43, 251, 281], dtype=int64)

In [4]:
counts = pd.read_csv("transition_counts.csv", index_col = 0)
counts_smoothed = counts + 1
probs = counts_smoothed / counts_smoothed.sum()
probs

Unnamed: 0_level_0,"False,bad_touch,0,0","False,bad_touch,0,1","False,bad_touch,0,2","False,bad_touch,0,3","False,bad_touch,0,4","False,bad_touch,0,5","False,bad_touch,0,6","False,bad_touch,0,7","False,bad_touch,0,8","False,bad_touch,0,9",...,"True,yellow_card,6,6","True,yellow_card,6,8","True,yellow_card,6,9","True,yellow_card,7,0","True,yellow_card,7,2","True,yellow_card,7,3","True,yellow_card,7,6","True,yellow_card,7,9","True,yellow_card,8,1","True,yellow_card,9,4"
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"False,bad_touch,0,0",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,1",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,2",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,3",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,4",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"True,yellow_card,7,3",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"True,yellow_card,7,6",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"True,yellow_card,7,9",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"True,yellow_card,8,1",0.000345,0.000345,0.000346,0.000346,0.000343,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035


In [5]:
def pred(tok):
    return vocab[counts[tok].argmax()]

test_accuracy = (
    df
    .query("match_id.isin(@test_groups)")
    .assign(
        y_true = lambda d: d.action_token.shift(-1),
        y_pred = lambda d: d.action_token.apply(pred)
    )
    .dropna()
    .assign(correct = lambda d: d.y_pred == d.y_true)
    ['correct'].value_counts(normalize = True)
)
test_accuracy

correct
False    0.575921
True     0.424079
Name: proportion, dtype: float64

In [6]:
def prob(tok1, tok2):
    return probs[tok1][tok2]

probs = (
    df
    .query("match_id.isin(@test_groups)")
    .assign(
        y_true = lambda d: d.action_token.shift(-1),
        y_pred = lambda d: d.action_token.apply(pred)
    )
    .dropna()
    .assign(prob = lambda d: d.apply(lambda x: prob(x.action_token, x.y_true), axis=1))
)

probs

Unnamed: 0,match_id,action_token,y_true,y_pred,prob
0,212,"False,pass,4,4","False,receival,6,5","False,receival,5,5",0.010901
1,212,"False,receival,6,5","False,pass,6,5","False,dribble,6,5",0.031392
2,212,"False,pass,6,5","False,interception,2,7","False,receival,6,3",0.000661
3,212,"False,interception,2,7","False,pass,2,7","False,dribble,2,7",0.054722
4,212,"False,pass,2,7","False,interception,4,8","False,receival,3,9",0.001967
...,...,...,...,...,...
3129,82,"False,pass,7,0","False,receival,8,0","False,receival,6,0",0.026809
3130,82,"False,receival,8,0","False,dribble,8,0","False,dribble,8,0",0.276458
3131,82,"False,dribble,8,0","False,tackle,9,0","False,pass,8,0",0.004378
3132,82,"False,tackle,9,0","False,dribble,9,0","False,dribble,9,0",0.024542


In [7]:
np.log(probs['prob']).sum()

-383490.09516024316