In [1]:
%run models.py

This is models.py. It contains dataset and model definitions for the project.


In [22]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.nn import DataParallel
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import os
from models import UnlabelledJetDataset, LabelledJetDataset, VAE, Classifier, vae_loss

In [23]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load datasets
unlabelled_dataset = UnlabelledJetDataset('../Dataset_Specific_Unlabelled.h5')
unlabelled_loader = DataLoader(unlabelled_dataset, batch_size=256, shuffle=True, num_workers=16)

labelled_dataset = LabelledJetDataset('../Dataset_Specific_labelled.h5')
train_size = int(0.8 * len(labelled_dataset))
test_size = len(labelled_dataset) - train_size
train_labelled, test_labelled = random_split(labelled_dataset, [train_size, test_size])
labelled_train_loader = DataLoader(train_labelled, batch_size=64, shuffle=True, num_workers=4)
labelled_test_loader = DataLoader(test_labelled, batch_size=64, shuffle=False, num_workers=4)

Using device: cuda


In [24]:
# Ensure ./res/ directory exists
os.makedirs('./res/best', exist_ok=True)

In [25]:
# Initialize models
vae = VAE(latent_dim=128)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    vae = DataParallel(vae)
vae.to(device)

vae_core = vae.module if isinstance(vae, DataParallel) else vae
classifier = Classifier(vae_core).to(device)
if torch.cuda.device_count() > 1:
    classifier = DataParallel(classifier)

Using 4 GPUs


In [26]:
# Optimizers
optimizer_vae = optim.Adam(vae.parameters(), lr=1e-3, weight_decay=1e-5)
optimizer_cls = optim.Adam([
    {'params': classifier.module.vae.parameters() if hasattr(classifier, 'module') else classifier.vae.parameters(), 'lr': 1e-4},
    {'params': classifier.module.classifier_net.parameters() if hasattr(classifier, 'module') else classifier.classifier_net.parameters(), 'lr': 1e-3}
])

# Schedulers
scheduler_vae = optim.lr_scheduler.ReduceLROnPlateau(optimizer_vae, 'min', patience=3, factor=0.5)
scheduler_cls = optim.lr_scheduler.ReduceLROnPlateau(optimizer_cls, 'max', patience=2)

# Training loop
best_auc = 0
best_loss = float('inf')
patience_counter = 0
patience = 10

for epoch in range(30):
    # VAE Training
    vae.train()
    total_vae_loss = 0
    total_recon = 0
    total_kl = 0
    beta = min(1.0, epoch / 20)  # KL annealing

    for batch in tqdm(unlabelled_loader, desc=f"VAE Epoch {epoch+1}"):
        batch = batch.to(device)
        optimizer_vae.zero_grad()
        recon, mu, logvar = vae(batch)
        loss = vae_loss(recon, batch, mu, logvar, beta)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
        optimizer_vae.step()
        total_vae_loss += loss.item()
        total_recon += nn.functional.mse_loss(recon, batch, reduction='sum').item()
        total_kl += -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()).item()

    avg_vae_loss = total_vae_loss / len(unlabelled_loader.dataset)
    scheduler_vae.step(avg_vae_loss)

    # Classifier Training
    classifier.train()
    total_cls_loss = 0
    for inputs, labels in tqdm(labelled_train_loader, desc="Classifier Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.view(-1, 1).float()
        optimizer_cls.zero_grad()
        outputs = classifier(inputs)
        loss = nn.functional.binary_cross_entropy_with_logits(outputs, labels)
        loss.backward()
        optimizer_cls.step()
        total_cls_loss += loss.item()

    # Classifier Evaluation
    classifier.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in labelled_test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = classifier(inputs)
            all_preds.extend(torch.sigmoid(outputs).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_auc = roc_auc_score(all_labels, all_preds)
    scheduler_cls.step(test_auc)

    # Print metrics
    print(f"\nEpoch {epoch+1} Results:")
    print(f"VAE Loss: {avg_vae_loss:.4f} | Recon: {total_recon/len(unlabelled_loader.dataset):.4f} | KL: {total_kl/len(unlabelled_loader.dataset):.4f}")
    print(f"Classifier Loss: {total_cls_loss/len(labelled_train_loader):.4f} | AUC: {test_auc:.4f}")

    # Save best models
    if test_auc > best_auc:
        best_auc = test_auc
        torch.save(classifier.state_dict(), './res/best/best_classifier.pth')
    if avg_vae_loss < best_loss:
        best_loss = avg_vae_loss
        patience_counter = 0
        torch.save(vae.state_dict(), './res/best/best_vae.pth')
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch+1}")
        break

    if (epoch+1) % 5 == 0:
        torch.save(vae.state_dict(), f'./res/vae_epoch_{epoch+1}.pth')
        torch.save(classifier.state_dict(), f'./res/classifier_epoch_{epoch+1}.pth')

print("\nTraining completed!")
print(f"Best VAE Loss: {best_loss:.4f}")
print(f"Best Classifier AUC: {best_auc:.4f}")

VAE Epoch 1: 100%|██████████| 235/235 [00:16<00:00, 14.32it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 31.62it/s]



Epoch 1 Results:
VAE Loss: 8060.7903 | Recon: 8060.7903 | KL: 24866.1203
Classifier Loss: 0.6536 | AUC: 0.8244


VAE Epoch 2: 100%|██████████| 235/235 [00:16<00:00, 14.43it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.12it/s]



Epoch 2 Results:
VAE Loss: 144.3552 | Recon: 141.3272 | KL: 60.5601
Classifier Loss: 0.5573 | AUC: 0.8711


VAE Epoch 3: 100%|██████████| 235/235 [00:16<00:00, 14.21it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 32.67it/s]



Epoch 3 Results:
VAE Loss: 125.3197 | Recon: 122.8775 | KL: 24.4221
Classifier Loss: 0.4977 | AUC: 0.8785


VAE Epoch 4: 100%|██████████| 235/235 [00:16<00:00, 14.42it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.78it/s]



Epoch 4 Results:
VAE Loss: 124.3792 | Recon: 121.4913 | KL: 19.2521
Classifier Loss: 0.4812 | AUC: 0.8728


VAE Epoch 5: 100%|██████████| 235/235 [00:17<00:00, 13.75it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 31.80it/s]



Epoch 5 Results:
VAE Loss: 124.3933 | Recon: 121.1333 | KL: 16.2996
Classifier Loss: 0.4798 | AUC: 0.8680


VAE Epoch 6: 100%|██████████| 235/235 [00:16<00:00, 14.48it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.54it/s]



Epoch 6 Results:
VAE Loss: 124.7444 | Recon: 121.1875 | KL: 14.2277
Classifier Loss: 0.4898 | AUC: 0.8686


VAE Epoch 7: 100%|██████████| 235/235 [00:16<00:00, 14.15it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.27it/s]



Epoch 7 Results:
VAE Loss: 125.1322 | Recon: 121.3195 | KL: 12.7089
Classifier Loss: 0.5051 | AUC: 0.8388


VAE Epoch 8: 100%|██████████| 235/235 [00:16<00:00, 13.91it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.86it/s]



Epoch 8 Results:
VAE Loss: 125.5289 | Recon: 121.5027 | KL: 11.5032
Classifier Loss: 0.5131 | AUC: 0.8268


VAE Epoch 9: 100%|██████████| 235/235 [00:16<00:00, 14.33it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.99it/s]



Epoch 9 Results:
VAE Loss: 125.6647 | Recon: 121.4670 | KL: 10.4943
Classifier Loss: 0.5112 | AUC: 0.8263


VAE Epoch 10: 100%|██████████| 235/235 [00:16<00:00, 14.56it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 34.06it/s]



Epoch 10 Results:
VAE Loss: 126.0195 | Recon: 121.6249 | KL: 9.7659
Classifier Loss: 0.5260 | AUC: 0.8096


VAE Epoch 11: 100%|██████████| 235/235 [00:16<00:00, 14.38it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 34.02it/s]



Epoch 11 Results:
VAE Loss: 126.3912 | Recon: 121.8292 | KL: 9.1240
Classifier Loss: 0.5367 | AUC: 0.8000


VAE Epoch 12: 100%|██████████| 235/235 [00:16<00:00, 14.55it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.16it/s]



Epoch 12 Results:
VAE Loss: 126.7694 | Recon: 122.0714 | KL: 8.5418
Classifier Loss: 0.5388 | AUC: 0.7946


VAE Epoch 13: 100%|██████████| 235/235 [00:16<00:00, 14.12it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.63it/s]



Epoch 13 Results:
VAE Loss: 127.0074 | Recon: 122.1796 | KL: 8.0463
Classifier Loss: 0.5565 | AUC: 0.7829


VAE Epoch 14: 100%|██████████| 235/235 [00:16<00:00, 14.35it/s]
Classifier Training: 100%|██████████| 125/125 [00:03<00:00, 33.29it/s]



Epoch 14 Results:
VAE Loss: 127.3251 | Recon: 122.3537 | KL: 7.6483
Classifier Loss: 0.5466 | AUC: 0.7792
Early stopping at epoch 14

Training completed!
Best VAE Loss: 124.3792
Best Classifier AUC: 0.8785
