In [1]:
import pandas as pd
from lib.pitch_control import plot_pitch_control, KNNPitchControl, SpearmanPitchControl, FernandezPitchControl
from lib.draw import pitch
import matplotlib.pyplot as plt
from time import time
from tqdm import tqdm

In [2]:
df = pd.read_csv('data/all_goals.csv', low_memory=False)
df.edgecolor.fillna('white', inplace=True)
df.bgcolor.fillna('black', inplace=True)
df.sample(5)

Unnamed: 0,bgcolor,dx,dy,edgecolor,frame,play,player,player_num,team,x,y,z,coords,player_obj,num,name,play_id
56528,red,-0.134625,-0.013595,white,77,Liverpool [1] - 0 Everton,15025,,attack,76.006661,66.4033,0.0,,,,,6
25643,red,-0.002955,0.120871,white,64,Liverpool [4] - 0 Barcelona,13,27.0,attack,92.566989,36.220693,0.0,,,,,17
26688,yellow,0.224974,0.112845,maroon,136,Liverpool [4] - 0 Barcelona,856,,defense,99.317571,49.49721,0.0,,,,,17
13680,red,-0.423001,-0.137764,black,178,Southampton 1 - [2] Liverpool,15145,,defense,39.807289,44.107349,0.0,,,,,19
39514,red,0.000714,-0.000494,white,36,Liverpool [2] - 1 Newcastle,7813,32.0,attack,58.969742,31.249171,0.0,,,,,14


In [3]:
df.play.unique()

array(['Liverpool [3] - 0 Bournemouth', 'Bayern 0 - [1] Liverpool',
       'Fulham 0 - [1] Liverpool', 'Southampton 1 - [2] Liverpool',
       'Liverpool [2] - 0 Porto', 'Porto 0 - [2] Liverpool',
       'Liverpool [4] - 0 Barcelona', 'Liverpool [1] - 0 Wolves',
       'Liverpool [3] - 0 Norwich', 'Liverpool [2] - 1 Chelsea',
       'Liverpool [2] - 1 Newcastle', 'Liverpool [2] - 0 Salzburg',
       'Genk 0 - [3] Liverpool', 'Liverpool [2] - 0 Man City',
       'Liverpool [1] - 0 Everton', 'Liverpool [2] - 0 Everton',
       'Bournemouth 0 - 3 Liverpool', 'Liverpool [1] - 0 Watford',
       'Leicester 0 - [3] Liverpool', 'Barcelona 1 - [2] Real Madrid'],
      dtype=object)

In [4]:
# Initialize results file
import os

def open_files(results_file_path = 'res/results.csv', grids_file_path = 'res/grids.csv'):
    if os.path.isfile(results_file_path):
        df_results = pd.read_csv(results_file_path)
        csv_results = open(results_file_path, 'a')
        csv_grids = open(grids_file_path, 'a')
    else:
        df_results = pd.DataFrame(columns=['play_frame_id', 'model'])
        csv_results = open(results_file_path, 'w')
        csv_results.write('play_frame_id,model,play,frame,inference_time\n')
        csv_grids = open(grids_file_path, 'w')
        csv_grids.write('play_frame_id,model_name' + ''.join(map(lambda x: x, [f',c{i}' for i in range(7314)])) + '\n')

    return df_results, csv_results, csv_grids

In [5]:
def close_and_reopen_files(csv_results, csv_grids, results_file_path = 'res/results.csv', grids_file_path = 'res/grids.csv'):
    csv_results.close()
    csv_grids.close()
    return open(results_file_path, 'a'), open(grids_file_path, 'a')

In [6]:
def run_model(csv_results, csv_grids, model, model_name, play_frame_id, play, frame_no, df_frame, df_play=None):
    t0 = time()
    if df_play is None:
        model.grid['control'] = model.predict(df_frame).round(2)
    else:
        model.grid['control'] = model.predict(df_frame, df_play).round(2)
    inference_time = time() - t0
    
    model.grid['play_frame_id'] = play_frame_id
    model.grid['model'] = model_name
    
    csv_results.write(f'{play_frame_id},{model_name},{play},{frame_no},{inference_time}\n')
    csv_grids.write(f'{play_frame_id},{model_name}' + ''.join(map(lambda x: x, [f',{i}' for i in model.grid['control'].tolist()])) + '\n')

In [7]:
def run_pitch_control_model(df, model, model_name, requires_full_data=False):
    df_results, csv_results, csv_grids = open_files()

    play_frame_id = 0
    for play in df.play.unique():
        df_play = df[df.play == play]
        frames_done = df_results[df_results.model == model_name].play_frame_id.unique()
        for frame_no in tqdm(df_play.frame.unique()):
            play_frame_id += 1
            if play_frame_id in frames_done:
                continue

            df_frame = df_play[df_play.frame == frame_no]
            
            if requires_full_data:
                run_model(csv_results, csv_grids, model, model_name, play_frame_id, play, frame_no, df_frame, df_play)
            else:
                run_model(csv_results, csv_grids, model, model_name, play_frame_id, play, frame_no, df_frame)

            csv_results, csv_grids = close_and_reopen_files(csv_results, csv_grids)

In [8]:
run_knn_pitch_control_voronoi = False
run_knn_pitch_control_spearman = True
run_knn_pitch_control_fernandez = True
run_spearman_pitch_control = False
run_fernandez_pitch_control = False

In [9]:
if run_knn_pitch_control_voronoi:
    knn_pitch_control_voronoi = KNNPitchControl()
    run_pitch_control_model(df, knn_pitch_control_voronoi, 'KNN (Voronoi)')

if run_knn_pitch_control_spearman:
    knn_pitch_control_spearman = KNNPitchControl(lags=[5, 15, 25], smoothing=6)
    run_pitch_control_model(df, knn_pitch_control_spearman, 'KNN (Spearman)')

if run_knn_pitch_control_fernandez:
    knn_pitch_control_fernandez = KNNPitchControl(lags=[5, 15, 25], smoothing=6, distance_basis=350)
    run_pitch_control_model(df, knn_pitch_control_fernandez, 'KNN (Fernandez)')

if run_spearman_pitch_control:
    spearman_pitch_control = SpearmanPitchControl()
    run_pitch_control_model(df, spearman_pitch_control, 'Spearman Pitch Control')

if run_fernandez_pitch_control:
    fernandez_pitch_control = FernandezPitchControl()
    run_pitch_control_model(df, fernandez_pitch_control, 'Fernandez Pitch Control', requires_full_data=True)

100%|██████████| 155/155 [00:00<00:00, 119726.91it/s]
100%|██████████| 165/165 [00:00<00:00, 103679.42it/s]
100%|██████████| 183/183 [00:00<00:00, 103639.97it/s]
100%|██████████| 257/257 [00:00<00:00, 102445.93it/s]
100%|██████████| 195/195 [00:00<00:00, 164004.27it/s]
100%|██████████| 257/257 [00:00<00:00, 110444.28it/s]
100%|██████████| 139/139 [00:00<00:00, 131991.91it/s]
100%|██████████| 157/157 [00:00<00:00, 135439.27it/s]
100%|██████████| 149/149 [00:00<00:00, 135007.84it/s]
100%|██████████| 195/195 [00:00<00:00, 97021.27it/s]
100%|██████████| 149/149 [00:00<00:00, 93109.55it/s]
100%|██████████| 191/191 [00:00<00:00, 105451.11it/s]
100%|██████████| 183/183 [00:00<00:00, 97085.46it/s]
100%|██████████| 167/167 [00:00<00:00, 89217.78it/s]
100%|██████████| 199/199 [00:00<00:00, 94259.34it/s]
100%|██████████| 287/287 [00:00<00:00, 116070.32it/s]
100%|██████████| 171/171 [00:00<00:00, 143101.75it/s]
100%|██████████| 225/225 [00:00<00:00, 133105.56it/s]
100%|██████████| 125/125 [00:00<0