In [1]:
import pandas as pd
from lib.pitch_control import plot_pitch_control, KNNPitchControl, SpearmanPitchControl, FernandezPitchControl
from lib.draw import pitch
import matplotlib.pyplot as plt
from time import time
from tqdm import tqdm

In [2]:
df = pd.read_csv('data/all_goals.csv', low_memory=False)
df.edgecolor.fillna('white', inplace=True)
df.bgcolor.fillna('black', inplace=True)
df.sample(5)

Unnamed: 0,bgcolor,dx,dy,edgecolor,frame,play,player,player_num,team,x,y,z,coords,player_obj,num,name,play_id
4038,red,0.290751,0.014463,blue,103,Bayern 0 - [1] Liverpool,3007,,defense,78.104794,48.164377,0.0,,,,,1
50015,red,-0.370069,-0.055725,white,56,Liverpool [2] - 0 Man City,13,26.0,attack,44.436803,14.227521,0.0,,,,,10
39888,black,0.239421,0.003983,white,112,Liverpool [2] - 1 Newcastle,9776,,defense,75.618815,53.289366,0.0,,,,,14
49042,blue,0.055111,0.083781,white,165,Genk 0 - [3] Liverpool,11234,,defense,50.625571,17.891725,0.0,,,,,4
69387,yellow,-0.411018,-0.220276,black,126,Liverpool [1] - 0 Watford,6052,,defense,45.074174,55.909799,0.0,,,,,7


In [3]:
df.play.unique()

array(['Liverpool [3] - 0 Bournemouth', 'Bayern 0 - [1] Liverpool',
       'Fulham 0 - [1] Liverpool', 'Southampton 1 - [2] Liverpool',
       'Liverpool [2] - 0 Porto', 'Porto 0 - [2] Liverpool',
       'Liverpool [4] - 0 Barcelona', 'Liverpool [1] - 0 Wolves',
       'Liverpool [3] - 0 Norwich', 'Liverpool [2] - 1 Chelsea',
       'Liverpool [2] - 1 Newcastle', 'Liverpool [2] - 0 Salzburg',
       'Genk 0 - [3] Liverpool', 'Liverpool [2] - 0 Man City',
       'Liverpool [1] - 0 Everton', 'Liverpool [2] - 0 Everton',
       'Bournemouth 0 - 3 Liverpool', 'Liverpool [1] - 0 Watford',
       'Leicester 0 - [3] Liverpool', 'Barcelona 1 - [2] Real Madrid'],
      dtype=object)

In [4]:
# Initialize results file
import os

def open_files(results_file_path = 'res/results.csv', grids_file_path = 'res/grids.csv'):
    if os.path.isfile(results_file_path):
        df_results = pd.read_csv(results_file_path)
        csv_results = open(results_file_path, 'a')
        csv_grids = open(grids_file_path, 'a')
    else:
        df_results = pd.DataFrame(columns=['play_frame_id', 'model'])
        csv_results = open(results_file_path, 'w')
        csv_results.write('play_frame_id,model,play,frame,inference_time\n')
        csv_grids = open(grids_file_path, 'w')
        csv_grids.write('play_frame_id,model_name' + ''.join(map(lambda x: x, [f',c{i}' for i in range(7314)])) + '\n')

    return df_results, csv_results, csv_grids

In [5]:
def close_and_reopen_files(csv_results, csv_grids, results_file_path = 'res/results.csv', grids_file_path = 'res/grids.csv'):
    csv_results.close()
    csv_grids.close()
    return open(results_file_path, 'a'), open(grids_file_path, 'a')

In [6]:
def run_model(csv_results, csv_grids, model, model_name, play_frame_id, play, frame_no, df_frame, df_play=None):
    t0 = time()
    if df_play is None:
        model.grid['control'] = model.predict(df_frame).round(2)
    else:
        model.grid['control'] = model.predict(df_frame, df_play).round(2)
    inference_time = time() - t0
    
    model.grid['play_frame_id'] = play_frame_id
    model.grid['model'] = model_name
    
    csv_results.write(f'{play_frame_id},{model_name},{play},{frame_no},{inference_time}\n')
    csv_grids.write(f'{play_frame_id},{model_name}' + ''.join(map(lambda x: x, [f',{i}' for i in model.grid['control'].tolist()])) + '\n')

In [7]:
def run_pitch_control_model(df, model, model_name, requires_full_data=False):
    df_results, csv_results, csv_grids = open_files()

    play_frame_id = 0
    for play in df.play.unique():
        df_play = df[df.play == play]
        frames_done = df_results[df_results.model == model_name].play_frame_id.unique()
        for frame_no in tqdm(df_play.frame.unique()):
            play_frame_id += 1
            if play_frame_id in frames_done:
                continue

            df_frame = df_play[df_play.frame == frame_no]
            
            if requires_full_data:
                run_model(csv_results, csv_grids, model, model_name, play_frame_id, play, frame_no, df_frame, df_play)
            else:
                run_model(csv_results, csv_grids, model, model_name, play_frame_id, play, frame_no, df_frame)

            csv_results, csv_grids = close_and_reopen_files(csv_results, csv_grids)

In [8]:
run_knn_pitch_control_voronoi = False
run_knn_pitch_control_spearman = True
run_knn_pitch_control_fernandez = True
run_spearman_pitch_control = False
run_fernandez_pitch_control = False

In [9]:
if run_knn_pitch_control_voronoi:
    knn_pitch_control_voronoi = KNNPitchControl()
    run_pitch_control_model(df, knn_pitch_control_voronoi, 'KNN (Voronoi)')

if run_knn_pitch_control_spearman:
    knn_pitch_control_spearman = KNNPitchControl(lags=[5, 15, 25], smoothing=6)
    run_pitch_control_model(df, knn_pitch_control_spearman, 'KNN (Spearman)')

if run_knn_pitch_control_fernandez:
    knn_pitch_control_fernandez = KNNPitchControl(lags=[5, 15, 25], smoothing=6, distance_basis=350)
    run_pitch_control_model(df, knn_pitch_control_fernandez, 'KNN (Fernandez)')

if run_spearman_pitch_control:
    spearman_pitch_control = SpearmanPitchControl()
    run_pitch_control_model(df, spearman_pitch_control, 'Spearman Pitch Control')

if run_fernandez_pitch_control:
    fernandez_pitch_control = FernandezPitchControl()
    run_pitch_control_model(df, fernandez_pitch_control, 'Fernandez Pitch Control', requires_full_data=True)

100%|██████████| 155/155 [00:00<00:00, 105589.92it/s]
100%|██████████| 165/165 [00:00<00:00, 122880.00it/s]
100%|██████████| 183/183 [00:00<00:00, 120742.12it/s]
100%|██████████| 257/257 [00:10<00:00, 23.71it/s] 
100%|██████████| 195/195 [02:04<00:00,  1.56it/s]
100%|██████████| 257/257 [02:54<00:00,  1.47it/s]
100%|██████████| 139/139 [01:26<00:00,  1.60it/s]
100%|██████████| 157/157 [01:37<00:00,  1.62it/s]
100%|██████████| 149/149 [01:32<00:00,  1.62it/s]
100%|██████████| 195/195 [02:09<00:00,  1.50it/s]
100%|██████████| 149/149 [01:35<00:00,  1.56it/s]
100%|██████████| 191/191 [02:01<00:00,  1.57it/s]
100%|██████████| 183/183 [01:55<00:00,  1.58it/s]
100%|██████████| 167/167 [01:48<00:00,  1.54it/s]
100%|██████████| 199/199 [02:10<00:00,  1.53it/s]
100%|██████████| 287/287 [03:08<00:00,  1.52it/s]
100%|██████████| 171/171 [01:47<00:00,  1.59it/s]
100%|██████████| 225/225 [02:19<00:00,  1.61it/s]
100%|██████████| 125/125 [01:17<00:00,  1.60it/s]
100%|██████████| 289/289 [03:01<00:00