In [None]:
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from model import *

print(f"PyTorch : {torch.__version__}")
print(f"CUDA    : {torch.cuda.is_available()}")
print(f"GPU     : {torch.cuda.get_device_name(0)}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with h5py.File("dataset/Dataset_Specific_Unlabelled.h5", "r") as f:
    for key in f.keys():
        print(key, f[key].shape)



PyTorch : 2.12.0.dev20260217+cu128
CUDA    : True
GPU     : NVIDIA GeForce RTX 5060 Laptop GPU
jet (60000, 125, 125, 8)


In [None]:
with h5py.File("dataset/Dataset_Specific_labelled.h5", "r") as f:
    for key in f.keys():
        print(key, f[key].shape, f[key].dtype)

Y (10000, 1) float32
jet (10000, 125, 125, 8) float32


In [None]:
with h5py.File("dataset/Dataset_Specific_Labelled.h5", "r") as f:
    print("X shape:", f["jet"].shape)
    print("Y shape:", f["Y"].shape)
    print("Y dtype:", f["Y"].dtype)
    y_sample = f["Y"][:100].flatten().tolist() 
    print("Y unique values:", set(y_sample))
    print("Y min:", min(y_sample))
    print("Y max:", max(y_sample))

X shape: (10000, 125, 125, 8)
Y shape: (10000, 1)
Y dtype: float32
Y unique values: {0.0, 1.0}
Y min: 0.0
Y max: 1.0


In [None]:
# %%
# ─── Classifier Model ─────────────────────────────────────────────────────────
class SparseClassifier(nn.Module):
    def __init__(self, in_channels=8):
        super().__init__()

        #Encoder
        self.enc1  = SubmanifoldSparseConv2d(in_channels, 32,  3)
        self.enc2  = SubmanifoldSparseConv2d(32,          64,  3)
        self.down1 = StridedSparseConv2d    (64,          64,  3, stride=2)
        self.enc3  = SubmanifoldSparseConv2d(64,          128, 3)
        self.down2 = StridedSparseConv2d    (128,         128, 3, stride=2)

        #Binary Classification
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  
            nn.Flatten(),        
            nn.Dropout(0.9),    
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.9),
            nn.Linear(64, 32),    
            nn.ReLU(),
            nn.Dropout(0.9),
            nn.Linear(32, 16),    
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Dropout(0.9),
            nn.Linear(8, 1)          
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, mask):
        # Encoder (frozen or fine-tuned)
        z, m = self.enc1(x, mask);  z = self.relu(z)
        z, m = self.enc2(z, m);     z = self.relu(z)
        z, m = self.down1(z, m);    z = self.relu(z)
        z, m = self.enc3(z, m);     z = self.relu(z)
        z, m = self.down2(z, m);    z = self.relu(z)

        z = z * m.float()

        logit = self.classifier(z) 
        return logit.squeeze(1)     


def load_pretrained_encoder(classifier, autoencoder_path):

    checkpoint = torch.load(autoencoder_path, map_location=device)
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    # Only load encoder layers
    encoder_keys = ['enc1', 'enc2', 'down1', 'enc3', 'down2']
    model_dict   = classifier.state_dict()

    pretrained = {k: v for k, v in state_dict.items()
                  if any(k.startswith(key) for key in encoder_keys)}

    model_dict.update(pretrained)
    classifier.load_state_dict(model_dict)

    loaded = list(pretrained.keys())
    print(f"Loaded {len(loaded)} pretrained encoder weights")
    return classifier

In [5]:
class H5LabelledDataset(Dataset):
    def __init__(self, file_path, x_key="jet", y_key="Y", threshold=0.0):
        self.file_path = file_path
        self.x_key     = x_key
        self.y_key     = y_key
        self.threshold = threshold
        with h5py.File(file_path, "r") as f:
            self.length = len(f[x_key])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.file_path, "r") as f:
            x = torch.tensor(f[self.x_key][idx], dtype=torch.float32)
            y = torch.tensor(f[self.y_key][idx], dtype=torch.float32).squeeze()  # (1,) → scalar
        x    = x.permute(2, 0, 1)
        mask = (x.abs().sum(dim=0, keepdim=True) > self.threshold)
        return x, mask, y

In [None]:
#Fine-tuning Loop

def finetune(autoencoder_path, labelled_file, 
             n_epochs=50, freeze_epochs=10):


    dataset = H5LabelledDataset(labelled_file)

    val_size   = int(0.2 * len(dataset))
    train_size = len(dataset) - val_size
    train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size],generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(train_ds, batch_size=100, shuffle=True,  num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=100, shuffle=True, num_workers=0)

    print(f"Train: {train_size} | Val: {val_size}")

    # Model 
    classifier = SparseClassifier(in_channels=8).to(device)
    classifier = load_pretrained_encoder(classifier, autoencoder_path)

    #Loss
    criterion = nn.BCEWithLogitsLoss() 

    for epoch in range(n_epochs):

        if epoch < freeze_epochs:
            for name, param in classifier.named_parameters():
                if any(name.startswith(k) for k in ['enc1','enc2','down1','enc3','down2']):
                    param.requires_grad = False
            optimizer = optim.SGD(
                filter(lambda p: p.requires_grad, classifier.parameters()),
                lr=1e-2, momentum=0.9, weight_decay=1e-4
            )
            if epoch == 0:
                print(f"Epochs 1-{freeze_epochs}: Encoder frozen, training head only")

        elif epoch == freeze_epochs:

            for param in classifier.parameters():
                param.requires_grad = True
            optimizer = optim.Adam ([
                {'params': classifier.enc1.parameters(),       'lr': 5e-4},
                {'params': classifier.enc2.parameters(),       'lr': 5e-4},
                {'params': classifier.down1.parameters(),      'lr': 5e-4},
                {'params': classifier.enc3.parameters(),       'lr': 5e-4},
                {'params': classifier.down2.parameters(),      'lr': 5e-4},
                {'params': classifier.classifier.parameters(), 'lr': 5e-4   },
            ], weight_decay=5e-4)
            scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
            print(f"\nEpoch {epoch+1}: Encoder unfrozen, full fine-tuning")

        # Train
        classifier.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        for X, mask, y in train_loader:
            X, mask, y = X.to(device), mask.to(device), y.to(device)

            optimizer.zero_grad()
            logits = classifier(X, mask)      
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            preds         = (logits > 0).float() 
            train_correct += (preds == y).sum().item()
            train_total   += y.size(0)
            train_loss    += loss.item()

        # Validate 
        classifier.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for X, mask, y in val_loader:
                X, mask, y = X.to(device), mask.to(device), y.to(device)
                logits      = classifier(X, mask)
                loss        = criterion(logits, y)
                preds       = (logits > 0).float()
                val_correct += (preds == y).sum().item()
                val_total   += y.size(0)
                val_loss    += loss.item()

        if epoch >= freeze_epochs:
            scheduler.step()

        train_acc = train_correct / train_total * 100
        val_acc   = val_correct   / val_total   * 100

        print(f"Epoch {epoch+1:02d} | "
              f"Train Loss {train_loss/len(train_loader):.4f} | "
              f"Train Acc {train_acc:.2f}% | "
              f"Val Loss {val_loss/len(val_loader):.4f} | "
              f"Val Acc {val_acc:.2f}%")

    # Save
    torch.save({
        'model_state_dict': classifier.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': n_epochs,
    }, "models/sparse_classifier.pth")
    print("\nClassifier saved → sparse_classifier.pth")

    return classifier




classifier = finetune(
    autoencoder_path = "sparse_autoencoder_checkpoint_2.pth",
    labelled_file    = "Dataset_Specific_Labelled.h5",
    n_epochs         = 50,
    freeze_epochs    = 10   # freeze encoder for first 10 epochs
)

Train: 8000 | Val: 2000
Loaded 30 pretrained encoder weights
Epochs 1-10: Encoder frozen, training head only
Epoch 01 | Train Loss 0.7033 | Train Acc 50.31% | Val Loss 0.6934 | Val Acc 50.75%
Epoch 02 | Train Loss 0.6927 | Train Acc 51.08% | Val Loss 0.6932 | Val Acc 50.75%
Epoch 03 | Train Loss 0.6937 | Train Acc 51.19% | Val Loss 0.6930 | Val Acc 50.75%
Epoch 04 | Train Loss 0.6934 | Train Acc 51.10% | Val Loss 0.6930 | Val Acc 50.75%
Epoch 05 | Train Loss 0.6934 | Train Acc 51.02% | Val Loss 0.6931 | Val Acc 50.75%
Epoch 06 | Train Loss 0.6931 | Train Acc 51.24% | Val Loss 0.6931 | Val Acc 50.75%
Epoch 07 | Train Loss 0.6932 | Train Acc 51.18% | Val Loss 0.6931 | Val Acc 50.75%
Epoch 08 | Train Loss 0.6927 | Train Acc 51.60% | Val Loss 0.6930 | Val Acc 50.75%
Epoch 09 | Train Loss 0.6930 | Train Acc 51.10% | Val Loss 0.6931 | Val Acc 50.75%
Epoch 10 | Train Loss 0.6927 | Train Acc 51.52% | Val Loss 0.6930 | Val Acc 50.75%

Epoch 11: Encoder unfrozen, full fine-tuning
Epoch 11 | Trai