In [1]:
import os
import sys
import warnings
import pandas as pd
from tqdm.notebook import tqdm

base_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
print(f"base_path: {base_path}")
sys.path.append(base_path)

base_path: /root/sr-press


In [2]:
from express.databases import SQLiteDatabase
from express.datasets import PressingDataset

from express import features as fs
from express import labels as ls

In [3]:
TRAIN_DB_PATH = os.path.join(base_path, "stores/train_database.sqlite")
TEST_DB_PATH = os.path.join(base_path, "stores/test_database.sqlite")

train_db = SQLiteDatabase(TRAIN_DB_PATH)
test_db = SQLiteDatabase(TEST_DB_PATH)

print("train_db:", train_db)
print("test_db:", test_db)

train_db: <express.databases.sqlite.SQLiteDatabase object at 0x7fbdd7dbd6d0>
test_db: <express.databases.sqlite.SQLiteDatabase object at 0x7fbdd7dbd0d0>


In [4]:
print(train_db.games().shape, test_db.games().shape)

(136, 11) (64, 11)


In [5]:
all_features = [f.__name__ for f in fs.all_features]
all_labels = [f.__name__ for f in ls.all_labels]
print("Features:", all_features)
print("Labels:", all_labels)

Features: ['actiontype', 'actiontype_onehot', 'result', 'result_onehot', 'bodypart', 'bodypart_onehot', 'time', 'startlocation', 'relative_startlocation', 'endlocation', 'relative_endlocation', 'startpolar', 'endpolar', 'movement', 'team', 'time_delta', 'space_delta', 'goalscore', 'angle', 'under_pressure', 'speed', 'freeze_frame_360', 'dist_opponent', 'defenders_in_3m_radius', 'closest_11_players', 'get_column_sum_to_player']
Labels: ['concede_shots', 'counterpress', 'possession_change_by_2_actions', 'possession_change_by_4_actions', 'possession_change_by_6_actions', 'possession_change_by_2_actions_and_3m_distance', 'possession_change_by_4_actions_and_3m_distance', 'possession_change_by_6_actions_and_3m_distance', 'possession_change_by_2_actions_and_5m_distance', 'possession_change_by_4_actions_and_5m_distance', 'possession_change_by_6_actions_and_5m_distance', 'possession_change_by_2_actions_and_7m_distance', 'possession_change_by_4_actions_and_7m_distance', 'possession_change_by_6_a

In [None]:
train_dataset = PressingDataset(
    path= os.path.join(base_path, "stores", "datasets", "train"), 
    xfns=["startlocation", "closest_11_players", "freeze_frame_360"],
    yfns=["possession_change_by_5_seconds"], 
    load_cached =False,
    nb_prev_actions = 3,
)

test_dataset = PressingDataset(
    path= os.path.join(base_path, "stores", "datasets", "test"), 
    xfns=["startlocation", "closest_11_players", "freeze_frame_360"],
    yfns=["possession_change_by_5_seconds"], 
    load_cached =False,
    nb_prev_actions = 3,
)

In [None]:
# Loading Time: 27m
train_dataset.create(train_db)
test_dataset.create(test_db)

  0%|          | 0/136 [00:00<?, ?it/s]

100%|██████████| 136/136 [18:42<00:00,  8.25s/it]
100%|██████████| 136/136 [00:22<00:00,  5.99it/s]
100%|██████████| 64/64 [08:00<00:00,  7.50s/it]
100%|██████████| 64/64 [00:10<00:00,  6.04it/s]


In [8]:
train_dataset.features

Unnamed: 0_level_0,Unnamed: 1_level_0,start_x_a0,start_y_a0,start_x_a1,start_y_a1,start_x_a2,start_y_a2,start_x_a3,start_y_a3,teammate_1_x_a0,teammate_1_y_a0,...,opponent_10_x_a3,opponent_10_y_a3,opponent_10_distance_a3,opponent_11_x_a3,opponent_11_y_a3,opponent_11_distance_a3,freeze_frame_360_a0,freeze_frame_360_a1,freeze_frame_360_a2,freeze_frame_360_a3
game_id,action_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
3788741,3,41.7375,61.285,31.2375,42.585,28.0000,43.945,52.0625,34.425,53.579218,52.737630,...,,,,,,,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': False, 'actor': False, 'keeper':..."
3788741,19,84.5250,59.500,41.4750,59.500,38.5875,54.400,29.9250,34.170,76.496764,53.723000,...,,,,,,,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ..."
3788741,27,27.2125,13.940,76.3875,58.055,72.1875,68.000,83.0375,56.865,21.941713,14.556432,...,,,,,,,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...",,"[{'teammate': True, 'actor': False, 'keeper': ..."
3788741,31,75.7750,60.690,25.9000,9.520,25.2875,11.815,24.4125,10.370,73.332487,65.151690,...,,,,,,,"[{'teammate': False, 'actor': False, 'keeper':...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': False, 'actor': False, 'keeper':...","[{'teammate': True, 'actor': False, 'keeper': ..."
3788741,35,79.9750,58.055,20.9125,10.710,29.2250,7.310,25.7250,10.965,78.241322,57.741576,...,,,,,,,"[{'teammate': False, 'actor': False, 'keeper':...","[{'teammate': True, 'actor': False, 'keeper': ...",,"[{'teammate': True, 'actor': False, 'keeper': ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3943043,2283,82.0750,20.230,37.1000,68.000,100.7125,36.805,98.5250,33.235,67.832653,16.084311,...,91.875211,35.69973,7.091868,91.756405,46.553855,14.940073,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': True, 'keeper': F...",,"[{'teammate': True, 'actor': False, 'keeper': ..."
3943043,2286,34.3000,61.710,33.3375,0.425,91.9625,20.655,67.9000,0.000,42.269977,58.389380,...,,,,,,,"[{'teammate': True, 'actor': True, 'keeper': F...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': True, 'keeper': T...","[{'teammate': True, 'actor': True, 'keeper': F..."
3943043,2302,47.7750,31.790,35.3500,12.070,38.5875,11.560,35.9625,7.055,50.047145,37.296331,...,,,,,,,"[{'teammate': True, 'actor': False, 'keeper': ...",,,
3943043,2304,49.6125,35.700,47.7750,31.790,69.6500,55.930,66.4125,56.440,49.390287,32.930291,...,,,,,,,"[{'teammate': True, 'actor': False, 'keeper': ...",,,


In [14]:
train_dataset.labels

Unnamed: 0_level_0,Unnamed: 1_level_0,possession_change_by_5_seconds
game_id,action_id,Unnamed: 2_level_1
3788741,3,True
3788741,19,True
3788741,27,True
3788741,31,False
3788741,35,False
...,...,...
3943043,2283,True
3943043,2286,False
3943043,2302,False
3943043,2304,False


In [15]:
train_dataset.labels["possession_change_by_5_seconds"].value_counts()

False    26609
True     10339
Name: possession_change_by_5_seconds, dtype: int64

In [16]:
test_dataset.labels["possession_change_by_5_seconds"].value_counts()

False    10214
True      4290
Name: possession_change_by_5_seconds, dtype: int64