In [1]:
import os
import sys
import warnings
import pandas as pd
from tqdm.notebook import tqdm

base_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
print(f"base_path: {base_path}")
sys.path.append(base_path)

base_path: /home/toc3/press


In [2]:
from express.databases import SQLiteDatabase
from express.datasets import PressingDataset

from express import features as fs
from express import labels as ls

In [3]:
TRAIN_DB_PATH = os.path.join(base_path, "stores/train_database.sqlite")
TEST_DB_PATH = os.path.join(base_path, "stores/test_database.sqlite")

train_db = SQLiteDatabase(TRAIN_DB_PATH)
test_db = SQLiteDatabase(TEST_DB_PATH)

print("train_db:", train_db)
print("test_db:", test_db)

train_db: <express.databases.sqlite.SQLiteDatabase object at 0x7f2a45bb5e80>
test_db: <express.databases.sqlite.SQLiteDatabase object at 0x7f2a45bb5a90>


In [4]:
print(train_db.games().shape, test_db.games().shape)

(151, 11) (49, 11)


In [5]:
all_features = [f.__name__ for f in fs.all_features]
all_labels = [f.__name__ for f in ls.all_labels]
print("Features:", all_features)
print("Labels:", all_labels)

Features: ['actiontype', 'actiontype_onehot', 'result', 'result_onehot', 'bodypart', 'bodypart_onehot', 'time', 'startlocation', 'relative_startlocation', 'endlocation', 'relative_endlocation', 'startpolar', 'endpolar', 'movement', 'team', 'time_delta', 'space_delta', 'goalscore', 'angle', 'under_pressure', 'packing_rate', 'ball_height_onehot', 'speed', 'nb_opp_in_path', 'dist_opponent', 'defenders_in_3m_radius', 'closest_3_players', 'closest_11_players', 'expected_3_receiver_and_presser_by_distance']
Labels: ['concede_shots', 'counterpress']


In [6]:
train_dataset = PressingDataset(
    path= os.path.join(base_path, "stores", "datasets", "train"), 
    xfns=["startlocation", "closest_11_players"],
    yfns=["counterpress"], 
    load_cached =False,
    nb_prev_actions = 3,
)

test_dataset = PressingDataset(
    path= os.path.join(base_path, "stores", "datasets", "test"), 
    xfns=["startlocation", "closest_11_players"],
    yfns=["counterpress"], 
    load_cached =False,
    nb_prev_actions = 3,
)

In [7]:
train_dataset.create(train_db)
test_dataset.create(test_db)

Output()

Output()

Output()

Output()

In [14]:
train_dataset.features

Unnamed: 0,start_x_a0,start_y_a0,start_x_a1,start_y_a1,start_x_a2,start_y_a2,teammate_1_x_a0,teammate_1_y_a0,teammate_1_distance_a0,teammate_2_x_a0,...,opponent_8_distance_a2,opponent_9_x_a2,opponent_9_y_a2,opponent_9_distance_a2,opponent_10_x_a2,opponent_10_y_a2,opponent_10_distance_a2,opponent_11_x_a2,opponent_11_y_a2,opponent_11_distance_a2
0,41.7375,61.285,73.7625,25.415,77.0000,24.055,53.579218,52.737630,14.604240,34.377131,...,33.197579,41.340865,36.142563,37.652132,44.092050,51.091348,42.589874,,,
1,84.5250,59.500,41.4750,59.500,38.5875,54.400,76.496764,53.723000,9.890718,65.244864,...,,,,,,,,,,
2,27.2125,13.940,28.6125,9.945,32.8125,0.000,21.941713,14.556432,5.306711,29.895983,...,,,,,,,,,,
3,75.7750,60.690,79.1000,58.480,79.7125,56.185,73.332487,65.151690,5.086506,79.741558,...,14.911395,83.436499,31.972434,24.497275,,,,,,
4,79.9750,58.055,84.0875,57.290,75.7750,60.690,78.241322,57.741576,1.761781,81.094272,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
39041,82.0750,20.230,67.9000,0.000,4.2875,31.195,67.832653,16.084311,14.833448,58.134253,...,,,,,,,,,,
39042,34.3000,61.710,71.6625,67.575,13.0375,47.345,42.269977,58.389380,8.634063,41.571906,...,,,,,,,,,,
39043,47.7750,31.790,69.6500,55.930,66.4125,56.440,50.047145,37.296331,5.956704,41.740027,...,,,,,,,,,,
39044,49.6125,35.700,47.7750,31.790,47.7750,31.790,49.390287,32.930291,2.778608,42.628292,...,31.145234,39.830609,64.457432,33.619555,,,,,,


In [15]:
test_dataset.features

Unnamed: 0,start_x_a0,start_y_a0,start_x_a1,start_y_a1,start_x_a2,start_y_a2,teammate_1_x_a0,teammate_1_y_a0,teammate_1_distance_a0,teammate_2_x_a0,...,opponent_8_distance_a2,opponent_9_x_a2,opponent_9_y_a2,opponent_9_distance_a2,opponent_10_x_a2,opponent_10_y_a2,opponent_10_distance_a2,opponent_11_x_a2,opponent_11_y_a2,opponent_11_distance_a2
0,85.6625,55.760,76.4750,59.075,62.0375,67.575,,,,,...,,,,,,,,,,
1,50.9250,57.375,56.8750,57.460,59.1500,38.420,53.924928,46.867274,10.927574,44.996415,...,27.756077,35.112788,53.538855,28.396608,,,,,,
2,56.4375,10.880,61.3375,12.240,67.7250,25.075,47.181661,17.103069,11.153347,60.019528,...,,,,,,,,,,
3,65.0125,43.010,56.4375,10.880,64.2250,14.280,64.424627,30.495363,12.528437,56.708205,...,31.876755,41.370071,50.755914,43.044629,,,,,,
4,39.6375,24.990,59.5875,27.200,65.1875,46.070,41.608359,29.447944,4.874172,42.013672,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12401,51.3625,38.590,58.1000,34.255,58.4500,32.980,44.518551,38.812903,6.847578,57.624924,...,22.092892,36.395506,27.607061,22.699542,42.342041,54.637436,26.990940,,,
12402,34.4750,38.080,27.9125,34.255,21.2625,33.320,28.821107,39.729508,5.889599,43.932708,...,25.822966,,,,,,,,,
12403,4.9875,58.395,16.0125,62.730,25.3750,55.675,7.743661,58.476807,2.757374,8.010600,...,26.393789,13.937399,29.743701,28.341683,3.526294,32.509401,31.843538,16.310699,19.229769,37.555511
12404,6.6500,60.520,4.9875,58.395,16.0125,62.730,5.021726,59.088590,2.167997,7.086932,...,,,,,,,,,,


In [17]:
train_dataset.labels

Unnamed: 0,counterpress
0,False
1,False
2,False
3,True
4,True
...,...
39041,False
39042,False
39043,False
39044,False


In [18]:
test_dataset.labels

Unnamed: 0,counterpress
0,False
1,False
2,False
3,False
4,False
...,...
12401,False
12402,False
12403,False
12404,False


In [19]:
train_dataset.labels["counterpress"].value_counts()

False    31605
True      7441
Name: counterpress, dtype: int64

In [20]:
test_dataset.labels["counterpress"].value_counts()

False    9954
True     2452
Name: counterpress, dtype: int64