In [10]:
#dependency package for this notebook
# !pip install h5py
# !pip install torchinfo

In [11]:
import os
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split, TensorDataset, Dataset
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms, models
%matplotlib inline
import h5py
from torchinfo import summary
import tarfile
print(torch.__version__)

2.0.1


In [12]:
# proton_url = "https://cernbox.cern.ch/remote.php/dav/public-files/AtBT8y4MiQYFcgc/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5"
# download_url(proton_url, "./data/proton_data")

In [13]:
# electron_url = "https://cernbox.cern.ch/remote.php/dav/public-files/FbXw3V4XNyYB3oA/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5"
# download_url(electron_url, "./data/electron_data")

In [14]:
proton_path = "./data/proton_dataset/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5"
electron_path = "./data/electron_dataset/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5"
f = h5py.File(electron_path)
xset = f['X']
yset = f['y']
len(xset)
#f.keys()()
#dset.attrs.keys()
#dset.attrs.values()
#dset.attrs.items()

249000

In [15]:
# %%writefile modules/custom_dataset.py
import random
class CombinedDataset(Dataset):  
    def __init__(self, proton_path, electron_path):
        self.proton_file = h5py.File(proton_path)
        self.electron_file = h5py.File(electron_path)
        self.proton_data = self.proton_file['X']
        self.proton_labels = self.proton_file['y']
        self.electron_data = self.electron_file['X']
        self.electron_labels = self.electron_file['y']
        self.length = len(self.proton_data) + len(self.electron_data)
        self.transform = transforms.Compose([
            transforms.ToTensor()  # Convert image to PyTorch tensors (CHW format)
#             transforms.Normalize(mean=imagenet_mean, std=imagenet_std),  # Normalize using ImageNet mean and std
#             transforms.Resize(224),  # Resize image to 224x224 (common for ResNet-18)
#             transforms.RandomResizedCrop(224),  # Randomly crop a 224x224 region from the resized image (data augmentation)
#             transforms.RandomHorizontalFlip(p=0.5)  # Randomly flip the image horizontally (data augmentation)
        ])

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        if index < len(self.proton_data):
            dataset = 'proton' 
            actual_index = index
        else:
            dataset = 'electron'
            actual_index = index - len(self.proton_data) 

        if dataset == 'proton':
            image, label = self.proton_data[actual_index], self.proton_labels[actual_index]
        else: 
            image, label = self.electron_data[actual_index], self.electron_labels[actual_index]
        return image, label

In [16]:
# #%%writefile modular/dataloader.py
# #import CombinedDataset from custom_dataset
import os
#from modules.custom_dataset import CombinedDataset
from torch.utils.data import DataLoader, random_split

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
    electron_path,
    proton_path,
    auto_transforms,
    batch_size,
    num_workers=NUM_WORKERS
):
    dataset = CombinedDataset(
        proton_path=proton_path,
        electron_path=electron_path
    )

    train_size = int(0.8 * len(dataset)) 
    test_size = len(dataset) - train_size
    train_set, test_set = random_split(dataset, [train_size, test_size])

    train_dataloader = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        shuffle=True,
        pin_memory=True
    )
    test_dataloader = DataLoader(
        test_set,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        shuffle=False,
        pin_memory=True
    )
    return train_dataloader, test_dataloader
# weights = torchvision.models.ResNet18_Weights.DEFAULT

dataset = CombinedDataset(electron_path, proton_path)
train_dataloader = DataLoader(dataset, batch_size = 64, shuffle = True)
for images, labels in train_dataloader:
    print(f'{images.shape}')
    break

torch.Size([64, 32, 32, 2])


In [17]:
#from modular import dataloader
weights = torchvision.models.ResNet18_Weights.DEFAULT
weights

ResNet18_Weights.IMAGENET1K_V1

In [18]:
model = models.resnet18(pretrained = True)
for param in model.parameters():
    param.requires_grad = False

model.fc.requires_grad = True

new_layer = nn.Sequential(
    nn.Linear(512, 256), 
    nn.ReLU(),
    nn.Dropout(p=0.5), 
    nn.Linear(256, 2),  
    nn.BatchNorm1d(2) 
)
model.fc = new_layer
summary(model = model, 
        input_size=(32, 3, 224, 224),
        col_names=["input_size", "output_size", "trainable"],
        #col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
       )

#loss_fn = nn.CrossEntropyLoss()
#optimizer = optim.Adam(model.parameters(), lr = 0.001)



Layer (type (var_name))                  Input Shape          Output Shape         Trainable
ResNet (ResNet)                          [32, 3, 224, 224]    [32, 2]              Partial
├─Conv2d (conv1)                         [32, 3, 224, 224]    [32, 64, 112, 112]   False
├─BatchNorm2d (bn1)                      [32, 64, 112, 112]   [32, 64, 112, 112]   False
├─ReLU (relu)                            [32, 64, 112, 112]   [32, 64, 112, 112]   --
├─MaxPool2d (maxpool)                    [32, 64, 112, 112]   [32, 64, 56, 56]     --
├─Sequential (layer1)                    [32, 64, 56, 56]     [32, 64, 56, 56]     False
│    └─BasicBlock (0)                    [32, 64, 56, 56]     [32, 64, 56, 56]     False
│    │    └─Conv2d (conv1)               [32, 64, 56, 56]     [32, 64, 56, 56]     False
│    │    └─BatchNorm2d (bn1)            [32, 64, 56, 56]     [32, 64, 56, 56]     False
│    │    └─ReLU (relu)                  [32, 64, 56, 56]     [32, 64, 56, 56]     --
│    │    └─Conv2d (conv

In [None]:
def train_func():
    model.train()
    for 