In [1]:
import multiprocessing as mp
mp.set_start_method('spawn', force=True)

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import logging
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.amp import autocast, GradScaler
from models.diff_modules import diff_CSDI
from models.diff_model import DiffusionTrajectoryModel
from models.encoder import InteractionGraphEncoder, TargetTrajectoryEncoder
from make_dataset import MultiMatchSoccerDataset, organize_and_process
from utils.utils import set_evertyhing, worker_init_fn, generator, plot_trajectories_on_pitch, log_graph_stats
from utils.data_utils import split_dataset_indices, custom_collate_fn
from utils.graph_utils import build_graph_sequence_from_condition

# SEED Fix
SEED = 42
set_evertyhing(SEED)


# Save Log / Logger Setting
model_save_path = './results/logs/'
os.makedirs(model_save_path, exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(message)s',
    filename=os.path.join(model_save_path, 'train.log'),
    filemode='w'
)
logger = logging.getLogger()

# 1. Model Config & Hyperparameter Setting
csdi_config = {
    "num_steps": 1000,
    "channels": 256,
    "diffusion_embedding_dim": 256,
    "nheads": 4,
    "layers": 5,
    # "side_dim": 128
    "side_dim": 512
}
hyperparams = {
    'raw_data_path': "idsse-data", # raw_data_path = "Download raw file path"
    'data_save_path': "match_data",
    'train_batch_size': 16,
    'val_batch_size': 16,
    'test_batch_size': 16,
    'num_workers': 8,
    'epochs': 50,
    'learning_rate': 1e-4,
    'self_conditioning_ratio': 0.5,
    'num_samples': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    'ddim_step': 200,
    'eta': 0.1,
    **csdi_config
}
raw_data_path = hyperparams['raw_data_path']
data_save_path = hyperparams['data_save_path']
train_batch_size = hyperparams['train_batch_size']
val_batch_size = hyperparams['val_batch_size']
test_batch_size = hyperparams['test_batch_size']
num_workers = hyperparams['num_workers']
epochs = hyperparams['epochs']
learning_rate = hyperparams['learning_rate']
self_conditioning_ratio = hyperparams['self_conditioning_ratio']
num_samples = hyperparams['num_samples']
device = hyperparams['device']
ddim_step = hyperparams['ddim_step']
eta = hyperparams['eta']
side_dim = hyperparams['side_dim']

logger.info(f"Hyperparameters: {hyperparams}")

# 2. Data Loading
print("---Data Loading---")
if not os.path.exists(data_save_path) or len(os.listdir(data_save_path)) == 0:
    organize_and_process(raw_data_path, data_save_path)
else:
    print("Skip organize_and_process")

dataset = MultiMatchSoccerDataset(data_root=data_save_path)
train_idx, val_idx, test_idx = split_dataset_indices(dataset, val_ratio=1/6, test_ratio=1/6, random_seed=SEED)

train_dataloader = DataLoader(
    Subset(dataset, train_idx),
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn,
    generator=generator(SEED)
)

val_dataloader = DataLoader(
    Subset(dataset, val_idx),
    batch_size=val_batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn,
)

test_dataloader = DataLoader(
    Subset(dataset, test_idx),
    batch_size=test_batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn
)

print("---Data Load!---")

# 3. Model Define
# Extract node feature dimension
sample = dataset[0]
graph = build_graph_sequence_from_condition({
    "condition": sample["condition"],
    "condition_columns": sample["condition_columns"],
    "pitch_scale": sample["pitch_scale"]
}).to(device)

log_graph_stats(graph, logger, prefix="InitGraphSample")

in_dim = graph['Node'].x.size(1)

# Extract target's history trajectories from condition
condition_columns = sample["condition_columns"]
target_columns = sample["target_columns"]
target_idx = [condition_columns.index(col) for col in target_columns if col in condition_columns]

# graph_encoder = InteractionGraphEncoder(in_dim=in_dim, hidden_dim=128, out_dim=128, heads = 2).to(device)
graph_encoder = InteractionGraphEncoder(in_dim=in_dim, hidden_dim=side_dim // 2, out_dim=side_dim // 2).to(device)
history_encoder = TargetTrajectoryEncoder(num_layers=5, hidden_dim = side_dim // 4, bidirectional=True).to(device)
denoiser = diff_CSDI(csdi_config)
model = DiffusionTrajectoryModel(denoiser, num_steps=csdi_config["num_steps"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, threshold=1e-4)
scaler = GradScaler()

logger.info(f"Device: {device}")
logger.info(f"GraphEncoder: {graph_encoder}")
logger.info(f"HistoryEncoder: {history_encoder}")
logger.info(f"Denoiser (diff_CSDI): {denoiser}")
logger.info(f"DiffusionTrajectoryModel: {model}")

---Data Loading---
Skip organize_and_process


Loading Matches: 100%|██████████| 6/6 [00:40<00:00,  6.70s/it]


---Data Load!---


: 

In [None]:

# 4. Train
best_state_dict = None
best_val_loss = float("inf")

train_losses = []
val_losses   = []

for epoch in tqdm(range(1, epochs + 1), desc="Training..."):
    model.train()
    train_noise_mse = 0
    train_noise_nll = 0
    # train_mse_loss = 0
    # train_dtw_loss = 0
    # train_fde_loss = 0
    train_loss = 0

    for batch in tqdm(train_dataloader, desc = "Batch Training..."):
        cond = batch["condition"].to(device)
        B, T, _ = cond.shape
        target = batch["target"].to(device).view(-1, T, 11, 2)  # [B, T, 11, 2]
        graph_batch = batch["graph"].to(device)                              # HeteroData batch

        # graph → H
        H = graph_encoder(graph_batch)                                       # [B, 128]
        cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        # Target's history trajectories
        hist = cond[:, :, target_idx].to(device) 
        hist_rep = history_encoder(hist)  # [B, 128]
        cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        # Concat conditions
        cond_info = torch.cat([cond_H, cond_hist], dim=1)
        # Preparing Self-conditioning data
        # timestep (consistency)
        t = torch.randint(0, model.num_steps, (target.size(0),), device=device)
        if torch.rand(1, device=device) < self_conditioning_ratio:
            with torch.no_grad():
                x_t, noise = model.q_sample(target, t)
                x_t = x_t.permute(0,3,2,1)
                
                z1 = model.model(x_t, t, cond_info, self_cond=None)
                
                eps_pred1 = z1[:, :2, :, :]
                a_hat = model.alpha_hat.to(device)[t].view(-1,1,1,1)
                x0_hat = (x_t - (1 - a_hat).sqrt() * eps_pred1) / a_hat.sqrt()
                x0_hat = x0_hat.permute(0,3,2,1)
                                
            s = x0_hat
        else:
            s = torch.zeros_like(target)

        
        # noise_loss, player_loss_mse, player_loss_frechet = model(target, cond_info=cond_info, self_cond=s)
        # loss = noise_loss + player_loss_mse + player_loss_frechet * 0.2
        noise_mse, noise_nll = model(target, t=t, cond_info=cond_info, self_cond=s)
        loss = noise_mse + noise_nll * 0.001
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # train_noise_loss += (noise_loss).item()
        # # train_mse_loss += (player_loss_mse).item()
        # train_dtw_loss += (player_loss_dtw).item()
        # train_fde_loss += (player_loss_fde).item()
        train_noise_mse += (noise_mse).item()
        train_noise_nll += (noise_nll * 0.001).item()
        train_loss += loss.item()

    num_batches = len(train_dataloader)
    
    # avg_noise_loss = train_noise_loss / num_batches
    # # avg_mse_loss = train_mse_loss / num_batches
    # avg_dtw_loss = train_dtw_loss / num_batches
    # avg_fde_loss = train_fde_loss / num_batches
    avg_train_noise_mse = train_noise_mse / num_batches
    avg_train_noise_nll = train_noise_nll / num_batches
    avg_train_loss = train_loss / num_batches


    # --- Validation ---
    model.eval()
    # val_noise_loss = 0
    # val_mse_loss = 0
    # val_dtw_loss = 0
    # val_fde_loss = 0
    val_noise_mse = 0
    val_noise_nll = 0
    val_total_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_dataloader, desc="Validation"):
            cond = batch["condition"].to(device)
            B, T, _ = cond.shape
            target = batch["target"].to(device).view(-1, T, 11, 2)  # [B, T, 11, 2]
            graph_batch = batch["graph"].to(device)                              # HeteroData batch

            # graph → H
            H = graph_encoder(graph_batch)                                       # [B, 128]
            cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
            
            # Target's history trajectories
            hist = cond[:, :, target_idx].to(device)  # [B,128,11,T]
            hist_rep = history_encoder(hist)  # [B, 128]
            cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
            
            # Concat conditions
            cond_info = torch.cat([cond_H, cond_hist], dim=1)
            
            s = torch.zeros_like(target)
            
            # noise_loss, player_loss_mse, player_loss_frechet = model(target, cond_info=cond_info, self_cond=s)
            # val_loss = noise_loss + player_loss_mse + player_loss_frechet * 0.2
            noise_mse, noise_nll = model(target, cond_info=cond_info, self_cond=s)
            val_loss = noise_mse + noise_nll * 0.001
        
            # val_noise_loss += (noise_loss).item()
            # # val_mse_loss += (player_loss_mse).item()
            # val_dtw_loss += (player_loss_dtw).item()
            # val_fde_loss += (player_loss_fde).item()
            val_noise_mse += (noise_mse).item()
            val_noise_nll += (noise_nll * 0.001).item()
            val_total_loss += val_loss.item()

    # avg_val_noise_loss = val_noise_loss / len(val_dataloader)
    # # avg_val_mse_loss = val_mse_loss / len(val_dataloader)
    # avg_val_dtw_loss = val_dtw_loss / len(val_dataloader)
    # avg_val_fde_loss = val_fde_loss / len(val_dataloader)
    # avg_val_loss = val_total_loss / len(val_dataloader)
    
    num_batches = len(val_dataloader)
    
    avg_val_noise_mse = val_noise_mse / num_batches
    avg_val_noise_nll = val_noise_nll / num_batches
    avg_val_loss = val_total_loss / num_batches
  
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    
    current_lr = scheduler.get_last_lr()[0]
    logger.info(f"[Epoch {epoch}/{epochs}] Train Loss={avg_train_loss:.6f} (Noise simple={avg_train_noise_mse:.6f}, Noise NLL={avg_train_noise_nll:.6f}) | Val Loss={avg_val_loss:.6f} | LR={current_lr:.6e}")
    
    tqdm.write(f"[Epoch {epoch}]\n"
               f"[Train] Cost: {avg_train_loss:.6f} | Noise Loss: {avg_train_noise_mse:.6f} | NLL Loss: {avg_train_noise_nll:.6f} | LR: {current_lr:.6f}\n"
               f"[Validation] Val Loss: {avg_val_loss:.6f} | Noise Loss: {avg_val_noise_mse:.6f} | NLL Loss: {avg_val_noise_nll:.6f}")
    
    scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_state_dict = model.state_dict()

logger.info(f"Training complete. Best val loss: {best_val_loss:.6f}")
        
# 4-1. Plot learning_curve
plt.figure(figsize=(8, 6))
plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, epochs+1), val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f"Train & Validation Loss, {csdi_config['num_steps']} steps, {csdi_config['channels']} channels, "
          f"{csdi_config['diffusion_embedding_dim']} embedding dim, {csdi_config['nheads']} heads, {csdi_config['layers']} layers "
          f"self-conditioning ratio: {self_conditioning_ratio}")
plt.legend()
plt.tight_layout()

plt.savefig('results/0512_diffusion_lr_curve.png')

plt.show()


Training...:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:

# 5. Inference (Best-of-N Sampling) & Visualization
model.load_state_dict(best_state_dict)
model.eval()
all_best_ades_test = []
all_best_fdes_test = []
visualize_samples = 5
visualized = False

ddim_step = 50
eta = 0.0

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Test Streaming Inference"):
        cond = batch["condition"].to(device)
        B, T, _ = cond.shape
        target = batch["target"].to(device).view(B, T, 11, 2)

        graph_batch  = batch["graph"].to(device)
        H = graph_encoder(graph_batch)
        cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        hist = cond[:, :, target_idx].to(device)
        hist_rep = history_encoder(hist)
        cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        cond_info = torch.cat([cond_H, cond_hist], dim=1)

        best_ade_t = torch.full((B,), float("inf"), device=device)
        best_pred_t = torch.zeros_like(target)
        best_fde_t = torch.full((B,), float("inf"), device=device)
                    
        scales = torch.tensor(batch["pitch_scale"], device=device, dtype=torch.float32)  
        scales = scales.view(B, 1, 1, 2)

        for _ in tqdm(range(num_samples), desc="Generating..."):
            pred_i = model.generate(shape=target.shape, cond_info=cond_info, ddim_steps=ddim_step, eta=eta, num_samples=1)[0]

            pred_i_den = pred_i * scales
            target_den = target * scales
            
            ade_i = ((pred_i_den - target_den)**2).sum(-1).sqrt().mean((1,2))
            fde_i = ((pred_i_den[:,-1] - target_den[:,-1])**2).sum(-1).sqrt().mean(1)
            
            better = ade_i < best_ade_t
            
            best_pred_t[better] = pred_i_den[better]
            best_ade_t[better] = ade_i[better]
            best_fde_t[better] = fde_i[better]

        all_best_ades_test.extend(best_ade_t.cpu().tolist())
        all_best_fdes_test.extend(best_fde_t.cpu().tolist())

        # Visualization
        if not visualized:
            base_dir = "results/test_trajs"
            os.makedirs(base_dir, exist_ok=True)
            for i in range(min(B, visualize_samples)):
                sample_dir = os.path.join(base_dir, f"sample{i:02d}")
                os.makedirs(sample_dir, exist_ok=True)
                
                other_cols  = batch["other_columns"][i]
                target_cols = batch["target_columns"][i]
                defender_nums = [int(col.split('_')[1]) for col in target_cols[::2]]

                others_seq = batch["other"][i].view(T, 12, 2).cpu().numpy()
                target_traj = target_den[i].cpu().numpy()
                pred_traj = best_pred_t[i].cpu().numpy()

                for idx, jersey in enumerate(defender_nums):
                    save_path = os.path.join(sample_dir, f"player_{jersey:02d}.png")
                    plot_trajectories_on_pitch(others_seq, target_traj, pred_traj,
                                               other_columns=other_cols, target_columns=target_cols,
                                               player_idx=idx, annotate=True, save_path=save_path)

            visualized = True
            print(all_best_ades_test)
            print(all_best_fdes_test)
avg_test_ade = np.mean(all_best_ades_test)
avg_test_fde = np.mean(all_best_fdes_test)
print(f"[Test Best-of-{num_samples}] Average ADE: {avg_test_ade:.4f} | Average FDE: {avg_test_fde:.4f}")
print(f"[Test Best-of-{num_samples}] Best ADE overall: {min(all_best_ades_test):.4f} | Best FDE overall: {min(all_best_fdes_test):.4f}")


Test Streaming Inference:   0%|          | 0/98 [00:56<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

[47.12474060058594, 48.35272216796875, 47.058170318603516, 48.267372131347656, 48.18181610107422, 48.32517623901367, 47.882080078125, 47.73262405395508, 48.303955078125, 47.62089538574219, 47.876163482666016, 48.19601058959961, 48.23334503173828, 48.0052490234375, 48.99375915527344, 48.5924072265625]
[52.43627166748047, 53.271663665771484, 57.27537536621094, 49.59867858886719, 48.58372116088867, 47.1732292175293, 51.14168167114258, 62.31167984008789, 62.052852630615234, 38.381710052490234, 51.8764762878418, 47.30390167236328, 44.88669967651367, 42.53123092651367, 47.781394958496094, 59.682098388671875]


Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

Generating...:   0%|          | 0/10 [00:00<?, ?it/s]

[Test Best-of-10] Average ADE: 46.9732 | Average FDE: 47.3212
[Test Best-of-10] Best ADE overall: 42.4612 | Best FDE overall: 22.6125


In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import logging
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.amp import autocast, GradScaler
from models.diff_modules import diff_CSDI
from models.diff_model import DiffusionTrajectoryModel
from models.encoder import InteractionGraphEncoder, TargetTrajectoryEncoder
from make_dataset import MultiMatchSoccerDataset, organize_and_process
from utils.utils import set_evertyhing, worker_init_fn, generator, plot_trajectories_on_pitch, log_graph_stats
from utils.data_utils import split_dataset_indices, custom_collate_fn
from utils.graph_utils import build_graph_sequence_from_condition

# SEED Fix
SEED = 42
set_evertyhing(SEED)


# Save Log / Logger Setting
model_save_path = './results/logs/'
os.makedirs(model_save_path, exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(message)s',
    filename=os.path.join(model_save_path, 'train.log'),
    filemode='w'
)
logger = logging.getLogger()

# 1. Model Config & Hyperparameter Setting
csdi_config = {
    "num_steps": 1000,
    "channels": 256,
    "diffusion_embedding_dim": 256,
    "nheads": 4,
    "layers": 5,
    # "side_dim": 128
    "side_dim": 512
}
hyperparams = {
    'raw_data_path': "idsse-data", # raw_data_path = "Download raw file path"
    'data_save_path': "match_data",
    'train_batch_size': 16,
    'val_batch_size': 16,
    'test_batch_size': 16,
    'num_workers': 8,
    'epochs': 50,
    'learning_rate': 1e-4,
    'self_conditioning_ratio': 0.5,
    'num_samples': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    'ddim_step': 200,
    'eta': 0.1,
    **csdi_config
}
raw_data_path = hyperparams['raw_data_path']
data_save_path = hyperparams['data_save_path']
train_batch_size = hyperparams['train_batch_size']
val_batch_size = hyperparams['val_batch_size']
test_batch_size = hyperparams['test_batch_size']
num_workers = hyperparams['num_workers']
epochs = hyperparams['epochs']
learning_rate = hyperparams['learning_rate']
self_conditioning_ratio = hyperparams['self_conditioning_ratio']
num_samples = hyperparams['num_samples']
device = hyperparams['device']
ddim_step = hyperparams['ddim_step']
eta = hyperparams['eta']
side_dim = hyperparams['side_dim']

logger.info(f"Hyperparameters: {hyperparams}")

# 2. Data Loading
print("---Data Loading---")
if not os.path.exists(data_save_path) or len(os.listdir(data_save_path)) == 0:
    organize_and_process(raw_data_path, data_save_path)
else:
    print("Skip organize_and_process")

dataset = MultiMatchSoccerDataset(data_root=data_save_path)
train_idx, val_idx, test_idx = split_dataset_indices(dataset, val_ratio=1/6, test_ratio=1/6, random_seed=SEED)

train_dataloader = DataLoader(
    Subset(dataset, train_idx),
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn,
    generator=generator(SEED)
)

val_dataloader = DataLoader(
    Subset(dataset, val_idx),
    batch_size=val_batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn,
)

test_dataloader = DataLoader(
    Subset(dataset, test_idx),
    batch_size=test_batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=custom_collate_fn,
    worker_init_fn=worker_init_fn
)

print("---Data Load!---")

# 3. Model Define
# Extract node feature dimension
sample = dataset[0]
graph = build_graph_sequence_from_condition({
    "condition": sample["condition"],
    "condition_columns": sample["condition_columns"],
    "pitch_scale": sample["pitch_scale"]
}).to(device)

log_graph_stats(graph, logger, prefix="InitGraphSample")

in_dim = graph['Node'].x.size(1)

# Extract target's history trajectories from condition
condition_columns = sample["condition_columns"]
target_columns = sample["target_columns"]
target_idx = [condition_columns.index(col) for col in target_columns if col in condition_columns]

# graph_encoder = InteractionGraphEncoder(in_dim=in_dim, hidden_dim=128, out_dim=128, heads = 2).to(device)
graph_encoder = InteractionGraphEncoder(in_dim=in_dim, hidden_dim=side_dim // 2, out_dim=side_dim // 2).to(device)
history_encoder = TargetTrajectoryEncoder(num_layers=5, hidden_dim = side_dim // 4, bidirectional=True).to(device)
denoiser = diff_CSDI(csdi_config)
model = DiffusionTrajectoryModel(denoiser, num_steps=csdi_config["num_steps"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, threshold=1e-4)
scaler = GradScaler()

logger.info(f"Device: {device}")
logger.info(f"GraphEncoder: {graph_encoder}")
logger.info(f"HistoryEncoder: {history_encoder}")
logger.info(f"Denoiser (diff_CSDI): {denoiser}")
logger.info(f"DiffusionTrajectoryModel: {model}")

# 4. Train
best_state_dict = None
best_val_loss = float("inf")

train_losses = []
val_losses   = []

for epoch in tqdm(range(1, epochs + 1), desc="Training..."):
    model.train()
    train_noise_mse = 0
    train_noise_nll = 0
    train_frechet_loss
    # train_mse_loss = 0
    # train_dtw_loss = 0
    # train_fde_loss = 0
    train_loss = 0

    for batch in tqdm(train_dataloader, desc = "Batch Training..."):
        cond = batch["condition"].to(device)
        B, T, _ = cond.shape
        target = batch["target"].to(device).view(-1, T, 11, 2)  # [B, T, 11, 2]
        graph_batch = batch["graph"].to(device)                              # HeteroData batch

        # graph → H
        H = graph_encoder(graph_batch)                                       # [B, 128]
        cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        # Target's history trajectories
        hist = cond[:, :, target_idx].to(device) 
        hist_rep = history_encoder(hist)  # [B, 128]
        cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        # Concat conditions
        cond_info = torch.cat([cond_H, cond_hist], dim=1)
        # Preparing Self-conditioning data
        # timestep (consistency)
        t = torch.randint(0, model.num_steps, (target.size(0),), device=device)
        if torch.rand(1, device=device) < self_conditioning_ratio:
            with torch.no_grad():
                x_t, noise = model.q_sample(target, t)
                x_t = x_t.permute(0,3,2,1)
                
                z1 = model.model(x_t, t, cond_info, self_cond=None)
                
                eps_pred1 = z1[:, :2, :, :]
                a_hat = model.alpha_hat.to(device)[t].view(-1,1,1,1)
                x0_hat = (x_t - (1 - a_hat).sqrt() * eps_pred1) / a_hat.sqrt()
                x0_hat = x0_hat.permute(0,3,2,1)
                                
            s = x0_hat
        else:
            s = torch.zeros_like(target)

        
        # noise_loss, player_loss_mse, player_loss_frechet = model(target, cond_info=cond_info, self_cond=s)
        # loss = noise_loss + player_loss_mse + player_loss_frechet * 0.2
        noise_mse, noise_nll, player_frechet_loss = model(target, t=t, cond_info=cond_info, self_cond=s)
        loss = noise_mse + noise_nll * 0.001 + player_frechet_loss
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # train_noise_loss += (noise_loss).item()
        # train_mse_loss += (player_loss_mse).item()
        # train_dtw_loss += (player_loss_dtw).item()
        # train_fde_loss += (player_loss_fde).item()
        train_noise_mse += (noise_mse).item()
        train_noise_nll += (noise_nll * 0.001).item()
        train_frechet_loss += (player_frechet_loss).item()
        train_loss += loss.item()

    num_batches = len(train_dataloader)
    
    # avg_noise_loss = train_noise_loss / num_batches
    # # avg_mse_loss = train_mse_loss / num_batches
    # avg_dtw_loss = train_dtw_loss / num_batches
    # avg_fde_loss = train_fde_loss / num_batches
    avg_train_noise_mse = train_noise_mse / num_batches
    avg_train_noise_nll = train_noise_nll / num_batches
    avg_train_frechet_loss = train_frechet_loss / num_batches
    avg_train_loss = train_loss / num_batches


    # --- Validation ---
    model.eval()
    # val_noise_loss = 0
    # val_mse_loss = 0
    # val_dtw_loss = 0
    # val_fde_loss = 0
    val_noise_mse = 0
    val_noise_nll = 0
    val_frechet_loss = 0
    val_total_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_dataloader, desc="Validation"):
            cond = batch["condition"].to(device)
            B, T, _ = cond.shape
            target = batch["target"].to(device).view(-1, T, 11, 2)  # [B, T, 11, 2]
            graph_batch = batch["graph"].to(device)                              # HeteroData batch

            # graph → H
            H = graph_encoder(graph_batch)                                       # [B, 128]
            cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
            
            # Target's history trajectories
            hist = cond[:, :, target_idx].to(device)  # [B,128,11,T]
            hist_rep = history_encoder(hist)  # [B, 128]
            cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
            
            # Concat conditions
            cond_info = torch.cat([cond_H, cond_hist], dim=1)
            
            s = torch.zeros_like(target)
            
            # noise_loss, player_loss_mse, player_loss_frechet = model(target, cond_info=cond_info, self_cond=s)
            # val_loss = noise_loss + player_loss_mse + player_loss_frechet * 0.2
            noise_mse, noise_nll, player_frechet_loss = model(target, cond_info=cond_info, self_cond=s)
            val_loss = noise_mse + noise_nll * 0.001 + player_frechet_loss
        
            # val_noise_loss += (noise_loss).item()
            # # val_mse_loss += (player_loss_mse).item()
            # val_dtw_loss += (player_loss_dtw).item()
            # val_fde_loss += (player_loss_fde).item()
            val_noise_mse += (noise_mse).item()
            val_noise_nll += (noise_nll * 0.001).item()
            val_frechet_loss += (player_frechet_loss).item()
            val_total_loss += val_loss.item()

    # avg_val_noise_loss = val_noise_loss / len(val_dataloader)
    # # avg_val_mse_loss = val_mse_loss / len(val_dataloader)
    # avg_val_dtw_loss = val_dtw_loss / len(val_dataloader)
    # avg_val_fde_loss = val_fde_loss / len(val_dataloader)
    # avg_val_loss = val_total_loss / len(val_dataloader)
    
    num_batches = len(val_dataloader)
    
    avg_val_noise_mse = val_noise_mse / num_batches
    avg_val_noise_nll = val_noise_nll / num_batches
    avg_val_frechet_loss = val_frechet_loss / num_batches
    avg_val_loss = val_total_loss / num_batches
  
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    
    current_lr = scheduler.get_last_lr()[0]
    logger.info(f"[Epoch {epoch}/{epochs}] Train Loss={avg_train_loss:.6f} (Noise simple={avg_train_noise_mse:.6f}, Noise NLL={avg_train_noise_nll:.6f}, Frechet={avg_train_frechet_loss:.6f}) |"
                f"Val Loss={avg_val_loss:.6f} | LR={current_lr:.6e}")
    
    tqdm.write(f"[Epoch {epoch}]\n"
               f"[Train] Cost: {avg_train_loss:.6f} | Noise Loss: {avg_train_noise_mse:.6f} | NLL Loss: {avg_train_noise_nll:.6f} | Frechet: {avg_train_frechet_loss:.6f}) | LR: {current_lr:.6f}\n"
               f"[Validation] Val Loss: {avg_val_loss:.6f} | Noise Loss: {avg_val_noise_mse:.6f} | NLL Loss: {avg_val_noise_nll:.6f} | Frechet: {avg_val_frechet_loss:.6f}) |")
    
    scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_state_dict = model.state_dict()

logger.info(f"Training complete. Best val loss: {best_val_loss:.6f}")
        
# 4-1. Plot learning_curve
plt.figure(figsize=(8, 6))
plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, epochs+1), val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f"Train & Validation Loss, {csdi_config['num_steps']} steps, {csdi_config['channels']} channels, "
          f"{csdi_config['diffusion_embedding_dim']} embedding dim, {csdi_config['nheads']} heads, {csdi_config['layers']} layers "
          f"self-conditioning ratio: {self_conditioning_ratio}")
plt.legend()
plt.tight_layout()

plt.savefig('results/0512_diffusion_lr_curve.png')

plt.show()


# 5. Inference (Best-of-N Sampling) & Visualization
model.load_state_dict(best_state_dict)
model.eval()
all_best_ades_test = []
all_best_fdes_test = []
visualize_samples = 5
visualized = False

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Test Streaming Inference"):
        cond = batch["condition"].to(device)
        B, T, _ = cond.shape
        target = batch["target"].to(device).view(B, T, 11, 2)

        graph_batch  = batch["graph"].to(device)
        H = graph_encoder(graph_batch)
        cond_H = H.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        
        hist = cond[:, :, target_idx].to(device)
        hist_rep = history_encoder(hist)
        cond_hist = hist_rep.unsqueeze(-1).unsqueeze(-1).expand(-1, H.size(1), 11, T)
        cond_info = torch.cat([cond_H, cond_hist], dim=1)

        best_ade_t = torch.full((B,), float("inf"), device=device)
        best_pred_t = torch.zeros_like(target)
        best_fde_t = torch.full((B,), float("inf"), device=device)
                    
        scales = torch.tensor(batch["pitch_scale"], device=device, dtype=torch.float32)  
        scales = scales.view(B, 1, 1, 2)

        for _ in tqdm(range(num_samples), desc="Generating..."):
            pred_i = model.generate(shape=target.shape, cond_info=cond_info, ddim_steps=ddim_step, eta=eta, num_samples=1)[0]

            pred_i_den = pred_i * scales
            target_den = target * scales
            
            ade_i = ((pred_i_den - target_den)**2).sum(-1).sqrt().mean((1,2))
            fde_i = ((pred_i_den[:,-1] - target_den[:,-1])**2).sum(-1).sqrt().mean(1)
            
            better = ade_i < best_ade_t
            
            best_pred_t[better] = pred_i_den[better]
            best_ade_t[better] = ade_i[better]
            best_fde_t[better] = fde_i[better]

        all_best_ades_test.extend(best_ade_t.cpu().tolist())
        all_best_fdes_test.extend(best_fde_t.cpu().tolist())

        # Visualization
        if not visualized:
            base_dir = "results/test_trajs"
            os.makedirs(base_dir, exist_ok=True)
            for i in range(min(B, visualize_samples)):
                sample_dir = os.path.join(base_dir, f"sample{i:02d}")
                os.makedirs(sample_dir, exist_ok=True)
                
                other_cols  = batch["other_columns"][i]
                target_cols = batch["target_columns"][i]
                defender_nums = [int(col.split('_')[1]) for col in target_cols[::2]]

                others_seq = batch["other"][i].view(T, 12, 2).cpu().numpy()
                target_traj = target_den[i].cpu().numpy()
                pred_traj = best_pred_t[i].cpu().numpy()

                for idx, jersey in enumerate(defender_nums):
                    save_path = os.path.join(sample_dir, f"player_{jersey:02d}.png")
                    plot_trajectories_on_pitch(others_seq, target_traj, pred_traj,
                                               other_columns=other_cols, target_columns=target_cols,
                                               player_idx=idx, annotate=True, save_path=save_path)

            visualized = True
            print(all_best_ades_test)
            print(all_best_fdes_test)
avg_test_ade = np.mean(all_best_ades_test)
avg_test_fde = np.mean(all_best_fdes_test)
print(f"[Test Best-of-{num_samples}] Average ADE: {avg_test_ade:.4f} | Average FDE: {avg_test_fde:.4f}")
print(f"[Test Best-of-{num_samples}] Best ADE overall: {min(all_best_ades_test):.4f} | Best FDE overall: {min(all_best_fdes_test):.4f}")
