In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torchmetrics.classification import Accuracy
import h5py
import numpy as np
from tqdm import tqdm
from torch import Tensor
from typing import Type

# Dataclass

In [2]:
class ParticleDataset(Dataset):
    def __init__(self, electron_data_path, photon_data_path, split='train'):
        with h5py.File(electron_data_path, 'r') as f:
            X_electrons = np.array(f['X'])
            y_electrons = np.array(f['y'])
        with h5py.File(photon_data_path, 'r') as f:
            X_photons = np.array(f['X'])
            y_photons = np.array(f['y'])
            
        X = np.concatenate((X_electrons, X_photons), axis=0)
        y = np.concatenate((y_electrons, y_photons))

        np.random.seed(42)
        indices = np.random.permutation(len(X))
        if split == 'train':
            indices = indices[:int(0.8*len(indices))]
        else:
            indices = indices[int(0.8*len(indices)):]
            
        X = X[indices]
        y = y[indices]
        
        X = np.transpose(X, (0, 3, 1, 2))
        # X = np.expand_dims(X, axis=1)
        
        mean = np.mean([np.mean(x, axis=(1,2)) for x in X], axis=0)
        
        std = np.std([np.std(x, axis=(1,2)) for x in X], axis=0)
        
        for i in range(len(X)):
            for j in range(len(X[i])):
                X[i][j] = (X[i][j] - mean[j]) / std[j]
                
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.float32)

## Creating train and val dataclasses

In [3]:
train_dataset = ParticleDataset('./Data/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5',
                          './Data/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5')

val_dataset = ParticleDataset('./Data/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5',
                            './Data/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5', split='val')

In [4]:
len(train_dataset), len(val_dataset)

(398400, 99600)

# ResNet Code

In [6]:
class BasicBlock(nn.Module):
    def __init__(
        self, 
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        expansion: int = 1,
        downsample: nn.Module = None
    ) -> None:
        super(BasicBlock, self).__init__()
        self.expansion = expansion
        self.downsample = downsample
        self.conv1 = nn.Conv2d(
            in_channels, 
            out_channels, 
            kernel_size=3, 
            stride=stride, 
            padding=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            out_channels, 
            out_channels*self.expansion, 
            kernel_size=3, 
            padding=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)
    def forward(self, x: Tensor) -> Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return  out

In [8]:
class ResNet(nn.Module):
    def __init__(
        self, 
        block: Type[BasicBlock],
        img_channels: int = 2,
        num_layers: int = 18,
        num_classes: int  = 2,
        dropout_prob: float = 0.4,
        weight_decay: float = 1e-4
    ) -> None:
        super(ResNet, self).__init__()
        if num_layers == 18:
            layers = [2, 2, 2, 2]
            self.expansion = 1
        
        self.in_channels = 64
        self.conv1 = nn.Conv2d(
            in_channels=img_channels,
            out_channels=self.in_channels,
            kernel_size=3, 
            stride=1,
            padding=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512*self.expansion, num_classes)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.calculate_accuracy = Accuracy(task='multiclass', num_classes=2, top_k=1)
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
    
    def _make_layer(
        self, 
        block: Type[BasicBlock],
        out_channels: int,
        blocks: int,
        stride: int = 1
    ) -> nn.Sequential:
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels, 
                    out_channels*self.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False 
                ),
                nn.BatchNorm2d(out_channels * self.expansion),
            )
        layers = []
        layers.append(
            block(
                self.in_channels, out_channels, stride, self.expansion, downsample
            )
        )
        self.in_channels = out_channels * self.expansion
        for i in range(1, blocks):
            layers.append(block(
                self.in_channels,
                out_channels,
                expansion=self.expansion
            ))
        return nn.Sequential(*layers)
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet(BasicBlock).to(device)

In [12]:
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True, num_workers=9)
val_loader = DataLoader(val_dataset, batch_size=2048, num_workers=9)

In [13]:
inputs, labels = next(iter(train_loader))
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
outputs.shape, labels.shape

(torch.Size([2048, 2]), torch.Size([2048]))

# Training Loop

In [14]:
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

epochs = 100

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    running_accuracy = 0.0
    for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        labels = labels.long()
        loss = criterion(outputs, labels)
        accuracy = model.calculate_accuracy(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_accuracy += accuracy
    print(f"Epoch {epoch+1}, Training loss: {running_loss/len(train_loader)}, Accuracy: {running_accuracy/len(train_loader)}")

    if (epoch + 1) % 1 == 0:
        model.eval()
        running_loss = 0.0
        running_accuracy = 0.0
        with torch.no_grad():
            for i, data in tqdm(enumerate(val_loader, 0), total=len(val_loader)):
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = model(inputs)
                labels = labels.long()
                loss = criterion(outputs, labels)
                accuracy = model.calculate_accuracy(outputs, labels)
                running_loss += loss.item()
                running_accuracy += accuracy
        print(f"Epoch {epoch+1}, Validation loss: {running_loss/len(val_loader)}, Accuracy: {running_accuracy/len(val_loader)}")

print('Finished Training')

  2%|▏         | 3/195 [00:02<02:05,  1.53it/s]

100%|██████████| 195/195 [00:49<00:00,  3.95it/s]

Epoch 1, Training loss: 4.20045590033898, Accuracy: 0.5435935854911804



100%|██████████| 49/49 [00:04<00:00, 11.33it/s]

Epoch 1, Validation loss: 1.075651379264131, Accuracy: 0.60798579454422



100%|██████████| 195/195 [00:48<00:00,  4.03it/s]

Epoch 2, Training loss: 1.392371157805125, Accuracy: 0.5615012049674988



100%|██████████| 49/49 [00:04<00:00, 11.41it/s]

Epoch 2, Validation loss: 0.8318219464652392, Accuracy: 0.6124886274337769



100%|██████████| 195/195 [00:48<00:00,  4.01it/s]

Epoch 3, Training loss: 0.9216139386861752, Accuracy: 0.5796296000480652



100%|██████████| 49/49 [00:04<00:00, 11.34it/s]

Epoch 3, Validation loss: 0.7843041894387226, Accuracy: 0.6189197897911072



100%|██████████| 195/195 [00:48<00:00,  4.00it/s]

Epoch 4, Training loss: 0.816845227816166, Accuracy: 0.5921654105186462



100%|██████████| 49/49 [00:04<00:00, 11.30it/s]

Epoch 4, Validation loss: 0.7148170982088361, Accuracy: 0.6286724805831909



100%|██████████| 195/195 [00:48<00:00,  4.01it/s]

Epoch 5, Training loss: 0.7614744204741258, Accuracy: 0.6038084626197815



100%|██████████| 49/49 [00:04<00:00, 11.29it/s]

Epoch 5, Validation loss: 1.1867082228465957, Accuracy: 0.6351915597915649



100%|██████████| 195/195 [00:48<00:00,  4.00it/s]

Epoch 6, Training loss: 0.709049328779563, Accuracy: 0.6138020753860474



100%|██████████| 49/49 [00:04<00:00, 11.41it/s]

Epoch 6, Validation loss: 0.7853770560147811, Accuracy: 0.632436990737915



100%|██████████| 195/195 [00:48<00:00,  4.00it/s]

Epoch 7, Training loss: 0.7028133866114494, Accuracy: 0.6217465996742249



100%|██████████| 49/49 [00:04<00:00, 11.20it/s]

Epoch 7, Validation loss: 0.6824515924161795, Accuracy: 0.591516375541687



100%|██████████| 195/195 [00:48<00:00,  4.00it/s]

Epoch 8, Training loss: 0.7165614990087655, Accuracy: 0.6221799254417419



100%|██████████| 49/49 [00:04<00:00, 11.32it/s]

Epoch 8, Validation loss: 0.6402513798402281, Accuracy: 0.6425259709358215



100%|██████████| 195/195 [00:48<00:00,  4.00it/s]

Epoch 9, Training loss: 1.0521432191897662, Accuracy: 0.5855473279953003



100%|██████████| 49/49 [00:04<00:00, 11.41it/s]

Epoch 9, Validation loss: 0.7227907351085118, Accuracy: 0.5915562510490417



 73%|███████▎  | 143/195 [00:35<00:12,  4.00it/s]

# Saving model

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')