In [1]:
import json
with open('action_types.json', 'r') as f:
    action_types = json.load(f)
action_types

{'0': 'pass',
 '1': 'cross',
 '2': 'throw_in',
 '3': 'freekick_crossed',
 '4': 'freekick_short',
 '5': 'corner_crossed',
 '6': 'corner_short',
 '7': 'take_on',
 '8': 'foul',
 '9': 'tackle',
 '10': 'interception',
 '11': 'shot',
 '12': 'shot_penalty',
 '13': 'shot_freekick',
 '14': 'keeper_save',
 '15': 'keeper_claim',
 '16': 'keeper_punch',
 '17': 'keeper_pick_up',
 '18': 'clearance',
 '19': 'bad_touch',
 '20': 'non_action',
 '21': 'dribble',
 '22': 'goalkick',
 '23': 'receival',
 '24': 'interception',
 '25': 'out',
 '26': 'offside',
 '27': 'goal',
 '28': 'owngoal',
 '29': 'yellow_card',
 '30': 'red_card',
 '31': 'corner',
 '32': 'freekick'}

In [2]:
import pandas as pd
from preprocess_data import *

df = (
    pd.read_csv("WSL_actions.csv", index_col = 0)
    .pipe(add_coordinate_bins, n_bins_x = 10, n_bins_y = 10)
    .pipe(add_team_as_dummy)
    .pipe(get_action_type_names, action_types)
    .pipe(get_action_tokens)
    .assign(
        group_id = lambda d: d.groupby(['game_id', 'period_id']).ngroup(),
        action_token = lambda d: pd.Categorical(d.action_token)
    )
    [['group_id', 'action_token']]
)

vocab = df['action_token'].cat.categories

df

Unnamed: 0,group_id,action_token
0,630,"True,pass,4,4"
1,630,"True,receival,4,5"
2,630,"True,dribble,4,5"
3,630,"True,pass,5,5"
4,630,"True,receival,6,4"
...,...,...
3130,165,"False,receival,8,0"
3131,165,"False,dribble,8,0"
3132,165,"False,tackle,9,0"
3133,165,"False,dribble,9,0"


In [3]:
from numpy.random import choice, seed

seed(42)
train_groups = choice(df['group_id'].unique(), int(0.8 * df['group_id'].nunique()), replace = False)
train_groups[:5]

array([ 55, 363, 406, 428, 402], dtype=int64)

In [4]:
counts = pd.read_csv("transition_counts.csv", index_col = 0)
counts_smoothed = counts + 1
probs = counts_smoothed / counts_smoothed.sum()
probs

Unnamed: 0_level_0,"False,bad_touch,0,0","False,bad_touch,0,1","False,bad_touch,0,2","False,bad_touch,0,3","False,bad_touch,0,4","False,bad_touch,0,5","False,bad_touch,0,6","False,bad_touch,0,7","False,bad_touch,0,8","False,bad_touch,0,9",...,"True,yellow_card,6,6","True,yellow_card,6,8","True,yellow_card,6,9","True,yellow_card,7,0","True,yellow_card,7,2","True,yellow_card,7,3","True,yellow_card,7,6","True,yellow_card,7,9","True,yellow_card,8,1","True,yellow_card,9,4"
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"False,bad_touch,0,0",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,1",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,2",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,3",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"False,bad_touch,0,4",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"True,yellow_card,7,3",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"True,yellow_card,7,6",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"True,yellow_card,7,9",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035
"True,yellow_card,8,1",0.000344,0.000345,0.000346,0.000346,0.000344,0.000344,0.000347,0.000344,0.000342,0.000342,...,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035,0.00035


In [5]:
def pred(tok):
    return vocab[counts[tok].argmax()]

test_accuracy = (
    df
    .query("~group_id.isin(@train_groups)")
    .assign(
        y_true = lambda d: d.action_token.shift(-1),
        y_pred = lambda d: d.action_token.apply(pred)
    )
    .dropna()
    .assign(correct = lambda d: d.y_pred == d.y_true)
    ['correct'].value_counts(normalize = True)
)
test_accuracy

correct
False    0.582035
True     0.417965
Name: proportion, dtype: float64

In [6]:
def prob(tok1, tok2):
    return probs[tok1][tok2]

probs = (
    df
    .query("~group_id.isin(@train_groups)")
    .assign(
        y_true = lambda d: d.action_token.shift(-1),
        y_pred = lambda d: d.action_token.apply(pred)
    )
    .dropna()
    .assign(prob = lambda d: d.apply(lambda x: prob(x.action_token, x.y_true), axis=1))
)

probs

Unnamed: 0,group_id,action_token,y_true,y_pred,prob
1650,631,"False,pass,5,5","False,receival,3,1","False,receival,4,4",0.000971
1651,631,"False,receival,3,1","False,dribble,3,1","False,dribble,3,1",0.348099
1652,631,"False,dribble,3,1","False,pass,3,1","False,pass,3,1",0.182469
1653,631,"False,pass,3,1","True,interception,6,0","False,receival,4,0",0.000192
1654,631,"True,interception,6,0","True,pass,6,0","True,dribble,6,0",0.006196
...,...,...,...,...,...
2657,31,"False,foul,0,1","False,freekick,0,1","False,freekick,0,1",0.005207
2658,31,"False,freekick,0,1","False,interception,4,1","False,interception,0,5",0.001041
2659,31,"False,interception,4,1","False,pass,4,1","False,pass,4,1",0.097674
2660,31,"False,pass,4,1","False,interception,3,1","False,receival,4,0",0.005691


In [7]:
np.log(probs['prob']).sum()

-726670.0684217903