In [1]:
import multiprocessing as mp
mp.set_start_method('spawn', force=True)

In [2]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import logging
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.amp import autocast, GradScaler
from models.diff_modules import diff_CSDI
from models.diff_model import DiffusionTrajectoryModel
from models.encoder import InteractionGraphEncoder, TargetTrajectoryEncoder
from make_dataset import MultiMatchSoccerDataset, organize_and_process
from utils.utils import set_evertyhing, worker_init_fn, generator, plot_trajectories_on_pitch, log_graph_stats
from utils.data_utils import split_dataset_indices, custom_collate_fn
from utils.graph_utils import build_graph_sequence_from_condition

# SEED Fix
SEED = 42
set_evertyhing(SEED)


# Save Log / Logger Setting
model_save_path = './results/logs/'
os.makedirs(model_save_path, exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(message)s',
    filename=os.path.join(model_save_path, 'train.log'),
    filemode='w'
)
logger = logging.getLogger()

# 1. Model Config & Hyperparameter Setting
csdi_config = {
    "num_steps": 1000,
    "channels": 256,
    "diffusion_embedding_dim": 256,
    "nheads": 4,
    "layers": 5,
    # "side_dim": 128
    "side_dim": 512
}
hyperparams = {
    'raw_data_path': "idsse-data", # raw_data_path = "Download raw file path"
    'data_save_path': "match_data",
    'train_batch_size': 16,
    'val_batch_size': 16,
    'test_batch_size': 16,
    'num_workers': 8,
    'epochs': 50,
    'learning_rate': 1e-4,
    'self_conditioning_ratio': 0.5,
    'num_samples': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    'ddim_step': 200,
    'eta': 0.1,
    **csdi_config
}
raw_data_path = hyperparams['raw_data_path']
data_save_path = hyperparams['data_save_path']
train_batch_size = hyperparams['train_batch_size']
val_batch_size = hyperparams['val_batch_size']
test_batch_size = hyperparams['test_batch_size']
num_workers = hyperparams['num_workers']
epochs = hyperparams['epochs']
learning_rate = hyperparams['learning_rate']
self_conditioning_ratio = hyperparams['self_conditioning_ratio']
num_samples = hyperparams['num_samples']
device = hyperparams['device']
ddim_step = hyperparams['ddim_step']
eta = hyperparams['eta']
side_dim = hyperparams['side_dim']

logger.info(f"Hyperparameters: {hyperparams}")

# 2. Data Loading
print("---Data Loading---")
if not os.path.exists(data_save_path) or len(os.listdir(data_save_path)) == 0:
    organize_and_process(raw_data_path, data_save_path)
else:
    print("Skip organize_and_process")

dataset = MultiMatchSoccerDataset(data_root=data_save_path)
train_idx, val_idx, test_idx = split_dataset_indices(dataset, val_ratio=1/6, test_ratio=1/6, random_seed=SEED)

train_dataloader = DataLoader(
    Subset(dataset, train_idx),
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn,
    generator=generator(SEED)
)

val_dataloader = DataLoader(
    Subset(dataset, val_idx),
    batch_size=val_batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn,
)

test_dataloader = DataLoader(
    Subset(dataset, test_idx),
    batch_size=test_batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn
)

print("---Data Load!---")

# 3. Model Define
# Extract node feature dimension
sample = dataset[0]
graph = build_graph_sequence_from_condition({
    "condition": sample["condition"],
    "condition_columns": sample["condition_columns"],
    "pitch_scale": sample["pitch_scale"]
}).to(device)

log_graph_stats(graph, logger, prefix="InitGraphSample")

in_dim = graph['Node'].x.size(1)

# Extract target's history trajectories from condition
condition_columns = sample["condition_columns"]
target_columns = sample["target_columns"]
target_idx = [condition_columns.index(col) for col in target_columns if col in condition_columns]

# graph_encoder = InteractionGraphEncoder(in_dim=in_dim, hidden_dim=128, out_dim=128, heads = 2).to(device)
graph_encoder = InteractionGraphEncoder(in_dim=in_dim, hidden_dim=side_dim // 2, out_dim=side_dim // 2).to(device)
history_encoder = TargetTrajectoryEncoder(num_layers=5, hidden_dim = side_dim // 4, bidirectional=True).to(device)
denoiser = diff_CSDI(csdi_config)
model = DiffusionTrajectoryModel(denoiser, num_steps=csdi_config["num_steps"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, threshold=1e-4)
scaler = GradScaler()

logger.info(f"Device: {device}")
logger.info(f"GraphEncoder: {graph_encoder}")
logger.info(f"HistoryEncoder: {history_encoder}")
logger.info(f"Denoiser (diff_CSDI): {denoiser}")
logger.info(f"DiffusionTrajectoryModel: {model}")

ImportError: cannot import name 'per_player_frechet_loss' from 'utils.utils' (/home/park/workspace/SportsScience/Soccer-Trajectory-Prediction/utils/utils.py)

In [None]:

# 4. Train
best_state_dict = None
best_val_loss = float("inf")

train_losses = []
val_losses   = []

for epoch in tqdm(range(1, epochs + 1), desc="Training..."):
    model.train()
    train_noise_loss = 0
    # train_mse_loss = 0
    train_dtw_loss = 0
    train_fde_loss = 0
    train_loss = 0

    for batch in tqdm(train_dataloader, desc = "Batch Training..."):
        optimizer.zero_grad()
        cond = batch["condition"].to(device)
        B, T, _ = cond.shape
        target = batch["target"].to(device).view(-1, T, 11, 2)  # [B, T, 11, 2]
        graph_batch = batch["graph"].to(device)                              # HeteroData batch

        # graph → H
        H = graph_encoder(graph_batch)                                       # [B, 128]
        cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        # Target's history trajectories
        hist = cond[:, :, target_idx].to(device) 
        hist_rep = history_encoder(hist)  # [B, 128]
        cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        # Concat conditions
        cond_info = torch.cat([cond_H, cond_hist], dim=1)
        # Preparing Self-conditioning data
        if torch.rand(1, device=device) < self_conditioning_ratio:
            s = torch.zeros_like(target)
        else:
            with torch.no_grad():
                t = torch.randint(0, model.num_steps, (target.size(0),), device=device)
                x_t, noise = model.q_sample(target, t)
                x_t = x_t.permute(0,3,2,1)
                
                z1 = model.model(x_t, t, cond_info, self_cond=None)
                
                eps_pred1 = z1[:, :2, :, :]
                a_hat = model.alpha_hat.to(device)[t].view(-1,1,1,1)
                x0_hat = (x_t - (1 - a_hat).sqrt() * eps_pred1) / a_hat.sqrt()
                x0_hat = x0_hat.permute(0,3,2,1)
                
            s = x0_hat
        
        # noise_loss, player_loss_mse, player_loss_frechet = model(target, cond_info=cond_info, self_cond=s)
        # loss = noise_loss + player_loss_mse + player_loss_frechet * 0.2
        noise_loss, player_loss_dtw, player_loss_fde = model(target, cond_info=cond_info, self_cond=s)
        loss = noise_loss + player_loss_dtw + player_loss_fde
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_noise_loss += (noise_loss).item()
        # train_mse_loss += (player_loss_mse).item()
        train_dtw_loss += (player_loss_dtw).item()
        train_fde_loss += (player_loss_fde).item()
        train_loss += loss.item()

    num_batches = len(train_dataloader)
    
    avg_noise_loss = train_noise_loss / num_batches
    # avg_mse_loss = train_mse_loss / num_batches
    avg_dtw_loss = train_dtw_loss / num_batches
    avg_fde_loss = train_fde_loss / num_batches
    avg_train_loss = train_loss / num_batches


    # --- Validation ---
    model.eval()
    val_noise_loss = 0
    # val_mse_loss = 0
    val_dtw_loss = 0
    val_fde_loss = 0
    val_total_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_dataloader, desc="Validation"):
            cond = batch["condition"].to(device)
            B, T, _ = cond.shape
            target = batch["target"].to(device).view(-1, T, 11, 2)  # [B, T, 11, 2]
            graph_batch = batch["graph"].to(device)                              # HeteroData batch

            # graph → H
            H = graph_encoder(graph_batch)                                       # [B, 128]
            cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
            
            # Target's history trajectories
            hist = cond[:, :, target_idx].to(device)  # [B,128,11,T]
            hist_rep = history_encoder(hist)  # [B, 128]
            cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
            
            # Concat conditions
            cond_info = torch.cat([cond_H, cond_hist], dim=1)
            
            s = torch.zeros_like(target)
            
            # noise_loss, player_loss_mse, player_loss_frechet = model(target, cond_info=cond_info, self_cond=s)
            # val_loss = noise_loss + player_loss_mse + player_loss_frechet * 0.2
            noise_loss, player_loss_dtw, player_loss_fde = model(target, cond_info=cond_info, self_cond=s)
            val_loss = noise_loss + player_loss_dtw + player_loss_fde
        
            val_noise_loss += (noise_loss).item()
            # val_mse_loss += (player_loss_mse).item()
            val_dtw_loss += (player_loss_dtw).item()
            val_fde_loss += (player_loss_fde).item()
            val_total_loss += val_loss.item()

    avg_val_noise_loss = val_noise_loss / len(val_dataloader)
    # avg_val_mse_loss = val_mse_loss / len(val_dataloader)
    avg_val_dtw_loss = val_dtw_loss / len(val_dataloader)
    avg_val_fde_loss = val_fde_loss / len(val_dataloader)
    avg_val_loss = val_total_loss / len(val_dataloader)
  
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    
    current_lr = scheduler.get_last_lr()[0]
    logger.info(f"[Epoch {epoch}/{epochs}] Train Loss={avg_train_loss:.6f} (Noise={avg_noise_loss:.6f}, DTW={avg_dtw_loss:.6f}, FDE={avg_fde_loss:.6f}) | Val Loss={avg_val_loss:.6f} | LR={current_lr:.6e}")
    
    tqdm.write(f"[Epoch {epoch}]\n"
               f"[Train] Cost: {avg_train_loss:.6f} | Noise Loss: {avg_noise_loss:.6f} | DTW Loss: {avg_dtw_loss:.6f} | FDE Loss: {avg_fde_loss:.6f} LR: {current_lr:.6f}\n"
               f"[Validation] Val Loss: {avg_val_loss:.6f} | Noise: {avg_val_noise_loss:.6f} | DTW: {avg_val_dtw_loss:.6f} | FDE: {avg_val_fde_loss:.6f}")
    
    scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_state_dict = model.state_dict()

logger.info(f"Training complete. Best val loss: {best_val_loss:.6f}")
        
# 4-1. Plot learning_curve
plt.figure(figsize=(8, 6))
plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, epochs+1), val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f"Train & Validation Loss, {csdi_config['num_steps']} steps, {csdi_config['channels']} channels, "
          f"{csdi_config['diffusion_embedding_dim']} embedding dim, {csdi_config['nheads']} heads, {csdi_config['layers']} layers "
          f"self-conditioning ratio: {self_conditioning_ratio}")
plt.legend()
plt.tight_layout()

plt.savefig('results/0512_diffusion_lr_curve.png')

plt.show()


Training...:   0%|          | 0/50 [00:00<?, ?it/s]

Batch Training...:   0%|          | 0/264 [00:59<?, ?it/s]

Validation:   0%|          | 0/77 [00:56<?, ?it/s]

[Epoch 1]
[Train] Cost: 1.489763 | Noise Loss: 0.024561 | Frechet Loss: 1.146793 | FDE Loss: 0.318409 LR: 0.000100
[Validation] Val Loss: 0.293558 | Noise: -0.225426 | Frechet: 0.394128 | FDE: 0.124856


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:56<?, ?it/s]

[Epoch 2]
[Train] Cost: 0.136903 | Noise Loss: -0.391952 | Frechet Loss: 0.404770 | FDE Loss: 0.124085 LR: 0.000100
[Validation] Val Loss: -0.192606 | Noise: -0.605984 | Frechet: 0.305869 | FDE: 0.107509


Batch Training...:   0%|          | 0/264 [00:54<?, ?it/s]

Validation:   0%|          | 0/77 [00:55<?, ?it/s]

[Epoch 3]
[Train] Cost: -0.240140 | Noise Loss: -0.681969 | Frechet Loss: 0.332264 | FDE Loss: 0.109566 LR: 0.000100
[Validation] Val Loss: -0.396890 | Noise: -0.786336 | Frechet: 0.286838 | FDE: 0.102607


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:55<?, ?it/s]

[Epoch 4]
[Train] Cost: -0.413216 | Noise Loss: -0.799655 | Frechet Loss: 0.287078 | FDE Loss: 0.099360 LR: 0.000100
[Validation] Val Loss: -0.424431 | Noise: -0.787874 | Frechet: 0.264821 | FDE: 0.098621


Batch Training...:   0%|          | 0/264 [00:56<?, ?it/s]

Validation:   0%|          | 0/77 [00:57<?, ?it/s]

[Epoch 5]
[Train] Cost: -0.503475 | Noise Loss: -0.869097 | Frechet Loss: 0.270614 | FDE Loss: 0.095008 LR: 0.000100
[Validation] Val Loss: -0.445057 | Noise: -0.775082 | Frechet: 0.236751 | FDE: 0.093274


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:57<?, ?it/s]

[Epoch 6]
[Train] Cost: -0.589561 | Noise Loss: -0.947479 | Frechet Loss: 0.264994 | FDE Loss: 0.092925 LR: 0.000100
[Validation] Val Loss: -0.453975 | Noise: -0.790138 | Frechet: 0.240558 | FDE: 0.095604


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:55<?, ?it/s]

[Epoch 7]
[Train] Cost: -0.682408 | Noise Loss: -1.017955 | Frechet Loss: 0.248980 | FDE Loss: 0.086567 LR: 0.000100
[Validation] Val Loss: -0.514655 | Noise: -0.834749 | Frechet: 0.230068 | FDE: 0.090026


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:54<?, ?it/s]

[Epoch 8]
[Train] Cost: -0.739712 | Noise Loss: -1.063200 | Frechet Loss: 0.238397 | FDE Loss: 0.085092 LR: 0.000100
[Validation] Val Loss: -0.452473 | Noise: -0.726155 | Frechet: 0.193101 | FDE: 0.080581


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:54<?, ?it/s]

[Epoch 9]
[Train] Cost: -0.783322 | Noise Loss: -1.094228 | Frechet Loss: 0.229464 | FDE Loss: 0.081442 LR: 0.000100
[Validation] Val Loss: -0.669831 | Noise: -0.952810 | Frechet: 0.199373 | FDE: 0.083605


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:55<?, ?it/s]

[Epoch 10]
[Train] Cost: -0.881608 | Noise Loss: -1.169299 | Frechet Loss: 0.210702 | FDE Loss: 0.076989 LR: 0.000100
[Validation] Val Loss: -0.576599 | Noise: -0.867026 | Frechet: 0.207376 | FDE: 0.083051


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:56<?, ?it/s]

[Epoch 11]
[Train] Cost: -0.914725 | Noise Loss: -1.212314 | Frechet Loss: 0.218895 | FDE Loss: 0.078694 LR: 0.000100
[Validation] Val Loss: -0.647067 | Noise: -0.930733 | Frechet: 0.201961 | FDE: 0.081706


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:56<?, ?it/s]

[Epoch 12]
[Train] Cost: -0.970932 | Noise Loss: -1.246008 | Frechet Loss: 0.201962 | FDE Loss: 0.073113 LR: 0.000100
[Validation] Val Loss: -0.653716 | Noise: -0.936185 | Frechet: 0.197756 | FDE: 0.084714


Batch Training...:   0%|          | 0/264 [00:54<?, ?it/s]

Validation:   0%|          | 0/77 [00:55<?, ?it/s]

[Epoch 13]
[Train] Cost: -1.090868 | Noise Loss: -1.341759 | Frechet Loss: 0.183430 | FDE Loss: 0.067461 LR: 0.000050
[Validation] Val Loss: -0.778439 | Noise: -1.040188 | Frechet: 0.183450 | FDE: 0.078299


Batch Training...:   0%|          | 0/264 [00:53<?, ?it/s]

Validation:   0%|          | 0/77 [00:55<?, ?it/s]

[Epoch 14]
[Train] Cost: -1.131490 | Noise Loss: -1.366414 | Frechet Loss: 0.170349 | FDE Loss: 0.064574 LR: 0.000050
[Validation] Val Loss: -0.733234 | Noise: -0.979740 | Frechet: 0.171186 | FDE: 0.075320


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:56<?, ?it/s]

[Epoch 15]
[Train] Cost: -1.132873 | Noise Loss: -1.354740 | Frechet Loss: 0.159891 | FDE Loss: 0.061976 LR: 0.000050
[Validation] Val Loss: -0.793460 | Noise: -1.042108 | Frechet: 0.172696 | FDE: 0.075951


Batch Training...:   0%|          | 0/264 [00:55<?, ?it/s]

Validation:   0%|          | 0/77 [00:55<?, ?it/s]

[Epoch 16]
[Train] Cost: -1.197687 | Noise Loss: -1.425194 | Frechet Loss: 0.165121 | FDE Loss: 0.062386 LR: 0.000050
[Validation] Val Loss: -0.682655 | Noise: -0.922373 | Frechet: 0.163522 | FDE: 0.076196


Batch Training...:   0%|          | 0/264 [00:56<?, ?it/s]

Validation:   0%|          | 0/77 [01:07<?, ?it/s]

[Epoch 17]
[Train] Cost: -1.223641 | Noise Loss: -1.444113 | Frechet Loss: 0.159680 | FDE Loss: 0.060792 LR: 0.000050
[Validation] Val Loss: -0.635851 | Noise: -0.885681 | Frechet: 0.172206 | FDE: 0.077625


Batch Training...:   0%|          | 0/264 [00:58<?, ?it/s]

Validation:   0%|          | 0/77 [01:06<?, ?it/s]

[Epoch 18]
[Train] Cost: -1.253018 | Noise Loss: -1.468145 | Frechet Loss: 0.155535 | FDE Loss: 0.059592 LR: 0.000050
[Validation] Val Loss: -0.740066 | Noise: -0.971657 | Frechet: 0.158380 | FDE: 0.073211


Batch Training...:   0%|          | 0/264 [00:56<?, ?it/s]

Validation:   0%|          | 0/77 [01:05<?, ?it/s]

[Epoch 19]
[Train] Cost: -1.283412 | Noise Loss: -1.478120 | Frechet Loss: 0.139613 | FDE Loss: 0.055096 LR: 0.000025
[Validation] Val Loss: -0.800288 | Noise: -1.018375 | Frechet: 0.148335 | FDE: 0.069752


Batch Training...:   0%|          | 0/264 [00:56<?, ?it/s]

Validation:   0%|          | 0/77 [01:05<?, ?it/s]

[Epoch 20]
[Train] Cost: -1.351177 | Noise Loss: -1.550036 | Frechet Loss: 0.143372 | FDE Loss: 0.055487 LR: 0.000025
[Validation] Val Loss: -0.861436 | Noise: -1.082674 | Frechet: 0.151693 | FDE: 0.069545


Batch Training...:   0%|          | 0/264 [00:57<?, ?it/s]

Validation:   0%|          | 0/77 [01:08<?, ?it/s]

[Epoch 21]
[Train] Cost: -1.340233 | Noise Loss: -1.533155 | Frechet Loss: 0.138078 | FDE Loss: 0.054844 LR: 0.000025
[Validation] Val Loss: -0.715444 | Noise: -0.927987 | Frechet: 0.144192 | FDE: 0.068351


Batch Training...:   0%|          | 0/264 [00:57<?, ?it/s]

Validation:   0%|          | 0/77 [01:06<?, ?it/s]

[Epoch 22]
[Train] Cost: -1.352620 | Noise Loss: -1.542745 | Frechet Loss: 0.136154 | FDE Loss: 0.053972 LR: 0.000025
[Validation] Val Loss: -0.701862 | Noise: -0.914496 | Frechet: 0.144756 | FDE: 0.067878


Batch Training...:   0%|          | 0/264 [01:39<?, ?it/s]

Validation:   0%|          | 0/77 [01:03<?, ?it/s]

[Epoch 23]
[Train] Cost: -1.375837 | Noise Loss: -1.564741 | Frechet Loss: 0.135593 | FDE Loss: 0.053311 LR: 0.000025
[Validation] Val Loss: -0.854327 | Noise: -1.074483 | Frechet: 0.149116 | FDE: 0.071040


Batch Training...:   0%|          | 0/264 [00:56<?, ?it/s]

Validation:   0%|          | 0/77 [01:06<?, ?it/s]

[Epoch 24]
[Train] Cost: -1.412803 | Noise Loss: -1.596755 | Frechet Loss: 0.131590 | FDE Loss: 0.052362 LR: 0.000013
[Validation] Val Loss: -0.780437 | Noise: -0.989660 | Frechet: 0.141471 | FDE: 0.067752


Batch Training...:   0%|          | 0/264 [00:58<?, ?it/s]

Validation:   0%|          | 0/77 [01:05<?, ?it/s]

[Epoch 25]
[Train] Cost: -1.419854 | Noise Loss: -1.601308 | Frechet Loss: 0.129716 | FDE Loss: 0.051738 LR: 0.000013
[Validation] Val Loss: -0.807755 | Noise: -1.014826 | Frechet: 0.139961 | FDE: 0.067110


Batch Training...:   0%|          | 0/264 [01:00<?, ?it/s]

Validation:   0%|          | 0/77 [01:04<?, ?it/s]

[Epoch 26]
[Train] Cost: -1.438757 | Noise Loss: -1.620317 | Frechet Loss: 0.130092 | FDE Loss: 0.051467 LR: 0.000013
[Validation] Val Loss: -0.796569 | Noise: -1.005769 | Frechet: 0.140883 | FDE: 0.068316


Batch Training...:   0%|          | 0/264 [00:58<?, ?it/s]

Validation:   0%|          | 0/77 [01:05<?, ?it/s]

[Epoch 27]
[Train] Cost: -1.458523 | Noise Loss: -1.640488 | Frechet Loss: 0.130643 | FDE Loss: 0.051321 LR: 0.000006
[Validation] Val Loss: -0.795852 | Noise: -1.002080 | Frechet: 0.139006 | FDE: 0.067222


Batch Training...:   0%|          | 0/264 [00:59<?, ?it/s]

Validation:   0%|          | 0/77 [01:04<?, ?it/s]

[Epoch 28]
[Train] Cost: -1.458501 | Noise Loss: -1.634942 | Frechet Loss: 0.125499 | FDE Loss: 0.050942 LR: 0.000006
[Validation] Val Loss: -0.811538 | Noise: -1.020256 | Frechet: 0.141456 | FDE: 0.067263


Batch Training...:   0%|          | 0/264 [00:59<?, ?it/s]

Validation:   0%|          | 0/77 [01:04<?, ?it/s]

[Epoch 29]
[Train] Cost: -1.473226 | Noise Loss: -1.650689 | Frechet Loss: 0.126846 | FDE Loss: 0.050618 LR: 0.000006
[Validation] Val Loss: -0.804130 | Noise: -1.011212 | Frechet: 0.139451 | FDE: 0.067631


Batch Training...:   0%|          | 0/264 [00:59<?, ?it/s]

Validation:   0%|          | 0/77 [01:04<?, ?it/s]

[Epoch 30]
[Train] Cost: -1.466183 | Noise Loss: -1.642020 | Frechet Loss: 0.125457 | FDE Loss: 0.050380 LR: 0.000003
[Validation] Val Loss: -0.746084 | Noise: -0.944767 | Frechet: 0.132690 | FDE: 0.065992


Batch Training...:   0%|          | 0/264 [01:00<?, ?it/s]

Validation:   0%|          | 0/77 [01:06<?, ?it/s]

[Epoch 31]
[Train] Cost: -1.469402 | Noise Loss: -1.642827 | Frechet Loss: 0.123342 | FDE Loss: 0.050083 LR: 0.000003
[Validation] Val Loss: -0.812675 | Noise: -1.019072 | Frechet: 0.139281 | FDE: 0.067116


Batch Training...:   0%|          | 0/264 [01:03<?, ?it/s]

Validation:   0%|          | 0/77 [01:05<?, ?it/s]

[Epoch 32]
[Train] Cost: -1.439210 | Noise Loss: -1.610396 | Frechet Loss: 0.121720 | FDE Loss: 0.049466 LR: 0.000003
[Validation] Val Loss: -0.778187 | Noise: -0.982257 | Frechet: 0.137938 | FDE: 0.066131


Batch Training...:   0%|          | 0/264 [01:02<?, ?it/s]

Validation:   0%|          | 0/77 [01:04<?, ?it/s]

[Epoch 33]
[Train] Cost: -1.467563 | Noise Loss: -1.637561 | Frechet Loss: 0.120619 | FDE Loss: 0.049380 LR: 0.000002
[Validation] Val Loss: -0.780895 | Noise: -0.981992 | Frechet: 0.134629 | FDE: 0.066468


Batch Training...:   0%|          | 0/264 [01:01<?, ?it/s]

Validation:   0%|          | 0/77 [01:05<?, ?it/s]

[Epoch 34]
[Train] Cost: -1.482609 | Noise Loss: -1.657224 | Frechet Loss: 0.124433 | FDE Loss: 0.050182 LR: 0.000002
[Validation] Val Loss: -0.751550 | Noise: -0.950098 | Frechet: 0.132721 | FDE: 0.065827


Batch Training...:   0%|          | 0/264 [01:00<?, ?it/s]

Validation:   0%|          | 0/77 [01:04<?, ?it/s]

[Epoch 35]
[Train] Cost: -1.455259 | Noise Loss: -1.625655 | Frechet Loss: 0.121127 | FDE Loss: 0.049269 LR: 0.000002
[Validation] Val Loss: -0.781013 | Noise: -0.979596 | Frechet: 0.132843 | FDE: 0.065739


Batch Training...:   0%|          | 0/264 [01:01<?, ?it/s]

Validation:   0%|          | 0/77 [01:05<?, ?it/s]

[Epoch 36]
[Train] Cost: -1.457479 | Noise Loss: -1.629250 | Frechet Loss: 0.122112 | FDE Loss: 0.049660 LR: 0.000001
[Validation] Val Loss: -0.692590 | Noise: -0.890858 | Frechet: 0.133468 | FDE: 0.064800


Batch Training...:   0%|          | 0/264 [01:04<?, ?it/s]

Validation:   0%|          | 0/77 [01:06<?, ?it/s]

[Epoch 37]
[Train] Cost: -1.486829 | Noise Loss: -1.661080 | Frechet Loss: 0.124494 | FDE Loss: 0.049757 LR: 0.000001
[Validation] Val Loss: -0.718185 | Noise: -0.915579 | Frechet: 0.132202 | FDE: 0.065192


Batch Training...:   0%|          | 0/264 [01:02<?, ?it/s]

Validation:   0%|          | 0/77 [00:59<?, ?it/s]

[Epoch 38]
[Train] Cost: -1.457526 | Noise Loss: -1.630214 | Frechet Loss: 0.123446 | FDE Loss: 0.049242 LR: 0.000001
[Validation] Val Loss: -0.771113 | Noise: -0.973734 | Frechet: 0.136336 | FDE: 0.066285


Batch Training...:   0%|          | 0/264 [01:04<?, ?it/s]

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/park/anaconda3/envs/SoccerTraj/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/park/anaconda3/envs/SoccerTraj/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/home/park/workspace/SportsScience/Soccer-Trajectory-Prediction/make_dataset.py", line 15, in <module>
    from utils.utils import calc_velocites, correct_all_player_jumps_adjacent
  File "/home/park/workspace/SportsScience/Soccer-Trajectory-Prediction/utils/utils.py", line 232, in <module>
    from soft_dtw import SoftDTW
ModuleNotFoundError: No module named 'soft_dtw'


In [None]:

# 5. Inference (Best-of-N Sampling) & Visualization
model.load_state_dict(best_state_dict)
model.eval()
all_best_ades_test = []
all_best_fdes_test = []
visualize_samples = 5
visualized = False

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Test Streaming Inference"):
        cond = batch["condition"].to(device)
        B, T, _ = cond.shape
        target = batch["target"].to(device).view(B, T, 11, 2)

        graph_batch  = batch["graph"].to(device)
        H = graph_encoder(graph_batch)
        cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        hist = cond[:, :, target_idx].to(device)
        hist_rep = history_encoder(hist)
        cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        cond_info = torch.cat([cond_H, cond_hist], dim=1)

        best_ade_t = torch.full((B,), float("inf"), device=device)
        best_pred_t = torch.zeros_like(target)
        best_fde_t = torch.full((B,), float("inf"), device=device)
                    
        scales = torch.tensor(batch["pitch_scale"], device=device, dtype=torch.float32)  
        scales = scales.view(B, 1, 1, 2)

        for _ in tqdm(range(num_samples), desc="Generating..."):
            pred_i = model.generate(shape=target.shape, cond_info=cond_info, ddim_steps=ddim_step, eta=eta, num_samples=1)[0]

            pred_i_den = pred_i * scales
            target_den = target * scales
            
            ade_i = ((pred_i_den - target_den)**2).sum(-1).sqrt().mean((1,2))
            fde_i = ((pred_i_den[:,-1] - target_den[:,-1])**2).sum(-1).sqrt().mean(1)
            
            better = ade_i < best_ade_t
            
            best_pred_t[better] = pred_i_den[better]
            best_ade_t[better] = ade_i[better]
            best_fde_t[better] = fde_i[better]

        all_best_ades_test.extend(best_ade_t.cpu().tolist())
        all_best_fdes_test.extend(best_fde_t.cpu().tolist())

        # Visualization
        if not visualized:
            base_dir = "results/test_trajs"
            os.makedirs(base_dir, exist_ok=True)
            for i in range(min(B, visualize_samples)):
                sample_dir = os.path.join(base_dir, f"sample{i:02d}")
                os.makedirs(sample_dir, exist_ok=True)
                
                other_cols  = batch["other_columns"][i]
                target_cols = batch["target_columns"][i]
                defender_nums = [int(col.split('_')[1]) for col in target_cols[::2]]

                others_seq = batch["other"][i].view(T, 12, 2).cpu().numpy()
                target_traj = target_den[i].cpu().numpy()
                pred_traj = best_pred_t[i].cpu().numpy()

                for idx, jersey in enumerate(defender_nums):
                    save_path = os.path.join(sample_dir, f"player_{jersey:02d}.png")
                    plot_trajectories_on_pitch(others_seq, target_traj, pred_traj,
                                               other_columns=other_cols, target_columns=target_cols,
                                               player_idx=idx, annotate=True, save_path=save_path)

            visualized = True

avg_test_ade = np.mean(all_best_ades_test)
avg_test_fde = np.mean(all_best_fdes_test)
print(f"[Test Best-of-{num_samples}] Average ADE: {avg_test_ade:.4f} | Average FDE: {avg_test_fde:.4f}")
print(f"[Test Best-of-{num_samples}] Best ADE overall: {min(all_best_ades_test):.4f} | Best FDE overall: {min(all_best_fdes_test):.4f}")


Test Streaming Inference:   0%|          | 0/98 [00:56<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

[47.12474060058594, 48.35272216796875, 47.058170318603516, 48.267372131347656, 48.18181610107422, 48.32517623901367, 47.882080078125, 47.73262405395508, 48.303955078125, 47.62089538574219, 47.876163482666016, 48.19601058959961, 48.23334503173828, 48.0052490234375, 48.99375915527344, 48.5924072265625]
[52.43627166748047, 53.271663665771484, 57.27537536621094, 49.59867858886719, 48.58372116088867, 47.1732292175293, 51.14168167114258, 62.31167984008789, 62.052852630615234, 38.381710052490234, 51.8764762878418, 47.30390167236328, 44.88669967651367, 42.53123092651367, 47.781394958496094, 59.682098388671875]


Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

[Test Best-of-10] Average ADE: 46.9732 | Average FDE: 47.3212
[Test Best-of-10] Best ADE overall: 42.4612 | Best FDE overall: 22.6125
