In [1]:
import numpy as np
import torch
import torch.nn as nn
import json

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset, Dataset, random_split

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path = "Datasets/CMS_task_i/"

In [3]:
import h5py
filename = data_path + "photons.hdf5"

photons = h5py.File(filename, "r")

X_photon = torch.Tensor(photons[list(photons.keys())[0]][:])
y_photon = torch.Tensor(photons[list(photons.keys())[1]][:])

In [4]:
filename = data_path + "electrons.hdf5"

electron = h5py.File(filename, "r")

X_electron = torch.Tensor(electron[list(electron.keys())[0]][:])
y_electron = torch.Tensor(electron[list(electron.keys())[1]][:])

### let's check X and y dimensions

In [5]:
print("X (photon) has shape: {}\ny (photon) has shape: {}".format(X_photon.shape, y_photon.shape))

X (photon) has shape: torch.Size([249000, 32, 32, 2])
y (photon) has shape: torch.Size([249000])


In [6]:
print("X (electron) has shape: {}\ny (electron) has shape: {}".format(X_electron.shape, y_electron.shape))

X (electron) has shape: torch.Size([249000, 32, 32, 2])
y (electron) has shape: torch.Size([249000])


### And now concat everything along the first dimension

In [7]:
X = torch.cat((X_electron, X_photon), axis=0)
y = torch.cat((y_electron, y_photon), axis=0)

In [8]:
y = y.type(torch.LongTensor)

In [9]:
print("X has shape: {}\ny has shape: {}".format(X.shape, y.shape))

X has shape: torch.Size([498000, 32, 32, 2])
y has shape: torch.Size([498000])


In [10]:
X = torch.permute(X, (0, 3, 1, 2))

In [11]:
# N C H W
# channels first
print("X has shape: {}\ny has shape: {}".format(X.shape, y.shape))

X has shape: torch.Size([498000, 2, 32, 32])
y has shape: torch.Size([498000])


### Some parameters

In [12]:
batch_size = 128

### Transforms

In [13]:
from torchvision.transforms.functional import rotate
from torch.utils.data import Dataset

class PhotonElectronDataset(Dataset):
    """
        This class will return the raw-rotationed digit pairs in the VAE encoder latent space.
    """

    def __init__(self, dataset, ds_transforms=None):
        """
        Arguments:
            dataset: should be a TensorDataset from PyTorch.
        """
        self.dataset = dataset
        self.ds_transforms = ds_transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
            each sample returned will consist of the original image
            and the rotated image latent representations.
        """
        
        if self.ds_transforms:
            return self.ds_transforms(self.dataset[idx][0]), self.dataset[idx][1]

        return self.dataset[idx][0], self.dataset[idx][1]

In [14]:
# define transforms
transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.Pad((40, 40, 40, 40))
])

In [15]:
tensor_ds = TensorDataset(X, y)

In [16]:
dataset = PhotonElectronDataset(tensor_ds, transform)

In [17]:
train_size = int(0.8 * len(y))
test_size = len(y) - train_size

In [18]:
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [19]:
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

In [20]:
test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=True
    )

### ResNet from scratch

In [21]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel, out_channels, kernel_sizes, stride=2, paddings = [0,0,0,0],downsample=False, verbose=False):
        super(ResidualBlock, self).__init__()

        self.verbose = verbose
        
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channel, out_channels[0], kernel_size = kernel_sizes[0], stride = 2, padding=paddings[0]),
                        nn.BatchNorm2d(out_channels[0]),
                        nn.ReLU())
        
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels[0], out_channels[1], kernel_size = kernel_sizes[1], padding=paddings[1]),
                        nn.BatchNorm2d(out_channels[1]),
                        nn.ReLU())
        
        self.conv3 = nn.Sequential(
                        nn.Conv2d(out_channels[1], out_channels[2], kernel_size = kernel_sizes[2], padding=paddings[2]),
                        nn.BatchNorm2d(out_channels[2]))

        self.conv4 = nn.Sequential(
                        nn.Conv2d(in_channel, out_channels[3], kernel_size = 1, stride=stride, padding=paddings[3]),
                        nn.BatchNorm2d(out_channels[3]))
        
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):
        residual = self.conv4(x)
        
        out = self.conv1(x)

        if self.verbose:
            print("input shape: ", x.shape)

        if self.verbose:
            print("Conv1 output shape: ", out.shape)
        out = self.conv2(out)
        if self.verbose:
            print("Conv2 output shape: ", out.shape)
        out = self.conv3(out)
    
        if self.verbose:
            print("Conv3 output shape: ", out.shape)
        
        # if self.downsample:
        #     residual = self.downsample(x)
        
        if self.verbose:
            print("output conv-1-2-3 shape: ",out.shape)
        if self.verbose:
            print("skip connection shape: ",residual.shape)

        # if self.downsample:
        #     residual = F.adaptive_avg_pool2d(residual, (out.size(2), out.size(3)))
        
        out += residual
        out = self.relu(out)
        return out

### My idea was to add one trainable convolutional layer before the whole architecture to have a third channel.

In [22]:
class ResNet(nn.Module):
    def __init__(self, block, num_classes = 2): # Only 2 classes, not 10 like in the original.
        super(ResNet, self).__init__()

        self.conv0 = nn.Conv2d(2, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
        
        self.conv1 = nn.Sequential(
                        nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3), # original ResNet uses 3 channels for Imagenet. We have 2 for HEP.
                        nn.BatchNorm2d(64),
                        nn.ReLU())
        
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1)
        
        self.layer0 = ResidualBlock(64, [32,32,64,64], [1,3,1,1], paddings=[0,1,0,0])
        self.layer1 = ResidualBlock(64, [64,64,128,128], [1,3,1,1], paddings=[0,1,0,0]) 
        self.layer2 = ResidualBlock(128, [128,128,256,256], [1,3,1,1], paddings=[0,1,0,0]) 
        self.layer3 = ResidualBlock(256, [256,256,512,512], [1,3,1,1], paddings=[0,1,0,0]) 

        self.fc = nn.Linear(12800, num_classes)
        
    def forward(self, x):

        x = self.conv0(x)
        # print("SHAPE antes: ", x.shape)
        x = self.conv1(x)
        # print("SHAPE : ", x.shape)
        x = self.maxpool(x)
        # print("SHAPE maxpool: ", x.shape)
        # print("\nlayer 0 : \n")
        x = self.layer0(x)
        # print("\nlayer 1 : \n")
        x = self.layer1(x)
        # print("\nlayer 2 : \n")
        x = self.layer2(x)
        # print("\nlayer 3 : \n")
        x = self.layer3(x)

        # x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        # print("\nx : \n", x.shape, x)
        # raise

        return x

In [116]:
del model

In [23]:
num_classes = 2
num_epochs = 50
learning_rate = 1e-3

model = ResNet(ResidualBlock).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay = 1e-2)  

# Train the model
total_step = len(train_loader)

In [24]:
print(model)

ResNet(
  (conv0): Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer0): ResidualBlock(
    (conv1): Sequential(
      (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

## Train

In [25]:
import gc
from tqdm import tqdm

total_step = len(train_loader)

### initialize logs
res = {'epochs': [], 'train_loss': [], 'val_loss': [],\
       'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

for epoch in range(num_epochs):
    correct = 0
    total = 0
    total_loss = 0
    for i, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
        
        optimizer.zero_grad()
        
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        
        preds = torch.max(outputs, 1).indices
        loss = criterion(outputs, labels)
        
        correct += (preds == labels).sum().item()
        total_loss += loss.item()
        total += labels.size(0)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()
        
    acc = 100 * correct / total
    
    res['train_acc'].append(acc)
    res['train_loss'].append(total_loss)
    
    print ('Epoch [{}/{}], Train accuracy: {:.4f} / Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, acc, loss.item()))
            
    # Validation
    with torch.no_grad():
        correct = 0
        total = 0
        total_loss = 0
        for images, labels in tqdm(test_loader, total = len(test_loader)):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            preds = torch.max(outputs, 1).indices
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            del images, labels, outputs
            
        acc = 100 * correct / total
        
        if acc > res['best_val']:
            res['best_val'] = acc
            res['best_epoch'] = epoch
            # torch.save(model.state_dict(), "models/CMS_task_i/best-model.pt")
        
        res['val_acc'].append(acc)
        res['val_loss'].append(total_loss)
        
        json_object = json.dumps(res, indent=4)
        with open("logs/CMS_task_i/train-result-epoch{}.json".format(epoch), "w") as outfile:
            outfile.write(json_object)
            
        print('Validation accuracy: {} / validation loss: {}'.format(acc, total_loss)) 

100%|██████████| 3113/3113 [11:39<00:00,  4.45it/s]


Epoch [1/50], Train accuracy: 66.3479 / Loss: 0.6032


100%|██████████| 779/779 [00:56<00:00, 13.71it/s]


Validation accuracy: 69.83835341365462 / validation loss: 455.13471484184265


100%|██████████| 3113/3113 [11:39<00:00,  4.45it/s]


Epoch [2/50], Train accuracy: 71.3022 / Loss: 0.5575


100%|██████████| 779/779 [00:58<00:00, 13.32it/s]


Validation accuracy: 71.41967871485944 / validation loss: 442.12287613749504


100%|██████████| 3113/3113 [11:39<00:00,  4.45it/s]


Epoch [3/50], Train accuracy: 72.0843 / Loss: 0.5533


100%|██████████| 779/779 [00:56<00:00, 13.74it/s]


Validation accuracy: 71.90261044176707 / validation loss: 437.2502802014351


100%|██████████| 3113/3113 [11:39<00:00,  4.45it/s]


Epoch [4/50], Train accuracy: 72.5126 / Loss: 0.5391


100%|██████████| 779/779 [00:56<00:00, 13.71it/s]


Validation accuracy: 72.15160642570281 / validation loss: 432.6467012465


100%|██████████| 3113/3113 [11:42<00:00,  4.43it/s]


Epoch [5/50], Train accuracy: 72.7455 / Loss: 0.6251


100%|██████████| 779/779 [00:56<00:00, 13.79it/s]


Validation accuracy: 72.95682730923694 / validation loss: 425.0837941169739


100%|██████████| 3113/3113 [11:40<00:00,  4.44it/s]


Epoch [6/50], Train accuracy: 72.9636 / Loss: 0.5867


100%|██████████| 779/779 [00:58<00:00, 13.38it/s]


Validation accuracy: 72.58132530120481 / validation loss: 428.717197984457


100%|██████████| 3113/3113 [11:42<00:00,  4.43it/s]


Epoch [7/50], Train accuracy: 73.1606 / Loss: 0.4824


100%|██████████| 779/779 [00:57<00:00, 13.45it/s]


Validation accuracy: 73.12048192771084 / validation loss: 423.877406924963


100%|██████████| 3113/3113 [11:43<00:00,  4.42it/s]


Epoch [8/50], Train accuracy: 73.2809 / Loss: 0.4710


100%|██████████| 779/779 [00:57<00:00, 13.54it/s]


Validation accuracy: 72.64056224899599 / validation loss: 426.29316836595535


100%|██████████| 3113/3113 [11:41<00:00,  4.44it/s]


Epoch [9/50], Train accuracy: 73.4064 / Loss: 0.5217


100%|██████████| 779/779 [00:57<00:00, 13.51it/s]


Validation accuracy: 73.1987951807229 / validation loss: 422.795225083828


100%|██████████| 3113/3113 [11:39<00:00,  4.45it/s]


Epoch [10/50], Train accuracy: 73.5018 / Loss: 0.6221


100%|██████████| 779/779 [00:58<00:00, 13.37it/s]


Validation accuracy: 72.95080321285141 / validation loss: 423.42400577664375


100%|██████████| 3113/3113 [11:40<00:00,  4.45it/s]


Epoch [11/50], Train accuracy: 73.5633 / Loss: 0.5433


100%|██████████| 779/779 [00:57<00:00, 13.54it/s]


Validation accuracy: 73.18373493975903 / validation loss: 421.7777629196644


100%|██████████| 3113/3113 [11:40<00:00,  4.45it/s]


Epoch [12/50], Train accuracy: 73.6288 / Loss: 0.4792


100%|██████████| 779/779 [00:57<00:00, 13.49it/s]


Validation accuracy: 73.21285140562249 / validation loss: 421.5066642463207


100%|██████████| 3113/3113 [11:41<00:00,  4.44it/s]


Epoch [13/50], Train accuracy: 73.7751 / Loss: 0.6262


100%|██████████| 779/779 [00:57<00:00, 13.48it/s]


Validation accuracy: 73.5933734939759 / validation loss: 418.74328660964966


100%|██████████| 3113/3113 [11:41<00:00,  4.44it/s]


Epoch [14/50], Train accuracy: 73.7967 / Loss: 0.5151


100%|██████████| 779/779 [00:57<00:00, 13.47it/s]


Validation accuracy: 73.53614457831326 / validation loss: 419.5775503218174


100%|██████████| 3113/3113 [11:42<00:00,  4.43it/s]


Epoch [15/50], Train accuracy: 73.9051 / Loss: 0.4879


100%|██████████| 779/779 [00:57<00:00, 13.54it/s]


Validation accuracy: 73.58333333333333 / validation loss: 417.9462950527668


100%|██████████| 3113/3113 [11:41<00:00,  4.44it/s]


Epoch [16/50], Train accuracy: 73.9661 / Loss: 0.4577


100%|██████████| 779/779 [00:58<00:00, 13.22it/s]


Validation accuracy: 73.65863453815261 / validation loss: 417.0901833176613


100%|██████████| 3113/3113 [11:41<00:00,  4.44it/s]


Epoch [17/50], Train accuracy: 74.0562 / Loss: 0.5307


100%|██████████| 779/779 [00:57<00:00, 13.48it/s]


Validation accuracy: 73.59136546184739 / validation loss: 418.13355028629303


100%|██████████| 3113/3113 [11:41<00:00,  4.44it/s]


Epoch [18/50], Train accuracy: 74.0580 / Loss: 0.5523


100%|██████████| 779/779 [00:57<00:00, 13.55it/s]


Validation accuracy: 73.59236947791165 / validation loss: 417.9561694562435


100%|██████████| 3113/3113 [11:42<00:00,  4.43it/s]


Epoch [19/50], Train accuracy: 74.2038 / Loss: 0.5346


100%|██████████| 779/779 [00:57<00:00, 13.59it/s]


Validation accuracy: 73.55923694779116 / validation loss: 419.28973269462585


100%|██████████| 3113/3113 [11:41<00:00,  4.44it/s]


Epoch [20/50], Train accuracy: 74.1943 / Loss: 0.5663


100%|██████████| 779/779 [00:57<00:00, 13.52it/s]


Validation accuracy: 73.44076305220884 / validation loss: 419.54751190543175


100%|██████████| 3113/3113 [11:45<00:00,  4.41it/s]


Epoch [21/50], Train accuracy: 74.3037 / Loss: 0.5179


100%|██████████| 779/779 [00:57<00:00, 13.45it/s]


Validation accuracy: 73.83232931726907 / validation loss: 416.3046959936619


100%|██████████| 3113/3113 [11:42<00:00,  4.43it/s]


Epoch [22/50], Train accuracy: 74.3773 / Loss: 0.5121


100%|██████████| 779/779 [00:57<00:00, 13.55it/s]


Validation accuracy: 73.71586345381526 / validation loss: 417.3388133943081


100%|██████████| 3113/3113 [11:42<00:00,  4.43it/s]


Epoch [23/50], Train accuracy: 74.4669 / Loss: 0.5662


100%|██████████| 779/779 [00:58<00:00, 13.41it/s]


Validation accuracy: 72.89859437751004 / validation loss: 424.6370293200016


100%|██████████| 3113/3113 [11:43<00:00,  4.43it/s]


Epoch [24/50], Train accuracy: 74.5233 / Loss: 0.4358


100%|██████████| 779/779 [00:57<00:00, 13.54it/s]


Validation accuracy: 73.51506024096386 / validation loss: 419.8921719789505


100%|██████████| 3113/3113 [11:41<00:00,  4.43it/s]


Epoch [25/50], Train accuracy: 74.5989 / Loss: 0.4267


100%|██████████| 779/779 [00:57<00:00, 13.45it/s]


Validation accuracy: 73.56124497991968 / validation loss: 420.2965717613697


100%|██████████| 3113/3113 [11:43<00:00,  4.43it/s]


Epoch [26/50], Train accuracy: 74.6283 / Loss: 0.5392


100%|██████████| 779/779 [00:57<00:00, 13.54it/s]


Validation accuracy: 73.67369477911646 / validation loss: 418.3039543926716


100%|██████████| 3113/3113 [11:42<00:00,  4.43it/s]


Epoch [27/50], Train accuracy: 74.7558 / Loss: 0.6173


100%|██████████| 3113/3113 [11:40<00:00,  4.44it/s]


Epoch [36/50], Train accuracy: 75.7631 / Loss: 0.4138


100%|██████████| 779/779 [01:00<00:00, 12.88it/s]


Validation accuracy: 73.28714859437751 / validation loss: 427.96770933270454


100%|██████████| 3113/3113 [11:52<00:00,  4.37it/s]


Epoch [37/50], Train accuracy: 75.8918 / Loss: 0.5133


100%|██████████| 779/779 [01:00<00:00, 12.96it/s]


Validation accuracy: 73.11546184738955 / validation loss: 426.32376578450203


100%|██████████| 3113/3113 [11:51<00:00,  4.37it/s]


Epoch [38/50], Train accuracy: 76.0133 / Loss: 0.5871


100%|██████████| 779/779 [01:00<00:00, 12.92it/s]


Validation accuracy: 73.29317269076306 / validation loss: 425.2430007457733


100%|██████████| 3113/3113 [11:50<00:00,  4.38it/s]


Epoch [39/50], Train accuracy: 76.1433 / Loss: 0.4416


100%|██████████| 779/779 [00:59<00:00, 12.99it/s]


Validation accuracy: 73.16666666666667 / validation loss: 428.31611317396164


100%|██████████| 3113/3113 [11:50<00:00,  4.38it/s]


Epoch [40/50], Train accuracy: 76.3517 / Loss: 0.4569


100%|██████████| 779/779 [00:59<00:00, 13.04it/s]


Validation accuracy: 72.94377510040161 / validation loss: 436.14124631881714


100%|██████████| 3113/3113 [11:50<00:00,  4.38it/s]


Epoch [41/50], Train accuracy: 76.4661 / Loss: 0.5352


100%|██████████| 779/779 [01:00<00:00, 12.97it/s]


Validation accuracy: 73.00602409638554 / validation loss: 432.88165950775146


100%|██████████| 3113/3113 [11:50<00:00,  4.38it/s]


Epoch [42/50], Train accuracy: 76.6719 / Loss: 0.4453


100%|██████████| 779/779 [00:59<00:00, 13.16it/s]


Validation accuracy: 72.84236947791165 / validation loss: 436.3695269227028


100%|██████████| 3113/3113 [11:49<00:00,  4.39it/s]


Epoch [43/50], Train accuracy: 76.8727 / Loss: 0.4239


100%|██████████| 779/779 [00:59<00:00, 13.18it/s]


Validation accuracy: 72.72690763052209 / validation loss: 443.80795815587044


100%|██████████| 3113/3113 [11:48<00:00,  4.39it/s]


Epoch [44/50], Train accuracy: 76.9955 / Loss: 0.4780


100%|██████████| 779/779 [00:58<00:00, 13.23it/s]


Validation accuracy: 72.73895582329317 / validation loss: 441.62601512670517


100%|██████████| 3113/3113 [11:48<00:00,  4.39it/s]


Epoch [45/50], Train accuracy: 77.2096 / Loss: 0.4546


100%|██████████| 779/779 [00:58<00:00, 13.42it/s]


Validation accuracy: 72.57831325301204 / validation loss: 444.1111565232277


100%|██████████| 3113/3113 [11:44<00:00,  4.42it/s]


Epoch [46/50], Train accuracy: 77.3838 / Loss: 0.3915


100%|██████████| 779/779 [00:58<00:00, 13.31it/s]


Validation accuracy: 72.46184738955823 / validation loss: 452.8960107266903


100%|██████████| 3113/3113 [11:49<00:00,  4.39it/s]


Epoch [47/50], Train accuracy: 77.6062 / Loss: 0.4541


100%|██████████| 779/779 [01:00<00:00, 12.85it/s]


Validation accuracy: 72.394578313253 / validation loss: 452.5816104412079


100%|██████████| 3113/3113 [11:52<00:00,  4.37it/s]


Epoch [48/50], Train accuracy: 77.7799 / Loss: 0.5113


100%|██████████| 779/779 [01:00<00:00, 12.85it/s]


Validation accuracy: 72.33433734939759 / validation loss: 453.6984914839268


100%|██████████| 3113/3113 [11:52<00:00,  4.37it/s]


Epoch [49/50], Train accuracy: 77.9024 / Loss: 0.5447


100%|██████████| 779/779 [01:01<00:00, 12.75it/s]


Validation accuracy: 72.3012048192771 / validation loss: 457.51171255111694


100%|██████████| 3113/3113 [11:52<00:00,  4.37it/s]


Epoch [50/50], Train accuracy: 78.1785 / Loss: 0.4832


100%|██████████| 779/779 [01:00<00:00, 12.88it/s]

Validation accuracy: 72.23795180722891 / validation loss: 455.81895673274994





The validation loss is some orders of magnitude higher than training because it's not divided per sample.