In [9]:
%load_ext autoreload
%autoreload 2
import numpy as np
import os
import pickle
from sklearn.model_selection import train_test_split
import xgboost as xgb
from sklearn.metrics import accuracy_score, roc_auc_score
import sys
sys.path.append(os.path.abspath(".."))
import torch

from datasets import PressingSequenceDataset, SoccerMapInputDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
data_path = "/data/MHL/pressing-intensity-feat"

with open(f"{data_path}/train_dataset.pkl", "rb") as f:
    train_dataset = pickle.load(f)

with open(f"{data_path}/valid_dataset.pkl", "rb") as f:
    valid_dataset = pickle.load(f)


with open(f"{data_path}/test_dataset.pkl", "rb") as f:
    test_dataset = pickle.load(f)

len(train_dataset),len(valid_dataset), len(test_dataset)

(6568, 601, 655)

In [13]:
sample = train_dataset[0]
sample.keys()

dict_keys(['features', 'pressing_intensity', 'label', 'pressed_id', 'presser_id', 'agent_order', 'match_info'])

In [14]:
print(f"Features : {sample['features'].shape}")
print(f"Pressing Intensity : {sample['pressing_intensity'].shape}")
print(f"Labels : {sample['label']}")
print(f"Presser ID : {sample['presser_id']}")
print(f"Players Order : {sample['agent_order']}")

Features : torch.Size([2, 23, 18])
Pressing Intensity : torch.Size([2, 11, 11])
Labels : 0
Presser ID : 77414
Players Order : ['188178', '250079', '250101', '250102', '500133', '500140', '500141', '500142', '62365', '62386', '77414', '187259', '343587', '408792', '500113', '500115', '500116', '500117', '500118', '500121', '500502', '83615', 'ball']


In [18]:
from config import FEAT_MIN, FEAT_MAX

for i in range(18):
    print(f"{FEAT_MIN[i]} ~ {FEAT_MAX[i]}")

-52.5 ~ 52.5
-32.0 ~ 32.0
-4.0 ~ 1.0
-5.0 ~ 3.0
0.0 ~ 5.0
-4.0 ~ 5.0
-4.0 ~ 5.0
0.0 ~ 6.0
0.0 ~ 1.0
0.0 ~ 1.0
21.0 ~ 100.0
-1.0 ~ 1.0
0.0 ~ 1.0
0.0 ~ 52.0
-1.0 ~ 1.0
-1.0 ~ 1.0
-1.0 ~ 1.0
-1.0 ~ 1.0


In [15]:
x_tensor_lst = [sample['features'] for sample in train_dataset]
x_tensor_lst = torch.cat(x_tensor_lst)
feature_cols = ['x', 'y', 'vx', 'vy', 'v', 'ax', 'ay', 'a']

for i in range(18):
    # print(f"{feature_cols[i]} : {x_tensor_lst[-1, ...][..., i].min()} ~ {x_tensor_lst[-1, ...][..., i].max()}")
    print(f"{x_tensor_lst[-1, ...][..., i].min()} ~ {x_tensor_lst[-1, ...][..., i].max()}")

-35.245399475097656 ~ 36.746498107910156
-23.04159927368164 ~ 31.23430061340332
-3.957848072052002 ~ 1.528892159461975
-4.297361373901367 ~ 2.6944620609283447
0.4733882546424866 ~ 4.297878265380859
-3.3421547412872314 ~ 4.828290939331055
-3.5408642292022705 ~ 4.212172985076904
0.023648617789149284 ~ 5.198577404022217
0.0 ~ 1.0
0.0 ~ 1.0
15.768238067626953 ~ 87.75919342041016
-0.5956571102142334 ~ 0.4275398254394531
0.8032388091087341 ~ 0.9998427629470825
0.0 ~ 51.872928619384766
-0.9545497894287109 ~ 0.9746968150138855
-0.9999775290489197 ~ 1.0
-0.9782698750495911 ~ 1.0
-0.9976794719696045 ~ 0.9850823283195496


In [17]:
x_tensor_lst = [sample['features'] for sample in valid_dataset]
x_tensor_lst = torch.cat(x_tensor_lst)
feature_cols = ['x', 'y', 'vx', 'vy', 'v', 'ax', 'ay', 'a']

for i in range(18):
    # print(f"{feature_cols[i]} : {x_tensor_lst[-1, ...][..., i].min()} ~ {x_tensor_lst[-1, ...][..., i].max()}")
    print(f"{x_tensor_lst[-1, ...][..., i].min()} ~ {x_tensor_lst[-1, ...][..., i].max()}")

-47.30929946899414 ~ 27.947900772094727
-26.779399871826172 ~ 24.279199600219727
-2.786034345626831 ~ 1.6843657493591309
-1.6886903047561646 ~ 4.080182075500488
0.12533733248710632 ~ 4.414178371429443
-2.033162832260132 ~ 20.059629440307617
-32.14739990234375 ~ 2.9072623252868652
0.14909246563911438 ~ 13.5
0.0 ~ 1.0
0.0 ~ 1.0
24.691688537597656 ~ 99.8106460571289
-0.3865571916103363 ~ 0.3389405906200409
0.9222654104232788 ~ 0.9999997019767761
0.0 ~ 53.62963104248047
0.0 ~ 0.9939579367637634
-0.8388556241989136 ~ 1.0
0.028892746195197105 ~ 1.0
-0.923808753490448 ~ 0.9995825290679932


In [16]:
x_tensor_lst = [sample['features'] for sample in test_dataset]
x_tensor_lst = torch.cat(x_tensor_lst)
feature_cols = ['x', 'y', 'vx', 'vy', 'v', 'ax', 'ay', 'a']

for i in range(18):
    # print(f"{feature_cols[i]} : {x_tensor_lst[-1, ...][..., i].min()} ~ {x_tensor_lst[-1, ...][..., i].max()}")
    print(f"{x_tensor_lst[-1, ...][..., i].min()} ~ {x_tensor_lst[-1, ...][..., i].max()}")

-22.494892120361328 ~ 50.170204162597656
-33.574310302734375 ~ 15.24506950378418
-6.2336344718933105 ~ 1.2372031211853027
-5.7715325355529785 ~ 1.795527696609497
1.2239831686019897 ~ 6.246331691741943
-2.944819688796997 ~ 5.235048770904541
-6.547811985015869 ~ 1.806138038635254
0.14878597855567932 ~ 6.0
0.0 ~ 1.0
0.0 ~ 1.0
3.920572519302368 ~ 75.00214385986328
-0.5290470719337463 ~ 0.9910622239112854
0.13340020179748535 ~ 0.9999033212661743
0.0 ~ 46.72101593017578
-0.9999926090240479 ~ 0.16790957748889923
-0.9999999403953552 ~ 1.0
0.09680623561143875 ~ 1.0
-0.9953032732009888 ~ 0.9049455523490906


# 1. XGBoost

In [101]:
num_seq = 150
num_agents = 11
use_pressing_intensity = True
selected_features_idx = [i for i in range(8)]

train_features = []
train_labels = []
for i in range(len(train_dataset)):
    sample = train_dataset[i]
    x_tensor = sample['features'][..., selected_features_idx][-1:]
    press_intensity = sample['pressing_intensity'][-1:]
    y_tensor = sample['label']
    
    # Flatten the sequence data: shape (sequence_length, num_features) -> (sequence_length*num_features,)
    feature_vector = x_tensor.flatten().numpy()
    if use_pressing_intensity:
        if press_intensity.shape[1] != num_agents:
            pad_tensor = torch.zeros(press_intensity.shape[0], num_agents-press_intensity.shape[1], press_intensity.shape[2])
            press_intensity = torch.cat([press_intensity, pad_tensor], dim=1)
            
        if press_intensity.shape[2] != num_agents:
            pad_tensor = torch.zeros(press_intensity.shape[0], press_intensity.shape[1], num_agents-press_intensity.shape[2])
            press_intensity = torch.cat([press_intensity, pad_tensor], dim=2)
        
        press_vector = press_intensity.flatten().numpy()
        # Concatenate the flattened sequence data with the pressing intensity
        feature_vector = np.concatenate((feature_vector, press_vector))
    train_features.append(feature_vector)
    train_labels.append(y_tensor.item())


test_features = []
test_labels = []
for i in range(len(test_dataset)):
    sample = test_dataset[i]
    x_tensor = sample['features'][..., selected_features_idx][-1:]
    press_intensity = sample['pressing_intensity'][-1:]
    y_tensor = sample['label']
    
    # Flatten the sequence data: shape (sequence_length, num_features) -> (sequence_length*num_features,)
    feature_vector = x_tensor.flatten().numpy()
    
    if use_pressing_intensity:
        if press_intensity.shape[1] != num_agents:
            pad_tensor = torch.zeros(press_intensity.shape[0], num_agents-press_intensity.shape[1], press_intensity.shape[2])
            press_intensity = torch.cat([press_intensity, pad_tensor], dim=1)
        if press_intensity.shape[2] != num_agents:
            pad_tensor = torch.zeros(press_intensity.shape[0], press_intensity.shape[1], num_agents-press_intensity.shape[2])
            press_intensity = torch.cat([press_intensity, pad_tensor], dim=2)
        
        press_vector = press_intensity.flatten().numpy()
        # Concatenate the flattened sequence data with the pressing intensity
        feature_vector = np.concatenate((feature_vector, press_vector))
    test_features.append(feature_vector)
    test_labels.append(y_tensor.item())

X_train = np.array(train_features)
y_train = np.array(train_labels)

# Split the data into training and testing sets (80% training, 20% testing)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

X_test = np.array(test_features)
y_test = np.array(test_labels)

In [102]:
def print_dataset_distribution(y_train, y_val, y_test):
    def _print_split(name, labels):
        total = len(labels)
        unique, counts = np.unique(labels, return_counts=True)
        print(f"{name} Set:")
        print(f"  Total samples: {total}")
        for label, count in zip(unique, counts):
            percent = (count / total) * 100
            print(f"    Label {label}: {count:>5} samples ({percent:5.2f}%)")
        print("-" * 40)

    print("\nüìä Dataset Distribution Summary")
    print("=" * 40)
    _print_split("Train", y_train)
    _print_split("Validation", y_val)
    _print_split("Test", y_test)

print_dataset_distribution(y_train, y_val, y_test)


üìä Dataset Distribution Summary
Train Set:
  Total samples: 5573
    Label 0:  4003 samples (71.83%)
    Label 1:  1570 samples (28.17%)
----------------------------------------
Validation Set:
  Total samples: 1394
    Label 0:  1029 samples (73.82%)
    Label 1:   365 samples (26.18%)
----------------------------------------
Test Set:
  Total samples: 867
    Label 0:   640 samples (73.82%)
    Label 1:   227 samples (26.18%)
----------------------------------------


In [103]:
# Create XGBoost DMatrix objects for train and test sets
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_val, label=y_val)

# Set XGBoost training parameters
params = {
    'objective': 'binary:logistic',  # binary classification
    'eval_metric': 'auc',            # evaluation metric: AUC
    'max_depth': 6,                  # maximum depth of trees
    'eta': 0.1,                      # learning rate
    'seed': 42
}

# Specify the watchlist to evaluate performance on training and test sets during training
watchlist = [(dtrain, 'train'), (dtest, 'eval')]
num_rounds = 100

In [104]:
# Train XGBoost model with early stopping on the evaluation set
bst = xgb.train(params, dtrain, num_rounds, watchlist, early_stopping_rounds=10)



[0]	train-auc:0.75959	eval-auc:0.61427
[1]	train-auc:0.79490	eval-auc:0.62917
[2]	train-auc:0.82032	eval-auc:0.64981
[3]	train-auc:0.83279	eval-auc:0.65480
[4]	train-auc:0.84991	eval-auc:0.65820
[5]	train-auc:0.86352	eval-auc:0.66439
[6]	train-auc:0.87259	eval-auc:0.66480
[7]	train-auc:0.87885	eval-auc:0.66355
[8]	train-auc:0.88916	eval-auc:0.66851
[9]	train-auc:0.89755	eval-auc:0.66977
[10]	train-auc:0.90254	eval-auc:0.67374
[11]	train-auc:0.91078	eval-auc:0.67134
[12]	train-auc:0.91646	eval-auc:0.67048
[13]	train-auc:0.92399	eval-auc:0.67064
[14]	train-auc:0.92672	eval-auc:0.67047
[15]	train-auc:0.93116	eval-auc:0.67298
[16]	train-auc:0.93318	eval-auc:0.67497
[17]	train-auc:0.93773	eval-auc:0.67577
[18]	train-auc:0.94166	eval-auc:0.67500
[19]	train-auc:0.94384	eval-auc:0.67588
[20]	train-auc:0.94658	eval-auc:0.67564
[21]	train-auc:0.94912	eval-auc:0.67658
[22]	train-auc:0.95171	eval-auc:0.67768
[23]	train-auc:0.95417	eval-auc:0.67648
[24]	train-auc:0.95753	eval-auc:0.67740
[25]	train

In [100]:
# W/O Pressing Intensity
# Get predictions on the test set
y_pred = bst.predict(dtest)
y_pred_label = (y_pred > 0.5).astype(int)

# Calculate and print evaluation metrics
accuracy = accuracy_score(y_val, y_pred_label)
auc = roc_auc_score(y_val, y_pred)

print("Test Accuracy: {:.4f}".format(accuracy))
print("Test AUC: {:.4f}".format(auc))

Test Accuracy: 0.7453
Test AUC: 0.6730


In [105]:
# W/ Pressing Intensity
# Get predictions on the test set
y_pred = bst.predict(dtest)
y_pred_label = (y_pred > 0.5).astype(int)

# Calculate and print evaluation metrics
accuracy = accuracy_score(y_val, y_pred_label)
auc = roc_auc_score(y_val, y_pred)

print("Test Accuracy: {:.4f}".format(accuracy))
print("Test AUC: {:.4f}".format(auc))

Test Accuracy: 0.7475
Test AUC: 0.6753


# 2. SoccerMap / exPress Evaluation

In [14]:
from datasets import exPressInputDataset
from tqdm import tqdm

In [15]:
data_path = "/data/MHL/pressing-intensity-feat"
train_dataset = exPressInputDataset(f"{data_path}/train_dataset.pkl")

Loading dataset from /data/MHL/pressing-intensity-feat/train_dataset.pkl...


In [16]:
train_dataset[0].keys()

dict_keys(['features', 'pressing_intensity', 'label', 'pressed_id', 'presser_id', 'agent_order', 'match_info'])

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

def custom_temporal_collate(batch):
    """
    Í∞ÄÎ≥Ä Í∏∏Ïù¥Ïùò ÏãúÍ≥ÑÏó¥ Îç∞Ïù¥ÌÑ∞Î•º Ìè¨Ìï®Ìïú Î∞∞ÏπòÎ•º Ï≤òÎ¶¨ÌïòÎäî collate_fn.
    
    Args:
        batch (list): DatasetÏùò __getitem__Ïù¥ Î∞òÌôòÌïòÎäî ÎîïÏÖîÎÑàÎ¶¨Îì§Ïùò Î¶¨Ïä§Ìä∏.
                      Ïòà: [{'features': [T1,A,F], ...}, {'features': [T2,A,F], ...}]
    """
    # 1. Î∞∞Ïπò ÎÇ¥Ïùò Îç∞Ïù¥ÌÑ∞Îì§ÏùÑ ÌÇ§(key)Î≥ÑÎ°ú Î∂ÑÎ¶¨ÌïòÏó¨ Í∞ÅÍ∞ÅÏùò Î¶¨Ïä§Ìä∏Ïóê Îã¥ÏäµÎãàÎã§.
    features_list = [item['features'] for item in batch]
    intensity_list = [item['pressing_intensity'] for item in batch]
    labels_list = [item['label'] for item in batch]
    
    # Î©îÌÉÄÎç∞Ïù¥ÌÑ∞
    pressed_id_list = [item['pressed_id'] for item in batch]
    presser_id_list = [item['presser_id'] for item in batch]
    agent_order_list = [item['agent_order'] for item in batch]
    match_info_list = [item['match_info'] for item in batch]

     # Ìå®Îî© Ï†Ñ, Í∞Å ÏãúÌÄÄÏä§Ïùò Ïã§Ï†ú Í∏∏Ïù¥Î•º Ï†ÄÏû•Ìï©ÎãàÎã§.
    seq_lengths = torch.tensor([f.shape[0] for f in features_list], dtype=torch.long)
    
    # 2. torch.nn.utils.rnn.pad_sequenceÎ•º ÏÇ¨Ïö©ÌïòÏó¨ ÏãúÌÄÄÏä§ Îç∞Ïù¥ÌÑ∞Îì§ÏùÑ Ìå®Îî©Ìï©ÎãàÎã§.
    #    batch_first=TrueÎäî Í≤∞Í≥º ÌÖêÏÑúÏùò Ï≤´ Î≤àÏß∏ Ï∞®ÏõêÏù¥ Î∞∞Ïπò ÌÅ¨Í∏∞Í∞Ä ÎêòÎèÑÎ°ù Ìï©ÎãàÎã§.
    #    [B, max_T, A, F] ÌòïÌÉúÍ∞Ä Îê©ÎãàÎã§.
    padded_features = pad_sequence(features_list, batch_first=True, padding_value=0.0)
    
    # pressing_intensityÎèÑ ÎèôÏùºÌïòÍ≤å Ìå®Îî©Ìï©ÎãàÎã§.
    # [B, max_T, 11, 11] ÌòïÌÉúÍ∞Ä Îê©ÎãàÎã§.
    padded_intensities = pad_sequence(intensity_list, batch_first=True, padding_value=0.0)

    # 3. ÌÅ¨Í∏∞Í∞Ä Í≥†Ï†ïÎêú ÌÖêÏÑú Îç∞Ïù¥ÌÑ∞Îì§ÏùÄ torch.stackÏùÑ ÏÇ¨Ïö©ÌïòÏó¨ Î¨∂ÏäµÎãàÎã§.
    labels = torch.stack(labels_list)

    # 4. ÏµúÏ¢ÖÏ†ÅÏúºÎ°ú, Ï≤òÎ¶¨Îêú Îç∞Ïù¥ÌÑ∞Îì§ÏùÑ Îã¥ÏùÄ ÎîïÏÖîÎÑàÎ¶¨Î•º Î∞òÌôòÌï©ÎãàÎã§.
    return {
        'features': padded_features,           # Ìå®Îî©Îêú ÌÖêÏÑú
        'pressing_intensity': padded_intensities, # Ìå®Îî©Îêú ÌÖêÏÑú
        'label': labels,           
        'seq_lengths': seq_lengths,           # Î∞∞ÏπòÎêú ÌÖêÏÑú
        'agent_order': agent_order_list,      # ÌååÏù¥Ïç¨ Î¶¨Ïä§Ìä∏
        'presser_id': presser_id_list,        # ÌååÏù¥Ïç¨ Î¶¨Ïä§Ìä∏
        'pressed_id': pressed_id_list,        # ÌååÏù¥Ïç¨ Î¶¨Ïä§Ìä∏
        'match_info': match_info_list         # ÌååÏù¥Ïç¨ Î¶¨Ïä§Ìä∏
    }

In [18]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_temporal_collate)

In [19]:
batch = next(iter(train_loader))

for batch in train_loader:
    print(batch['features'].shape)

torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 3, 2

torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 6, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 4, 23, 18])
torch.Size([16, 5, 23, 18])
torch.Size([16, 5, 2

In [126]:
from torch_geometric.utils import dense_to_sparse

A = 23
W = torch.zeros((A, A))

avg_press = train_dataset[0]['pressing_intensity'][0, ...]
P, O = avg_press.size(0), avg_press.size(1) # P: number of players, O: number of opponents

W[:P, P:P+O] = avg_press
W[P:P+O, :P] = avg_press.t()

# ball (last node) edges remain zeros
adj = torch.ones((A, A)) - torch.eye(A)
edge_index, _ = dense_to_sparse(adj)
edge_attr = W[edge_index[0], edge_index[1]]

In [130]:
feats = train_dataset[0]['features'][0, ...]

In [131]:
source_nodes, dest_nodes = edge_index[0], edge_index[1]
pos_source = feats[source_nodes, :2]
pos_dest = feats[dest_nodes, :2]
edge_distances = torch.linalg.norm(pos_source - pos_dest, dim=1).unsqueeze(1)

In [133]:
source_is_home = source_nodes < 11
dest_is_home = dest_nodes < 11
source_is_away = (source_nodes >= 11) & (source_nodes < 22)
dest_is_away = (dest_nodes >= 11) & (dest_nodes < 22)

# Í∞ôÏùÄ ÌåÄÏù∏ Í≤ΩÏö∞: (Îëò Îã§ ÌôàÌåÄ) ÎòêÎäî (Îëò Îã§ ÏõêÏ†ïÌåÄ)
is_same = (source_is_home & dest_is_home) | (source_is_away & dest_is_away)
edge_same_team = is_same.float().unsqueeze(1)

In [137]:
edge_attr = torch.cat([edge_distances, edge_same_team, edge_attr.unsqueeze(1)], dim=-1)

In [136]:
edge_attr.shape

torch.Size([506])

In [132]:
edge_distances.shape

torch.Size([506, 1])

In [127]:
edge_attr.shape

torch.Size([506])

In [121]:
train_dataset[0].keys()

dict_keys(['features', 'pressing_intensity', 'label', 'pressed_id', 'presser_id', 'agent_order', 'match_info'])

In [114]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=16)

In [115]:
feat_shape = []
press_shape = []
for batch in train_loader:
    feat_shape.append(batch['features'].shape)

In [39]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import json
import os
os.chdir('/home/work/MHL/express-v2')
import argparse # To accept checkpoint path as argument

# Import project modules
# import config  # Import static configurations
from model import PytorchSoccerMapModel # Import Lightning model
from datasets import PressingSequenceDataset, SoccerMapInputDataset 


In [40]:
pl.seed_everything(42, workers=True) # Ensure reproducibility

DATA_PATH = "/data/MHL/pressing-intensity" # Path where pickled datasets are saved
test_dataset = SoccerMapInputDataset(os.path.join(DATA_PATH, "test_dataset.pkl"))

if len(test_dataset) == 0:
    print("Loaded test dataset is empty. Exiting.")

# Custom collate function to handle potential None values from dataset errors
def collate_fn_skip_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch: return None
    try: return torch.utils.data.dataloader.default_collate(batch)
    except RuntimeError: return None # Skip batch if collation error

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    # collate_fn=collate_fn_skip_none
)

Seed set to 42


Loading dataset from /data/MHL/pressing-intensity/test_dataset.pkl...


In [41]:
import argparse

parser = argparse.ArgumentParser(description="Train a pressing evaluation model.")
# parser.add_argument("--model_type", type=str, default="soccermap", choices=['soccermap', 'xgboost', 'exPress'], help="Path to the model checkpoint (.ckpt) file saved during training.")
# parser.add_argument("--root_path", type=str, default="/data/MHL/pressing-intensity", help="Path to the data file.")
parser.add_argument("--model_type", type=str, default="soccermap", choices=['soccermap', 'xgboost', 'exPress'], help="Path to the model checkpoint (.ckpt) file saved during training.")
parser.add_argument("--root_path", type=str, default="/data/MHL/pressing-intensity", help="Path to the data file.")
parser.add_argument("--mode", type=str, default="train", choices=['train', 'test'], help="Mode: 'train' or 'test'.")
parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint file (Required for 'test' mode).")
parser.add_argument("--params_path", type=str, default="params.json", help="Path to the JSON containing configurations.")
parser.add_argument("--seed", type=int, default=42, help="Seed number.")

args = parser.parse_args([])

args.mode = 'test'
args.model_type = "exPress"
args.ckpt_path = "/data/MHL/pressing-intensity/checkpoints/exPress-epoch=28-val_loss=0.49.ckpt"

In [42]:
from components import press


component_dict = {
                    "soccermap": press.SoccerMapComponent,
                    "exPress": press.exPressComponent,
                }

exp = component_dict[args.model_type](args)

Seed set to 42


Configurations loaded from params.json.
