In [1]:
%load_ext autoreload
%autoreload 2
import yaml
import os
import torch
import numpy as np
import uproot
import matplotlib.pyplot as plt
from src.config import Config
from src.dataset import Dataset 
from src.visualiser import Visualiser
from src.models import SimCLRModel

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import importlib
import src.visualiser as visualiser_module
importlib.reload(visualiser_module)
from src.visualiser import Visualiser

pth = "cfg/default.yaml"
cfg = Config(pth)

print(cfg.as_dict())
print(dir(Visualiser))

{'train': {'n_epochs': 20, 'batch_size': 8, 'learning_rate': 0.001, 'checkpoint_directory': './chk', 'input_channels': 3, 'numer_classes': 6, 'segmentation_classes': 5, 'filters': 32, 'dropout': 0.1, 'feature_dimensions': 128, 'optimiser': 'Adam', 'weight_decay': 0.0001}, 'dataset': {'path': '/gluster/data/dune/niclane/nlane_prod_strange_resample_fhc_run2_fhc_reco2_reco2_trainingimage_signal_lambdamuon_100_ana.root', 'tree': 'imageanalyser/ImageTree', 'width': 512, 'height': 512, 'planes': ['U', 'V', 'W'], 'induction_plane': 2}}
['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_get_random', 'visualise_input_event', 'visualise_overlay_event', 'visualise_truth_event']


In [2]:
ds = Dataset(cfg)
vis = Visualiser(cfg)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = DataLoader(ds, batch_size=16, shuffle=True, num_workers=4)


In [4]:
model = SimCLRModel(in_channels=1, feature_dim=128, projection_hidden_dim=512, projection_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [5]:
def multi_positive_info_nce_loss(features, num_views=3, temperature=0.5):
    features = F.normalize(features, dim=1)
    batch_size = features.shape[0] // num_views
    loss = 0.0
    total_count = 0
    similarity_matrix = torch.matmul(features, features.T)
    for i in range(batch_size):
        indices = torch.arange(i * num_views, (i + 1) * num_views, device=features.device)
        for anchor in indices:
            positives = indices[indices != anchor]
            pos_sim = torch.exp(similarity_matrix[anchor, positives] / temperature).sum()
            mask = torch.ones(features.shape[0], dtype=torch.bool, device=features.device)
            mask[indices] = False
            neg_sim = torch.exp(similarity_matrix[anchor][mask] / temperature).sum()
            loss += -torch.log(pos_sim / (pos_sim + neg_sim))
            total_count += 1
    return loss / total_count

In [6]:
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    num_batches = 0
    print("Starting epoch", epoch)
    for batch_idx, (images, _, _, _, _) in enumerate(train_loader):
        B, planes, H, W = images.shape
        print(f"Batch {batch_idx}: Original shape = {images.shape}")
        if planes < 3:
            print(f"Batch {batch_idx}: Skipped because number of planes ({planes}) is less than 3")
            continue
        images = images.view(B * 3, 1, H, W).to(device)
        print(f"Batch {batch_idx}: Reshaped to {images.shape}")
        optimizer.zero_grad()
        projections = model(images)
        print(f"Batch {batch_idx}: Projections shape = {projections.shape}")
        loss = multi_positive_info_nce_loss(projections, num_views=3, temperature=0.5)
        print(f"Batch {batch_idx}: Loss = {loss.item():.4f}")
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
    print(f"Epoch {epoch}: Average Loss = {avg_loss:.4f}")
torch.save(model.encoder.state_dict(), "pretrained_encoder.pth")
print("Pretrained encoder saved as 'pretrained_encoder.pth'")

Starting epoch 0
Batch 0: Original shape = torch.Size([16, 3, 512, 512])
Batch 0: Reshaped to torch.Size([48, 1, 512, 512])
Batch 0: Projections shape = torch.Size([48, 128])
Batch 0: Loss = 3.1506
Batch 1: Original shape = torch.Size([16, 3, 512, 512])
Batch 1: Reshaped to torch.Size([48, 1, 512, 512])
Batch 1: Projections shape = torch.Size([48, 128])
Batch 1: Loss = 3.0020
Batch 2: Original shape = torch.Size([16, 3, 512, 512])
Batch 2: Reshaped to torch.Size([48, 1, 512, 512])
Batch 2: Projections shape = torch.Size([48, 128])
Batch 2: Loss = 2.7980
Batch 3: Original shape = torch.Size([16, 3, 512, 512])
Batch 3: Reshaped to torch.Size([48, 1, 512, 512])
Batch 3: Projections shape = torch.Size([48, 128])
Batch 3: Loss = 3.0182
Batch 4: Original shape = torch.Size([16, 3, 512, 512])
Batch 4: Reshaped to torch.Size([48, 1, 512, 512])
Batch 4: Projections shape = torch.Size([48, 128])
Batch 4: Loss = 3.0607
Batch 5: Original shape = torch.Size([16, 3, 512, 512])
Batch 5: Reshaped to t

In [5]:
vis.visualise_input_event(ds)
vis.visualise_input_event(ds)
vis.visualise_input_event(ds)

RuntimeError: Class values must be smaller than num_classes.

In [None]:
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)

In [27]:
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)

In [28]:
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)
vis.visualise_overlay_event(ds)