# Import packages

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset, Dataset, DataLoader

import numpy as np
import h5py
import matplotlib.pyplot as plt
from PIL import Image
import argparse
from argparse import Namespace

import os
import random
from tqdm import tqdm

import tensorboard
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

from resnet import resnet18, resnet34, resnet50
from datetime import datetime as dt

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [14]:
device

device(type='cuda')

# Dataset

In [15]:
particle2idx = {
    '1fpv': 0,
    '1ss8': 1,
    '3j03': 2,
    '1ijg': 3,
    '3iyf': 4,
    '6ody': 5,
    '6sp2': 6,
    '6xs6': 7,
    '7dwz': 8,
    '7dx8': 9,
    '7dx9': 10
}

count2idx = {
    'single': 0,
    'double': 1,
    'triple': 2,
    'quadruple': 3
}

In [16]:
idx2particle = {
    0: '1fpv',
    2: '1ss8',
    2: '3j03',
    3: '1ijg',
    4: '3iyf',
    5: '6ody',
    6: '6sp2',
    7: '6xs6',
    8: '7dwz',
    9: '7dx8',
    10: '7dx9'
}

idx2count = {
    0: 'single',
    1: 'double',
    2: 'triple',
    3: 'quadruple'
}

In [17]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, particles, counts, transform=None, seed=1234):
        self.root_dir = root_dir
        self.transform = transform

        self.count_labels = []
        self.particle_labels = []
        self.data = []

        for particle in particles:
            for count in counts:
                
                # Create directory path to dataset
                n = 1
                data_dir = f'{self.root_dir}/{particle}_{str(n)}k_{count}_pps_1e14_thumbnail.h5'

                # Load images as h5 files
                f = h5py.File(data_dir, 'r')
                dset_name = list(f.keys())[0]
                data = f[dset_name]
                data = [Image.fromarray(data[i]) for i in range(LENGTH * n)]
                data = [self.transform(data[i]) for i in range(LENGTH * n)]
                count_label = [count2idx[count]] * (LENGTH * n)
                particle_label = [particle2idx[particle]] * (LENGTH * n)
                self.data.extend(data)
                self.count_labels.extend(count_label)
                self.particle_labels.extend(particle_label)
        
        # Shuffle the data
        random.seed(seed)
        perm = list(range(len(self.data)))
        random.shuffle(perm)
        self.data = [self.data[i] for i in perm]
        self.count_labels = [self.count_labels[i] for i in perm]
        self.particle_labels = [self.particle_labels[i] for i in perm]

    def __len__(self):
        '''Denotes the total number of samples'''
        return len(self.data)

    def __getitem__(self, index):
        '''Generates one sample of data'''
        X = self.data[index]
        count = self.count_labels[index]
        particle = self.particle_labels[index]
        return X, count, particle

# Dataloader

In [19]:
class AddNoise(object):
    """
    A torchvision.transforms wrapper for addNoise()
    
    """

    def __init__(self, flux_jitter, gaussian_noise):
        assert isinstance(flux_jitter, float)
        assert isinstance(gaussian_noise, float)
        self.flux_jitter = flux_jitter
        self.gaussian_noise = gaussian_noise

    def __call__(self, image):
        return self.addNoise(image)

    def addNoise(self, orig_img):

        def changeIntensity(img, flux_jitter):
            factor = 100 # FIXME: correct for data which has 100 more flux
            mu = 1 # mean jitter
            alpha = np.random.normal(mu, flux_jitter)
            if alpha <= 0: alpha = 0.1 # alpha can't be zero
            n_photons   = alpha*np.sum(img)/factor                     # number of desired photons per image
            return n_photons*(img/np.sum(img)) # cache noise-free measurement
    
        def poisson(img):
            # add poisson noise
            return np.random.poisson(img)      # apply Poisson statistics
    
        def gaussian(img, sigma):
            # add gaussian noise 
            # For random samples from N(\mu, \sigma^2), 
            # mu + sigma * np.random.randn(...)
            # sigma: Gaussian noise level
            img = img + sigma*np.random.randn(*img.shape);  # apply Gaussian statistics
            return img
    
        def varNorm(V):
            # variance normalization, each image has mean 0, variance 1
            # This shouldn't happen, but zero out infinite pixels
            V[np.argwhere(V==np.inf)] = 0
            mean = np.mean(V)
            std = np.std(V)
            if std == 0:
                return np.zeros_like(V)
            V1 = (V-mean)/std
            return V1

        def transform(img):
            img = changeIntensity(img, self.flux_jitter)
            img = poisson(img)
            img = gaussian(img, self.gaussian_noise)
            img = varNorm(img)
            return img

        return transform(orig_img)
    

In [20]:
def get_dataloaders(args, train_val_particles, test_particles, test_diff_particle=False):
    # Original from Shawn
    #transform = transforms.Compose([transforms.CenterCrop(128),
    #                                transforms.RandomVerticalFlip(p=0.5),
    #                                transforms.RandomHorizontalFlip(p=0.5),
    #                                transforms.ToTensor()])
    
    # Modified by EricFlorin
    transform = transforms.Compose([AddNoise(0.9, 1.0),
                                    transforms.RandomVerticalFlip(p=0.5),
                                    transforms.RandomHorizontalFlip(p=0.5),
                                    transforms.ToTensor()])
    
    data_len = args.num_particles * 7000
    
    if not test_diff_particle:
        assert train_val_particles == test_particles
        dataset = CustomDataset(root_dir=args.root_dir,
                                particles=train_val_particles,
                                counts=COUNTS,
                                transform=transform)
        print(len(dataset))
        train_idx = list(range(0, int(data_len * 0.7)))
        valid_idx = list(range(int(data_len * 0.7), int(data_len * 0.8)))
        test_idx = list(range(int(data_len * 0.8), data_len))
        train_dataset = Subset(dataset, train_idx) 
        valid_dataset = Subset(dataset, valid_idx)
        test_dataset = Subset(dataset, test_idx)
    else:
        # Create train/valid/test datasets
        train_val_dataset = CustomDataset(root_dir=args.root_dir, 
                                          particles=train_val_particles,
                                          counts=COUNTS,
                                          transform=transform)
        train_idx = list(range(0, 7000))
        valid_idx = list(range(7000, 8000))
        train_dataset = Subset(train_val_dataset, train_idx) 
        valid_dataset = Subset(train_val_dataset, valid_idx)
        test_dataset = CustomDataset(root_dir=args.root_dir, 
                                    particles=test_particles,
                                    counts=COUNTS,
                                    transform=transform)
        
        assert train_dataset.__getitem__(0)[0].shape == torch.Size([1, 128, 128])
        assert valid_dataset.__getitem__(0)[0].shape == torch.Size([1, 128, 128])
        assert test_dataset.__getitem__(0)[0].shape == torch.Size([1, 128, 128])

    # Create train/valid/test dataloaders
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle, 
                                  num_workers=args.num_workers)
    valid_dataloader = DataLoader(dataset=valid_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle, 
                                  num_workers=args.num_workers)
    test_dataloader = DataLoader(dataset=test_dataset, 
                                 batch_size=args.batch_size, 
                                 shuffle=args.shuffle, 
                                 num_workers=args.num_workers)
    return train_dataloader, valid_dataloader, test_dataloader

# Utils

In [None]:
#def sp_noise(image,prob):
#    '''
#    Add salt and pepper noise to image
#    prob: Probability of the noise
#    '''
#    output = np.zeros(image.shape,np.uint8)
#    thres = 1 - prob 
#    for i in range(image.shape[0]):
#        for j in range(image.shape[1]):
#            rdn = random.random()
#            if rdn < prob:
#                output[i][j] = 0
#            elif rdn > thres:
#                output[i][j] = 255
#            else:
#                output[i][j] = image[i][j]
#    return output

In [None]:
def read_thumbnails(fname):
    f = h5py.File(fname, 'r')
    dset_name = list(f.keys())[0]
    print("Check the datasets keys: " + str(list(f.keys())))
    
    dset = f[dset_name]
    print("Check the shape of the dataset: " + str(dset.shape))
    
    print(len(dset))
    w=20
    h=20
    fig=plt.figure(figsize=(15, 15))
    columns = 4
    rows = 5
    for i in range(1, columns*rows +1):
        img = dset[5* i]
        fig.add_subplot(rows, columns, i)
        plt.imshow(img, vmin=0, vmax=10)
    plt.show()

### Combine two h5 files and save as new file

In [None]:
root_dir = './data/thumbnail'

In [None]:
particle = '1fpv'
filenames = [f'{root_dir}/SPI_{particle}_1k_single_thumbnail.h5',
             f'{root_dir}/SPI_{particle}_3k_single_thumbnail.h5']

dataset = []

for fname in filenames:
    f = h5py.File(fname,'r+')
    dset_name = list(f.keys())[0]
    dset = f[dset_name]
    dataset.extend(dset)

In [None]:
len(dataset)

In [None]:
fout = h5py.File(f'{root_dir}/SPI_{particle}_4k_single_thumbnail.h5','w')
ds = fout.create_dataset('photons', (4000, 128, 130), dtype='float32')
for i in range(4000):
    ds[i,:,:] = dataset[i]
fout.close()

In [None]:
particle = '6xs6'
root_dir = './data/thumbnail'
image_dir = f'{root_dir}/SPI_{particle}_4k_single_thumbnail.h5'

read_thumbnails(image_dir)

# Models

## 3-layer Multi-output CNN (Late)

In [None]:
class MultiOutputCNN(nn.Module):
    def __init__(self, num_particles=11, num_counts=4, hidden_dim=8):
        super(MultiOutputCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim, 2, 2) # (8, 64, 64)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim * 4, 4, 4) # (32, 16, 16)
        self.conv3 = nn.Conv2d(hidden_dim * 4, hidden_dim * 16, 4, 4) # (128, 4, 4)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(2048, num_counts)
        self.fc2 = nn.Linear(2048, num_particles)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
#         x = self.conv4(x)
#         x = F.relu(x)
#         x = self.conv5(x)
#         x = F.relu(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)

        y1 = self.fc1(x)
        y1 = F.log_softmax(y1, dim=1)
        y2 = self.fc2(x)
        y2 = F.log_softmax(y2, dim=1)
        return y1, y2

## 5-layer Multi-output CNN (Late)

In [None]:
class MultiOutputCNN(nn.Module):
    def __init__(self, num_particles=11, num_counts=4, hidden_dim=8):
        super(MultiOutputCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim, 2, 2) # (8, 64, 64)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim * 2, 2, 2) # (16, 32, 32)
        self.conv3 = nn.Conv2d(hidden_dim * 2, hidden_dim * 4, 2, 2) # (32, 16, 16)
        self.conv4 = nn.Conv2d(hidden_dim * 4, hidden_dim * 8, 2, 2) # (64, 8, 8)
        self.conv5 = nn.Conv2d(hidden_dim * 8, hidden_dim * 16, 2, 2) # (128, 4, 4)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(2048, num_counts)
        self.fc2 = nn.Linear(2048, num_particles)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)

        y1 = self.fc1(x)
        y1 = F.log_softmax(y1, dim=1)
        y2 = self.fc2(x)
        y2 = F.log_softmax(y2, dim=1)
        return y1, y2

## 10-layer Multi-output CNN (Late)

In [None]:
class MultiOutputCNN(nn.Module):
    def __init__(self, num_particles=11, num_counts=4, hidden_dim=8):
        super(MultiOutputCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim, 2, 2) # (8, 64, 64)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim * 2, 2, 2) # (16, 32, 32)
        self.conv3 = nn.Conv2d(hidden_dim * 2, hidden_dim * 4, 2, 2) # (32, 16, 16)
        self.conv4 = nn.Conv2d(hidden_dim * 4, hidden_dim * 8, 2, 2) # (64, 8, 8)
        self.conv5 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 7, 7)
        self.conv6 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 6, 6)
        self.conv7 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 5, 5)
        self.conv8 = nn.Conv2d(hidden_dim * 8, hidden_dim * 16, 1, 1) # (128, 5, 5)
        self.conv9 = nn.Conv2d(hidden_dim * 16, hidden_dim * 16, 2, 1) # (128, 4, 4)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(2048, num_counts)
        self.fc2 = nn.Linear(2048, num_particles)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        x = F.relu(x)
        x = self.conv6(x)
        x = F.relu(x)
        x = self.conv7(x)
        x = F.relu(x)
        x = self.conv8(x)
        x = F.relu(x)
        x = self.conv9(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)

        y1 = self.fc1(x)
        y1 = F.log_softmax(y1, dim=1)
        y2 = self.fc2(x)
        y2 = F.log_softmax(y2, dim=1)
        return y1, y2

## 18-layer Multi-output CNN (Late)

In [None]:
class MultiOutputCNN(nn.Module):
    def __init__(self, num_particles=11, num_counts=4, hidden_dim=8):
        super(MultiOutputCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim, 2, 2) # (8, 64, 64)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim * 2, 2, 2) # (16, 32, 32)
        self.conv3 = nn.Conv2d(hidden_dim * 2, hidden_dim * 4, 2, 2) # (32, 16, 16)
        self.conv4 = nn.Conv2d(hidden_dim * 4, hidden_dim * 4, 2, 1) # (32, 15, 15)
        self.conv5 = nn.Conv2d(hidden_dim * 4, hidden_dim * 4, 2, 1) # (32, 14, 14)
        self.conv6 = nn.Conv2d(hidden_dim * 4, hidden_dim * 4, 2, 1) # (32, 13, 13)
        self.conv7 = nn.Conv2d(hidden_dim * 4, hidden_dim * 4, 2, 1) # (32, 12, 12)
        self.conv8 = nn.Conv2d(hidden_dim * 4, hidden_dim * 4, 2, 1) # (32, 11, 11)
        self.conv9 = nn.Conv2d(hidden_dim * 4, hidden_dim * 4, 2, 1) # (32, 10, 10)
        self.conv10 = nn.Conv2d(hidden_dim * 4, hidden_dim * 4, 2, 1) # (32, 9, 9)
        self.conv11 = nn.Conv2d(hidden_dim * 4, hidden_dim * 8, 1, 1) # (64, 9, 9)
        self.conv12 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 8, 8)
        self.conv13 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 7, 7)
        self.conv14 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 6, 6)
        self.conv15 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 5, 5)
        self.conv16 = nn.Conv2d(hidden_dim * 8, hidden_dim * 16, 1, 1) # (128, 5, 5)
        self.conv17 = nn.Conv2d(hidden_dim * 16, hidden_dim * 16, 2, 1) # (128, 4, 4)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(2048, num_counts)
        self.fc2 = nn.Linear(2048, num_particles)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        x = F.relu(x)
        x = self.conv6(x)
        x = F.relu(x)
        x = self.conv7(x)
        x = F.relu(x)
        x = self.conv8(x)
        x = F.relu(x)
        x = self.conv9(x)
        x = F.relu(x)
        x = self.conv10(x)
        x = F.relu(x)
        x = self.conv11(x)
        x = F.relu(x)
        x = self.conv12(x)
        x = F.relu(x)
        x = self.conv13(x)
        x = F.relu(x)
        x = self.conv14(x)
        x = F.relu(x)
        x = self.conv15(x)
        x = F.relu(x)
        x = self.conv16(x)
        x = F.relu(x)
        x = self.conv17(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)

        y1 = self.fc1(x)
        y1 = F.log_softmax(y1, dim=1)
        y2 = self.fc2(x)
        y2 = F.log_softmax(y2, dim=1)
        return y1, y2

## Multi-output CNN (Early)

In [None]:
class MultiOutputCNN(nn.Module):
    def __init__(self, num_particles=11, num_counts=4, hidden_dim=8):
        super(MultiOutputCNN, self).__init__()
        
        
        self.conv1 = nn.Conv2d(1, hidden_dim, 2, 2) # (8, 64, 64)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim * 2, 2, 2) # (16, 32, 32)
        
        #Brached
        self.conv3_b1 = nn.Conv2d(hidden_dim * 2, hidden_dim * 4, 2, 2) # (32, 16, 16)
        self.conv4_b1 = nn.Conv2d(hidden_dim * 4, hidden_dim * 8, 2, 2) # (64, 8, 8)
        self.conv5_b1 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 7, 7)
        self.conv6_b1 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 6, 6)
        self.conv7_b1 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 5, 5)
        self.conv8_b1 = nn.Conv2d(hidden_dim * 8, hidden_dim * 16, 1, 1) # (128, 5, 5)
        self.conv9_b1 = nn.Conv2d(hidden_dim * 16, hidden_dim * 16, 2, 1) # (128, 4, 4)
        self.dropout1_b1 = nn.Dropout(0.25)
        self.fc1_b1 = nn.Linear(2048, num_counts)
        
        self.conv3_b2 = nn.Conv2d(hidden_dim * 2, hidden_dim * 4, 2, 2) # (32, 16, 16)
        self.conv4_b2 = nn.Conv2d(hidden_dim * 4, hidden_dim * 8, 2, 2) # (64, 8, 8)
        self.conv5_b2 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 7, 7)
        self.conv6_b2 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 6, 6)
        self.conv7_b2 = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 2, 1) # (64, 5, 5)
        self.conv8_b2 = nn.Conv2d(hidden_dim * 8, hidden_dim * 16, 1, 1) # (128, 5, 5)
        self.conv9_b2 = nn.Conv2d(hidden_dim * 16, hidden_dim * 16, 2, 1) # (128, 4, 4)
        self.dropout1_b2 = nn.Dropout(0.25)
        self.fc1_b2 = nn.Linear(2048, num_particles)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        
        #Brached
        y1 = self.conv3_b1(x)
        y1 = F.relu(y1)
        y1 = self.conv4_b1(y1)
        y1 = F.relu(y1)
        y1 = self.conv5_b1(y1)
        y1 = F.relu(y1)
        y1 = self.conv6_b1(y1)
        y1 = F.relu(y1)
        y1 = self.conv7_b1(y1)
        y1 = F.relu(y1)
        y1 = self.conv8_b1(y1)
        y1 = F.relu(y1)
        y1 = self.conv9_b1(y1)
        y1 = F.relu(y1)
        y1 = self.dropout1_b1(y1)
        y1 = torch.flatten(y1, 1)
        y1 = self.fc1_b1(y1)
        y1 = F.log_softmax(y1, dim=1)

        y2 = self.conv3_b2(x)
        y2 = F.relu(y2)
        y2 = self.conv4_b2(y2)
        y2 = F.relu(y2)
        y2 = self.conv5_b2(y2)
        y2 = F.relu(y2)
        y2 = self.conv6_b2(y2)
        y2 = F.relu(y2)
        y2 = self.conv7_b2(y2)
        y2 = F.relu(y2)
        y2 = self.conv8_b2(y2)
        y2 = F.relu(y2)
        y2 = self.conv9_b2(y2)
        y2 = F.relu(y2)
        y2 = self.dropout1_b2(y2)
        y2 = torch.flatten(y2, 1)
        y2 = self.fc1_b2(y2)
        y2 = F.log_softmax(y2, dim=1)

        return y1, y2

## ResNet18

In [None]:
class CustomResNet18Model(nn.Module):
    def __init__(self, num_counts, num_particles):
        super(CustomResNet18Model, self).__init__()
        self.model_resnet = models.resnet18(pretrained=False)
        self.model_resnet.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=True)
        
        self.model_resnet.fc.register_forward_hook(lambda m, inp, out: F.dropout(out, p=0.5, training=m.training))
        
        num_ftrs = self.model_resnet.fc.in_features
        self.model_resnet.fc = nn.Identity()
        self.fc1 = nn.Linear(num_ftrs, num_counts)
        self.fc2 = nn.Linear(num_ftrs, num_particles)
    def forward(self, x):
        x = self.model_resnet(x)
        out1 = self.fc1(x)
        y1 = F.log_softmax(out1, dim=1)
        out2 = self.fc2(x)
        y2 = F.log_softmax(out2, dim=1)
        return y1, y2

## VGG16

In [None]:
class CustomVgg16Model(nn.Module):
    def __init__(self, num_counts, num_particles):
        super(CustomVgg16Model, self).__init__()
        self.model_vgg16 = models.vgg16(pretrained=False, progress=True)
        self.model_vgg16.features[0] = torch.nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1))
        num_ftrs = self.model_vgg16.classifier[0].in_features
        self.model_vgg16.classifier = nn.Identity()
        self.fc1 = nn.Linear(num_ftrs, num_counts)
        self.fc2 = nn.Linear(num_ftrs, num_particles)
    def forward(self, x):
        x = self.model_vgg16(x)
        out1 = self.fc1(x)
        y1 = F.log_softmax(out1, dim=1)
        out2 = self.fc2(x)
        y2 = F.log_softmax(out2, dim=1)
        return y1, y2

# Evalutate

In [None]:
def evaluate(model, loss_fn, dataloader):
    """Evaluate the model on `num_steps` batches.
    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
    """
    model.eval()

    accuracies = []
    loss = 0.0
    all_preds = []
    all_labels = []
    preds1 = []
    preds2 = []
    all_count_labels = []
    all_particle_labels = []
    all_images = []
    for i, (inputs, count_labels, particle_labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        if not args.multi_output:
            outputs = model(inputs)
            batch_loss = loss_fn(outputs, count_labels.squeeze(0).to(device))
            loss += batch_loss
            preds = torch.argmax(outputs, dim=-1)
            all_preds.append(preds)
            all_labels.append(count_labels)
            loss = loss / len(dataloader)
        else:
            y1, y2 = model(inputs)
            loss1 = loss_fn(y1, count_labels.squeeze(0).to(device)).detach().item()
            loss2 = loss_fn(y2, particle_labels.squeeze(0).to(device)).detach().item()
            batch_loss = 4 * loss1 + loss2
            loss += batch_loss
            y1 = torch.argmax(y1, dim=-1).to('cpu').numpy().tolist()
            y2 = torch.argmax(y2, dim=-1).to('cpu').numpy().tolist()
            
            preds1.extend(y1)
            preds2.extend(y2)
            all_images.extend(inputs)
            all_count_labels.extend(count_labels.to('cpu').numpy().tolist())
            all_particle_labels.extend(particle_labels.to('cpu').numpy().tolist())
    
    torch.cuda.empty_cache()
    
    # Total accuracy
    correct_pred = [1 if (preds1[i] == all_count_labels[i] and preds2[i] == all_particle_labels[i]) else 0 for i in range(len(preds1))]
    accuracy = sum(correct_pred) / len(preds1) * 100
    
    # Count accuracy
    correct_pred_count = [1 if (preds1[i] == all_count_labels[i]) else 0 for i in range(len(preds1))]
    count_accuracy = sum(correct_pred_count) / len(preds1) * 100
    
    # Particle accuracy
    correct_pred_particle = [1 if (preds2[i] == all_particle_labels[i]) else 0 for i in range(len(preds2))]
    particle_accuracy = sum(correct_pred_particle) / len(preds1) * 100
    
    loss = loss / len(dataloader)
    
    # Compute accuracy for each particle type
    particle_acc_dict = {}
    for idx in idx2particle:
        temp = {}
        temp['crct'] = sum([1 if (preds2[i] == idx and preds2[i] == all_particle_labels[i]) else 0 for i in range(len(preds2))])
        temp['total'] = sum([1 if (all_particle_labels[i] == idx) else 0 for i in range(len(all_particle_labels))])
        particle_acc_dict[idx2particle[idx]] = temp['crct'] / temp['total']
        if idx2particle[idx] == '7dx8':
            wrong_pred_7dx8 = [preds2[i] if (all_particle_labels[i] == idx and preds2[i] != all_particle_labels[i]) else None for i in range(len(preds2))]
    print(particle_acc_dict)

    return accuracy, count_accuracy, particle_accuracy, loss

### Draw confusion matrix for particle prediction

In [None]:
_, _, test_dataloader = get_dataloaders(args, PARTICLES, PARTICLES, test_diff_particle=False)

In [None]:
import pandas as pd
import seaborn as sns

nb_classes = 11
confusion_matrix = np.zeros((nb_classes, nb_classes))

for i, (inputs, count_labels, particle_labels) in enumerate(test_dataloader):
    inputs = inputs.to(device)
    _, y2 = model(inputs)
    y2 = torch.argmax(y2, dim=-1).to('cpu').numpy().tolist()
    for t, p in zip(particle_labels, y2):
        confusion_matrix[t, p] += 1

plt.figure(figsize=(15,10))

class_names = list(particle2idx.keys())
df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names).astype(int)
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")

heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right',fontsize=12)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right',fontsize=12)
plt.ylabel('True label', fontsize=15)
plt.xlabel('Predicted label', fontsize=15)

# Train

In [None]:
def train(args, model, optimizer, loss_fn):
    """""
    Train the network on the training data
    """

    EPOCH = args.epoches

    train_dataloader, valid_dataloader, test_dataloader = get_dataloaders(args, 
                                                                          PARTICLES, 
                                                                          PARTICLES,
                                                                          test_diff_particle=False)
    step = 0

    train_loss_values = []
    train_accuracies = {
        'Total': [],
        'Count': [],
        'Particle': []
    }
    valid_loss_values = []
    valid_accuracies = {
        'Total': [],
        'Count': [],
        'Particle': []
    }
    for epoch in range(EPOCH):
        epoch_train_loss = 0.0
        with tqdm(total=len(train_dataloader)) as t: 
            for i, (inputs, count_labels, particle_labels) in enumerate(train_dataloader):
                step += 1
                model.train()

                inputs = inputs.to(device)
                
                if not args.multi_output:
                    outputs = model(inputs)
                    loss = loss_fn(outputs, count_labels.squeeze(0).to(device))
                else:
                    y1, y2 = model(inputs)
                    loss1 = loss_fn(y1, count_labels.squeeze(0).to(device))
                    loss2 = loss_fn(y2, particle_labels.squeeze(0).to(device))
                    loss = 4 * loss1 + loss2
                
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

                cost = loss.item()
                
                epoch_train_loss = cost
                
                t.set_postfix(train_loss='{:05.3f}'.format(cost))
                t.update()
                
                # torch.cuda.empty_cache()
        
        if epoch % args.evaluate_every == 0:
            valid_accuracy, valid_count_accuracy, valid_particle_accuracy, valid_loss = evaluate(model, criterion, valid_dataloader)
            valid_accuracies['Total'].append(valid_accuracy)
            valid_accuracies['Count'].append(valid_count_accuracy)
            valid_accuracies['Particle'].append(valid_particle_accuracy)
            valid_loss_values.append(valid_loss)
            
            train_accuracy, train_count_accuracy, train_particle_accuracy, train_loss = evaluate(model, criterion, train_dataloader)
            train_accuracies['Total'].append(train_accuracy)
            train_accuracies['Count'].append(train_count_accuracy)
            train_accuracies['Particle'].append(train_particle_accuracy)
            train_loss_values.append(train_loss)
            
            print(f'Step {step}: valid loss={valid_loss}, \n valid accuracy={valid_accuracy}, \n valid count accuracy={valid_count_accuracy}, \n valid particle accuracy={valid_particle_accuracy}')
            print(f'Step {step}: train loss={train_loss}, \n train accuracy={train_accuracy}, \n train count accuracy={train_count_accuracy}, \n valid particle accuracy={train_particle_accuracy}')

            
    torch.save(model.state_dict(), args.ckpt_path)
    
    plot_loss(args, train_loss_values, 'train')
    plot_accuracies(args, train_accuracies, 'train')
    plot_loss(args, valid_loss_values, 'validation')
    plot_accuracies(args, valid_accuracies, 'validation')
    
    test_accuracy, test_count_accuracy, test_particle_accuracy, _ = evaluate(model, criterion, test_dataloader)
    print('Test accuracy: %f' % test_accuracy)
    print('Test count accuracy: %f' % test_count_accuracy)
    print('Test particle accuracy: %f' % test_particle_accuracy)

In [None]:
def plot_accuracies(args, accuracies, split):
    plt.figure(figsize=(6,5))
    plt.title(f"Total, count, and particle accuracies for {split} set")
    plt.plot(accuracies['Total'], label="Total", color='r')
    plt.plot(accuracies['Count'], label="Count", color='g')
    plt.plot(accuracies['Particle'], label="Particle", color='b')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig(f'{args.logdir}/{args.model}_{split}_accuracies_{dt.now()}.png')
    plt.show()

In [None]:
def plot_loss(args, loss_values, split):
    plt.figure(figsize=(6,5))
    plt.title(f"{split} loss")
    plt.plot(loss_values,label="train", color='b')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f'{args.logdir}/{args.model}_{split}_loss_{dt.now()}.png')
    plt.show()

# Run the pipeline

In [None]:
def load_model(args):
    if args.model == 'multi_output_cnn':
        num_particles = args.num_particles
        num_counts = args.num_counts
        hidden_dim = 8
        model = MultiOutputCNN(num_particles=num_particles, num_counts=num_counts, hidden_dim=hidden_dim).to(device)
    elif args.model == 'multi_output_resnet18':
        num_particles = args.num_particles
        num_counts = args.num_counts
        model = CustomResNet18Model(num_counts, num_particles).to(device)
    elif args.model == 'multi_output_vgg16':
        num_particles = args.num_particles
        num_counts = args.num_counts
        model = CustomVgg16Model(num_counts, num_particles).to(device)
    return model

In [None]:
PARTICLES = ['1fpv', '1ss8', '3j03', '1ijg', '3iyf', '6ody', '6sp2', '6xs6', '7dwz', '7dx8', '7dx9']
COUNTS = ['single', 'double', 'triple', 'quadruple']
LENGTH = 1000

In [None]:
args = {
    'model': 'multi_output_cnn', # multi_output_cnn || multi_output_resnet18 || multi_output_vgg16
    'root_dir': './data/thumbnail',
    'epoches': 20,
    'batch_size': 128,
    'shuffle': False,
    'num_workers': 1,
    'num_particles': 11,
    'num_counts': 4,
    'length': 1000,
    'evaluate_every': 1,
    'logdir': './logs',
    'multi_output': True
}
args = Namespace(**args)
args.ckpt_path = f'{args.logdir}/{args.model}_checkpoint.pth'

In [None]:
"""
Define loss function and optimizer

We will use the cross entropy loss and Adam optimizer
"""

# Create model
model = load_model(args)

# Define the cost function
criterion = nn.CrossEntropyLoss()

# Define the optimizer, learning rate 
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)

In [None]:
model

In [None]:
train(args, model, optimizer, criterion)

In [None]:
summary(model, input_size=(1, 128, 128))