In [1]:
import uproot
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors



In [2]:
file_path = "/gluster/data/dune/niclane/nlane_prod_strange_resample_fhc_run2_fhc_reco2_reco2_trainingimage_background_lambdamuon_ana.root"  
#file_path = "/gluster/data/dune/niclane/test.root"

root_file = uproot.open(file_path)
print(root_file.keys())

['imageanalyser;1', 'imageanalyser/SampleTree;1', 'imageanalyser/ImageTree;1', 'FRH;1', 'FRV;1', 'rICKR;1', 'rICKI;1', 'PreC;1', 'PostC;1', 'PostO;1', 'PreD;1', 'PostDO;1', 'ER;1']


In [3]:
tree_name = "imageanalyser/ImageTree" 
tree = root_file[tree_name]

print(tree.keys())

['run', 'subrun', 'event', 'event_type', 'planes', 'width', 'height', 'input_data', 'truth_data']


In [4]:
input_data = tree["input_data"].array(library="np")
truth_data = tree["truth_data"].array(library="np")

print("Shape:", input_data[1])
print("Data Type:", type(input_data))

Shape: [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ...]
Data Type: <class 'numpy.ndarray'>


In [5]:
print(f"Number of events: {len(input_data)}")  
print(f"Planes per event: {len(input_data[0])}")  
print(f"Pixels per plane: {len(input_data[0][0])}")

Number of events: 706
Planes per event: 3
Pixels per plane: 262144


In [6]:
event_index = 2
event = input_data[event_index]  
truth_event = truth_data[event_index] 

run_numbers = tree["run"].array(library="np")
subrun_numbers = tree["subrun"].array(library="np")
event_numbers = tree["event"].array(library="np")

event_type = tree["event_type"].array(library="np")

print(event_type[event_index])

print(type(input_data))
plane = event[0]
print(type(plane))

num_entries = tree.num_entries
print(f"Number of entries in {tree_name}: {num_entries}")

1
<class 'numpy.ndarray'>
<class 'uproot.containers.STLVector'>
Number of entries in imageanalyser/ImageTree: 706


In [7]:
W, H = 512, 512  

plane_images = [np.array(list(plane), dtype=np.float32).reshape(H, W) for plane in event]
truth_images = [np.array(list(plane), dtype=np.float32).reshape(H, W) for plane in truth_event]

plane_labels = ["U", "V", "W"]

r, sr, evnum = run_numbers[event_index], subrun_numbers[event_index], event_numbers[event_index]

for i, (input_img, truth_img) in enumerate(zip(plane_images, truth_images)):
    fig, ax = plt.subplots(figsize=(12, 12), dpi=600)

    ax.imshow(input_img,
              origin="lower",
              cmap="jet",
              norm=colors.PowerNorm(gamma=0.35, vmin=input_img.min(), vmax=input_img.max()))

    overlay = False
    if overlay:
        ax.imshow(truth_img, origin="lower", cmap="cool", alpha=0.4)

    ax.set_xticks([0, W - 1])
    ax.set_yticks([0, H - 1])
    ax.tick_params(axis="both", direction="out", length=6, width=1.5, labelsize=18)
    ax.set_xlim(0, W - 1)
    ax.set_ylim(0, H - 1)
    ax.set_xlabel("Local Wire Coord", fontsize=20)
    ax.set_ylabel("Local Drift Time", fontsize=20)
    ax.set_title(f"Plane {plane_labels[i]} (Run {r}, Subrun {sr}, Event {evnum})", fontsize=22)

    plt.tight_layout()
    plt.savefig(f"event_{r}_{sr}_{evnum}_plane_{plane_labels[i]}.png")
    plt.close(fig)

In [8]:
event_indices = [5]

run_numbers = tree["run"].array(library="np")
subrun_numbers = tree["subrun"].array(library="np")
event_numbers = tree["event"].array(library="np")
event_type = tree["event_type"].array(library="np")

W, H = 512, 512  
plane_labels = ["U", "V", "W"]

for event_index in event_indices:
    event = input_data[event_index]  
    truth_event = truth_data[event_index]  

    print(f"Processing Event Index: {event_index}")
    print(f"Event Type: {event_type[event_index]}")

    r, sr, evnum = run_numbers[event_index], subrun_numbers[event_index], event_numbers[event_index]

    plane_images = [np.array(list(plane), dtype=np.float32).reshape(H, W) for plane in event]
    truth_images = [np.array(list(plane), dtype=np.float32).reshape(H, W) for plane in truth_event]

    print(f"Number of truth image channels: {len(truth_images)}")

    for i, (input_img, truth_img) in enumerate(zip(plane_images, truth_images)):
        fig, ax = plt.subplots(figsize=(12, 12), dpi=600)

        ax.imshow(input_img,
                  origin="lower",
                  cmap="jet",
                  norm=colors.PowerNorm(gamma=0.35, vmin=input_img.min(), vmax=input_img.max()))

        overlay = True
        if overlay:
            ax.imshow(truth_img, origin="lower", cmap="cool", alpha=0.4)

        ax.set_xticks([0, W - 1])
        ax.set_yticks([0, H - 1])
        ax.tick_params(axis="both", direction="out", length=6, width=1.5, labelsize=18)
        ax.set_xlim(0, W - 1)
        ax.set_ylim(0, H - 1)
        ax.set_xlabel("Local Wire Coord", fontsize=20)
        ax.set_ylabel("Local Drift Time", fontsize=20)
        ax.set_title(f"Plane {plane_labels[i]} (Run {r}, Subrun {sr}, Event {evnum})", fontsize=22)

        plt.tight_layout()
        plt.savefig(f"event_{r}_{sr}_{evnum}_plane_{plane_labels[i]}.png")
        plt.close(fig)

Processing Event Index: 5
Event Type: 1
Number of truth image channels: 3


In [None]:
%load_ext autoreload
%autoreload 2
import yaml
import os
import torch
import numpy as np
import uproot
import matplotlib.pyplot as plt

from src.config import Config
from src.dataset import SegmentationDataset, ContrastiveDataset  
from src.visualiser import Visualiser
from src.trainers import SegmentationTrainer, ContrastiveTrainer

In [None]:
path = "cfg/default.yaml"
config = Config(path)

print("Loaded configuration:")
print(config.as_dict())
print("Dataset File Path:", config.get("dataset.file_path")) 

In [None]:
try:
    segmentation_dataset = SegmentationDataset(config)
    print(f"SegmentationDataset initialized with {len(segmentation_dataset)} samples.")
except Exception as e:
    print(f"Error loading SegmentationDataset: {e}")

try:
    contrastive_dataset = ContrastiveDataset(config)
    print(f"ContrastiveDataset initialized with {len(contrastive_dataset)} samples.")
except Exception as e:
    print(f"Error loading ContrastiveDataset: {e}")

In [None]:
visualiser = Visualiser(segmentation_dataset, width=512, height=512)

In [None]:
visualiser.visualise_event_planes(idx=10, save=True, show=False)

In [None]:
import yaml
from src.config import Config

config_path = "cfg/default.yaml"  
config = Config(config_path)

print("Training Config:\n", yaml.dump(config.as_dict(), default_flow_style=False))

In [None]:
dataset = SegmentationDataset(config)
trainer = SegmentationTrainer(config, dataset)
#print(trainer.model)

In [None]:
trainer.train()

In [None]:
def load_model_checkpoint(checkpoint_path, config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UResNet(
        in_dim=config.get("model.in_channels"),
        n_classes=config.get("model.num_classes"),
        n_filters=config.get("model.filters"),
        drop_prob=config.get("model.dropout"),
        y_range=None
    )
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    return model

def run_inference(model, sample):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image, truth = sample
    image_tensor = image.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image_tensor)
    prediction = torch.argmax(output, dim=1).cpu().squeeze(0)
    return image, truth, prediction

def visualise_prediction(image, truth, prediction, plane, run, subrun, event):
    fig, ax = plt.subplots(figsize=(12,12), dpi=600)
    norm_img = mcolors.PowerNorm(gamma=0.35, vmin=image.min(), vmax=image.max())
    ax.imshow(image, origin="lower", cmap="jet", norm=norm_img)
    num_classes = int(np.max(prediction)) + 1
    cmap_seg = plt.get_cmap("tab10", num_classes)
    boundaries = np.arange(-0.5, num_classes+0.5, 1)
    norm_seg = mcolors.BoundaryNorm(boundaries, cmap_seg.N)
    ax.imshow(prediction, origin="lower", cmap=cmap_seg, norm=norm_seg, alpha=0.5)
    ax.set_xticks([0, image.shape[1]-1])
    ax.set_yticks([0, image.shape[0]-1])
    ax.tick_params(axis="both", direction="out", length=6, width=1.5, labelsize=18)
    ax.set_xlabel("Local Wire Coord", fontsize=20)
    ax.set_ylabel("Local Drift Time", fontsize=20)
    ax.set_title(f"Plane {plane} | Run {run}, Subrun {subrun}, Event {event}", fontsize=22)
    plt.tight_layout()
    plt.savefig(f"inference_plane_{plane}_run_{run}_subrun_{subrun}_event_{event}.png")
    plt.show()
    plt.close(fig)

def visualise_model_inference(model, dataset, event_idx):
    sample = dataset[event_idx]
    image, truth, prediction = run_inference(model, sample)
    run, subrun, event = "Unknown", "Unknown", event_idx
    planes = ["U", "V", "W"]
    for i, plane in enumerate(planes):
        visualise_prediction(image[i], truth[i], prediction[i], plane, run, subrun, event)

def plot_loss_curve(epoch_loss_history):
    epochs = np.array([e[0] for e in epoch_loss_history])
    train_loss = np.array([e[1] for e in epoch_loss_history])
    train_err = np.array([e[2] for e in epoch_loss_history])
    val_loss = np.array([e[3] for e in epoch_loss_history])
    val_err = np.array([e[4] for e in epoch_loss_history])
    plt.figure(figsize=(10,6))
    plt.errorbar(epochs, train_loss, yerr=train_err, fmt='-o', capsize=5, label='Train Loss', color='blue', markerfacecolor='white')
    plt.errorbar(epochs, val_loss, yerr=val_err, fmt='-o', capsize=5, label='Validation Loss', color='orange', markerfacecolor='white')
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.title('Training and Validation Loss', fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('loss_curve.png', dpi=300)
    plt.show()


In [None]:
checkpoint_path = "checkpoints/SegmentationTrainer_epoch_1.pth"
model = load_model_checkpoint(checkpoint_path, config)
dataset = SegmentationDataset(config)
visualise_model_inference(model, dataset, event_idx=10)