In [1]:
import os
import os.path as osp

from datetime import datetime
import time
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_geometric.loader.dataloader import DataLoader
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
import torch.multiprocessing

from torch_geometric.data import Data
import torch_geometric
import networkx as nx


from sklearn.model_selection import train_test_split
import joblib

from GNNDataset import GNNDataset
from ClusterDataset import ClusterDataset
from ClusterDatasetBuilder import ClusterDatasetBuilder
from train_transformer import *
from data_statistics import *
from GNN_TrackLinkingNet import EarlyStopping, weight_init

from IPython.display import display

from Transformer import Transformer
from lang import Lang
from LossFunction import Loss

In [2]:
# CUDA Setup
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

print(f"Using device: {device}, number of devices: {torch.cuda.device_count()}")

Using device: cuda:0, number of devices: 1


In [3]:
input_length = 60
max_seq_length = 60
batch_size = 64
converter = Lang(0)

In [4]:
torch.multiprocessing.set_start_method('spawn')

In [5]:
# Load the dataset
model_folder = "/eos/user/c/czeh/"
hist_folder = "/eos/user/c/czeh/histo_10pion0PU/"
data_folder_training = "/eos/user/c/czeh/graph_data/processed"
store_folder_training = "/eos/user/c/czeh/graph_data_trans"
data_folder_test = "/eos/user/c/czeh/graph_data_test/processed"
store_folder_test = "/eos/user/c/czeh/graph_data_trans_test"

scaler = joblib.load("/eos/user/c/czeh/graph_data/scaler.joblib")
scale = torch.tensor(scaler.scale_).to(device)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [None]:
testBuilder = ClusterDatasetBuilder(store_folder_test, data_folder_test, input_length=input_length)

if not testBuilder.metadataExists():
    testBuilder.generate(24, device)

Processing...
 11%|█         | 43/400 [04:40<27:06,  4.56s/it] 

In [None]:
trainBuilder = ClusterDatasetBuilder(store_folder_training, data_folder_training, input_length=input_length)

if not trainBuilder.metadataExists():
    trainBuilder.generate(24, device)

In [None]:
dataset_training = ClusterDataset(store_folder_training, input_length=input_length, scale=scale, output_group=False)
dataset_test = ClusterDataset(store_folder_test, input_length=input_length, scale=scale, output_group=False, num_workers=25)

In [None]:
epochs = 100
num_heads = 2
num_layers = 3
d_model = 128
d_ff = 256
dropout = 0.2
padding = converter.word2index["<PAD>"]
feature_num = len(dataset_test.model_feature_keys)
max_nodes = max(dataset_test.max_nodes, dataset_training.max_nodes)
vocab_size = max_nodes + 4


# Model, loss, and optimizer
model = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, feature_num, max_nodes, max_seq_length, dropout)
weight_init(model)
criterion = Loss(converter, vocab_size, device=device)

In [None]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.parallel.DistributedDataParallel(model)
    
model.to(device)

In [None]:
train_dl = DataLoader(dataset_training, shuffle=True, batch_size=batch_size)
test_dl = DataLoader(dataset_test, shuffle=True, batch_size=batch_size)

In [None]:
# Optionally introduce weight decay
# optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

# Drop Step Size over time
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
early_stopping = EarlyStopping(patience=5, delta=-0.02)

In [None]:
#Load Weights if needed
# weights = torch.load("/eos/user/c/czeh/tranformer_2.pt", weights_only=True)
# model.load_state_dict(weights["model_state_dict"])
# optimizer.load_state_dict(weights["optimizer_state_dict"])
# start_epoch = weights["epoch"]

In [None]:
train_loss_hist = []
val_loss_hist = []

In [None]:
# https://stats.stackexchange.com/questions/352036/what-should-i-do-when-my-neural-network-doesnt-learn
# Optionally introduce gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.25)

fig_loss, ax_loss = plt.subplots(1, 1)
fig_loss.set_figwidth(6)
fig_loss.set_figheight(3)

display_loss = display(1, display_id=True)

optimizer.zero_grad()

# Training loop
for epoch in range(1, 101):
    print(f'Epoch: {epoch}')
    
    loss = train(model, optimizer, test_dl, epoch, criterion, vocab_size, device=device)
    print(f"Training loss: {loss}")
    train_loss_hist.append(loss)
    
    val_loss = test(model, test_dl, epoch, criterion, vocab_size, device=device)
    val_loss_hist.append(val_loss)
    print(f"Validation loss: {val_loss}")
    
    ax_loss.clear()
    plot_loss(train_loss_hist, val_loss_hist, ax=ax_loss, n=0)
    display_loss.update(fig_loss)
    time.sleep(1)
    
    scheduler.step()
    print(f"Epoch {epoch}, LR: {scheduler.get_last_lr()[0]}")
    
    early_stopping(model, val_loss)
    if early_stopping.early_stop:
        print(f"Early stopping after {epoch+1} epochs with best score {early_stopping.best_score}")
        early_stopping.load_best_model(model)
#         break

In [None]:
fig, ax = plt.subplots(1, 1)
fig.set_figheight(6)
fig.set_figwidth(10)
epochs = len(train_loss_hist)
ax.plot(range(1, epochs+1), moving_average(train_loss_hist, 8), label='train', linewidth=2)
ax.plot(range(1, epochs+1), moving_average(val_loss_hist, 8), label='val', linewidth=2)
ax.set_ylabel("Loss", fontsize=14)
ax.set_xlabel("Epochs", fontsize=14)
ax.set_title("Training and Validation Loss", fontsize=14)
ax.legend()

In [None]:
date = f"{datetime.now():%Y-%m-%d-%H}:00"
save_model(model, epoch, optimizer, train_loss_hist, val_loss_hist, model_folder, f"tranformer_date_{date}.pt")

## Test Full Event

In [None]:
from EventGrouping import EventGrouping

In [None]:
model = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, feature_num, max_nodes, max_seq_length, dropout).to(device)
weights = torch.load(osp.join(model_folder, "tranformer_date_2025-06-02-16:00.pt"), weights_only=True)
model.load_state_dict(weights["model_state_dict"])

In [None]:
components = dataset_training.get(0)
print(len(components))

In [None]:
runner = EventGrouping(model, seq_length=input_length)

nTrackster = 0
for component in components:
    max_comp_t = int(torch.max(component["x"]).item())
    if max_comp_t > nTrackster:
        nTrackster = max_comp_t
  
group = 0
edges = np.full(nTrackster, -1)

    
for component in components:
    converter = Lang(trackster_list=component["lang"])
    print("Goal", component["seq"], converter.word2index[component["root"]])
    res = runner(component)[-1]
    print(runner(component)[-1])
    
    new_groups = converter.seq2y(res.cpu().numpy(), nodes=nTrackster, start_group=group)
    print(np.array(range(new_groups.shape[0]))[new_groups >= 0])
    edges = np.maximum(edges, converter.seq2y(res.cpu().numpy(), nodes=nTrackster, start_group=group))
    print(np.max(edges))
    group = np.max(edges) + 1

In [None]:
edges[edges>=0]

## Random Tests

In [None]:
model = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, feature_num, max_nodes, max_seq_length, dropout).to(device)
weights = torch.load("/eos/user/c/czeh/tranformer_date_2025-06-02-16:00.pt", weights_only=True)
model.load_state_dict(weights["model_state_dict"])

In [None]:
components = dataset_training.get(0)
components[0]["lang"]

In [None]:
dataset_training.__getitem__(0)

In [None]:
num_nodes = components[0]["nTrackster"]
converter = Lang(trackster_list=components[0]["lang"])
sample_seq = converter.starting_seq(components[0]["root"], input_length).to(device)
print(sample_seq)

X = components[0]["x"].float()
X /= scale
X = F.pad(X, pad=(0, 0, max_nodes - num_nodes, 0), value=converter.word2index["<PAD>"])
X = X[:, list(map(dataset_test.node_feature_dict.get, dataset_test.model_feature_keys))]

predictions = model(torch.unsqueeze(X, dim=0), torch.unsqueeze(sample_seq, dim=0))
predicted_index = torch.argsort(-predictions[0, -1, :num_nodes], dim=0)
print(predicted_index)

In [None]:
targets[targets[:, -1] != -4, :]

In [None]:
targets[mask].shape[0]/3

In [None]:
opts = dataset_training.__getitem__(0)[1]
opts = torch.roll(opts, -1, dims=0)
opts[-1] = 5
opts

In [None]:
out_mask = opts != -4
opts[out_mask].shape[0]

In [None]:
targets = torch.reshape(targets[mask], (int(targets[mask].shape[0]/3), 3))

In [None]:
targets[0, :]