# CNNs from the [following paper](http://cs230.stanford.edu/projects_spring_2021/reports/54.pdf)

In [None]:
import socket
print(socket.gethostname())

# Import Models

In [None]:
# Import models from multioutput_cnns.py
from multioutput_cnns import MultiOutputCNN_3Layer, MultiOutputCNN_5Layer, MultiOutputCNN_10Layer, MultiOutputCNN_18Layer, MultiOutputCNN_Early, CustomResNet18Model, CustomVgg16Model

# Import packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset, Dataset, DataLoader

import numpy as np
import h5py
import matplotlib.pyplot as plt
from PIL import Image
import argparse
from argparse import Namespace

import os
import random
from tqdm import tqdm

import tensorboard
# from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

from resnet import resnet18, resnet34, resnet50
from datetime import datetime as dt

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
device

# Dataset

In [None]:
particle2idx = {
    '1fpv': 0,
    '1ss8': 1,
    '3j03': 2,
    '1ijg': 3,
    '3iyf': 4,
    '6ody': 5,
    '6sp2': 6,
    '6xs6': 7,
    '7dwz': 8,
    '7dx8': 9,
    '7dx9': 10
}

count2idx = {
    'single': 0,
    'double': 1,
    'triple': 2,
    'quadruple': 3
}

In [None]:
idx2particle = {
    0: '1fpv',
    2: '1ss8',
    2: '3j03',
    3: '1ijg',
    4: '3iyf',
    5: '6ody',
    6: '6sp2',
    7: '6xs6',
    8: '7dwz',
    9: '7dx8',
    10: '7dx9'
}

idx2count = {
    0: 'single',
    1: 'double',
    2: 'triple',
    3: 'quadruple'
}

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, particles, counts, transform=None, seed=1234):
        """
        CustomDataset is used to contain diffraction images used in model training.
        
        Parameters
        ----------
        root_dir: str
            String representing the directory path of the diffraction image datasets.
        particles: list(str)
            List of strings representing the PDB IDs of the particles being used for model training.
        counts: list(str)
            List of strings representing the particle count of the images being used for model training.
        transform: torchvision.transforms.Compose
            A torchvision.transforms.Compose object containing the transforms to apply to the diffraction images.
        seed: int
            An integer used to seed the randomization of the order of the data.
        """
        self.root_dir = root_dir
        self.transform = transform

        self.count_labels = []
        self.particle_labels = []
        self.data = []

        for particle in particles:
            for count in counts:
                
                # Create directory path to dataset
                n = 5    # Used in generating the file names of files to open; n = 5 will open up 5k image datasets
                data_dir = f'{self.root_dir}/{particle}_{str(n)}k_{count}_pps_1e14_thumbnail.h5'

                # Load images as h5 files
                f = h5py.File(data_dir, 'r')
                dset_name = list(f.keys())[0]
                data = f[dset_name]
                data = [Image.fromarray(data[i]) for i in range(LENGTH * n)]   # Converts data into PIL images.
                
                # DEBUG
                # Display 20 of the original images
                #print('Original shape: ' + str(data[0].size))
                #fig = plt.figure(figsize=(15, 15))
                #fig.suptitle('Original', y=0.91, fontsize=16)
                #columns = 4
                #rows = 5
                #for i in range(1, columns * rows + 1):
                #    img = np.array(data[i - 1])
                #    fig.add_subplot(rows, columns, i)
                #    plt.imshow(img, vmin=0, vmax=25)
                #    plt.colorbar()
                #plt.show()
                
                # Apply transforms to images
                data = [self.transform(data[i]) for i in range(LENGTH * n)]
                
                # DEBUG
                # Display 20 of the transformed images
                #print('Transformed shape: ' + str(data[0].shape))
                #fig = plt.figure(figsize=(15, 15))
                #fig.suptitle('Transformed From Original', y=0.91, fontsize=16)
                #columns = 4
                #rows = 5
                #for i in range(1, columns * rows + 1):
                #    img = data[i - 1].numpy()
                #    img = np.squeeze(img)
                #    fig.add_subplot(rows, columns, i)
                #    plt.imshow(img, vmin=0, vmax=25)
                #    plt.colorbar()
                #plt.show()
                
                # Apply image labels
                count_label = [count2idx[count]] * (LENGTH * n)
                particle_label = [particle2idx[particle]] * (LENGTH * n)
                self.data.extend(data)
                self.count_labels.extend(count_label)
                self.particle_labels.extend(particle_label)
        
        # Shuffle the data
        random.seed(seed)
        perm = list(range(len(self.data)))
        random.shuffle(perm)
        self.data = [self.data[i] for i in perm]
        self.count_labels = [self.count_labels[i] for i in perm]
        self.particle_labels = [self.particle_labels[i] for i in perm]

    def __len__(self):
        '''Denotes the total number of samples'''
        return len(self.data)

    def __getitem__(self, index):
        '''Generates one sample of data'''
        X = self.data[index]
        count = self.count_labels[index]
        particle = self.particle_labels[index]
        return X, count, particle

# AddNoise Class: A wrapper for the addNoise() function.

In [None]:
class AddNoise(object):
    """
    A wrapper for addNoise(), so that it can be used with other torch transform functions.
    AddNoise applies the following noise to images:
        1. Reduce the fluence by a factor of 100.
        2. Add poisson noise.
        3. Add gaussian noise given a sigma value.
        4. Varience normalization.
    """
    
    def __init__(self, flux_jitter, gaussian_noise):
        """
        Parameters
        ----------
        flux_jitter: float
            Flux jitter to use when reducing the fluence of the image.
            
        gaussian_noise: float
            Alias for sigma to be used in gaussian distribution. Sets how much
            gaussian noise to apply to image.
        """
        assert isinstance(flux_jitter, float)
        assert isinstance(gaussian_noise, float)
        self.flux_jitter = flux_jitter
        self.gaussian_noise = gaussian_noise

    def __call__(self, image):
        """ Called by PyTorch when applying the AddNoise transform. """
        return self.addNoise(image)

    def addNoise(self, orig_img):
        """
        Applies the following noise to a given image:
            1. Reduce the fluence by a factor of 100.
            2. Add poisson noise.
            3. Add gaussian noise given a sigma value.
            4. Varience normalization.
            
        Parameters
        ----------
        orig_img: PIL Image
            Image to apply noise to.
            
        Returns
        -------
        PIL Image
            Image with noise applied to it.
        """
        
        def changeIntensity(img, flux_jitter):
            factor = 100 # Reduces the fluence of the image by a factor of 100. Ex. 1e14 photons/pulse -> 1e12 photons/pulse.
            mu = 1 # mean jitter
            alpha = np.random.normal(mu, flux_jitter)
            if alpha <= 0: alpha = 0.1 # alpha can't be zero
            n_photons = alpha*np.sum(img) / factor                     # number of desired photons per image
            return n_photons*(img/np.sum(img)) # cache noise-free measurement
    
        def poisson(img):
            # add poisson noise
            return np.random.poisson(img)      # apply Poisson statistics
    
        def gaussian(img, sigma):
            # add gaussian noise 
            # For random samples from N(\mu, \sigma^2), 
            # mu + sigma * np.random.randn(...)
            # sigma: Gaussian noise level
            img = img + sigma*np.random.randn(*img.shape);  # apply Gaussian statistics
            return img
    
        def varNorm(V):
            # variance normalization, each image has mean 0, variance 1
            # This shouldn't happen, but zero out infinite pixels
            V[np.argwhere(V==np.inf)] = 0
            mean = np.mean(V)
            std = np.std(V)
            if std == 0:
                return np.zeros_like(V)
            V1 = (V-mean)/std
            return V1

        def transform(img):
            img = changeIntensity(img, self.flux_jitter)
            img = poisson(img)
            img = gaussian(img, self.gaussian_noise)
            img = varNorm(img)
            return img
        
        return Image.fromarray(transform(orig_img))
    

# Dataloader

In [None]:
def get_dataloaders(args, train_val_particles, test_particles, test_diff_particle=False):
    
    """
    Creates torch.utils.data.DataLoader objects for the training, validation, and testing. Part of this includes applying augmentations and noise to the diffraction images.

    Parameters
    ----------
    args
        args.num_particles: int
            Number of unique particles in train_val_particles and test_particles.
        args.root_dir: str
            String representation of directory containing the data needed for training/testing.
        args.batch_size: int
            Batch size for DataLoaders.
        args.shuffle: bool
            Shuffle the data in DataLoaders.
        args.num_workers: int
            Number of subprocesses to use for data loading.
    
    train_val_particles: list(str)
        List of str representing the PDB IDs of particles used for training and validation sets.

    test_particles: list(str) 
        List of str representing the PDB IDs of particles used for test sets.

    test_diff_particle: bool
        If True, create a test dataloader that uses a different set of particles not specified in train_val_particles or test_particles.
        If False, create a test dataloader that uses the same set of particles specified in train_val_particles/test_particles; train_val_particles and test_particles must be the same!

    Return
    ------
    A training DataLoader, a validation DataLoader, and a test DataLoader.
    """
    
    # Augmentations and noise applied to images:
    # 1. CenterCrop(128): Crops image at the center with an output size of (128, 128).
    # 2. RandomVerticalFlip(p=0.5): Flips image vertically with 50% probability.
    # 3. RandomHorizontalFlip(p=0.5): Flips image horizontally with 50% probability.
    # 4. RandomAffine(degrees=360, scale(0.9, 1.1)): Random rotation w/ range of (-360, 360) and random zoom w/ range of (0.9x, 1.1x).
    # 5. AddNoise(0.9, 0.15): Applies noise with flux jitter of 0.9 and gaussian noise of 0.15.
    # 6. ToTensor(): Convert image to PyTorch tensor.
    transform = transforms.Compose([transforms.CenterCrop(128),
                                    transforms.RandomVerticalFlip(p=0.5),
                                    transforms.RandomHorizontalFlip(p=0.5),
                                    transforms.RandomAffine(degrees=360, scale=(0.9, 1.1)),
                                    AddNoise(flux_jitter=0.9, gaussian_noise=0.15),
                                    transforms.ToTensor()])
    
    # Total number of images in datasets.
    # 20000 = (5k single images) + (5k double images) + (5k triple images) + (5k quadruple images)
    data_len = args.num_particles * 20000
    
    if not test_diff_particle: # Create train, validation, and test datasets using the same set of particles.
        assert train_val_particles == test_particles
        dataset = CustomDataset(root_dir=args.root_dir,
                                particles=train_val_particles,
                                counts=COUNTS,
                                transform=transform)
        
        # Split the data into a train, validation, and test set as follows:
        # The first 70% of the dataset is for the training dataset.
        # The next 10% of the dataset is for the validation dataset.
        # The last 20% of the dataset is for the test dataset.
        train_idx = list(range(0, int(data_len * 0.7)))
        valid_idx = list(range(int(data_len * 0.7), int(data_len * 0.8)))
        test_idx = list(range(int(data_len * 0.8), data_len))
        
        # More information on PyTorch Subset: https://pytorch.org/docs/stable/data.html
        # Create the train, validation, and test datasets.
        train_dataset = Subset(dataset, train_idx) 
        valid_dataset = Subset(dataset, valid_idx)
        test_dataset = Subset(dataset, test_idx)
    else: # Create train and validation datasets using the same set of particles, and the test dataset with a different set of particles.
        
        # Create train/valid/test datasets
        train_val_dataset = CustomDataset(root_dir=args.root_dir, 
                                          particles=train_val_particles,
                                          counts=COUNTS,
                                          transform=transform)
        
        # More information on PyTorch Subset: https://pytorch.org/docs/stable/data.html
        # Split the data into a train and validation set as follows:
        # The first 7000 diffraction images is for the training dataset.
        # The next 1000 diffraction images is for the validation dataset.
        train_idx = list(range(0, 7000))
        valid_idx = list(range(7000, 8000))
        
        # Create the train, validation, and test datasets.
        train_dataset = Subset(train_val_dataset, train_idx) 
        valid_dataset = Subset(train_val_dataset, valid_idx)
        
        # The test dataset in this case contains diffraction images of particles not in
        # the training nor validation set.
        test_dataset = CustomDataset(root_dir=args.root_dir, 
                                    particles=test_particles,
                                    counts=COUNTS,
                                    transform=transform)
        
        # Check to see that the images in the train, validation, and test datasets
        # have the same shape of (1, 128, 128).
        assert train_dataset.__getitem__(0)[0].shape == torch.Size([1, 128, 128])
        assert valid_dataset.__getitem__(0)[0].shape == torch.Size([1, 128, 128])
        assert test_dataset.__getitem__(0)[0].shape == torch.Size([1, 128, 128])

    # Create train/valid/test dataloaders
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle, 
                                  num_workers=args.num_workers)
    valid_dataloader = DataLoader(dataset=valid_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle, 
                                  num_workers=args.num_workers)
    test_dataloader = DataLoader(dataset=test_dataset, 
                                 batch_size=args.batch_size, 
                                 shuffle=args.shuffle, 
                                 num_workers=args.num_workers)
    return train_dataloader, valid_dataloader, test_dataloader

# Evaluate

In [None]:
def evaluate(model, loss_fn, dataloader):
    """Evaluate the model on `num_steps` batches.
    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        
    Return:
        accuracy: float
            Overall accuracy of the model.
        count_accuracy: float
            Accuracy of model in identifying whether an image is single-hit or multi-hit (ex. double, triple, quadruple).
        particle_accuracy: float
            Accuracy of model in identifying the particle from diffraction images.
        loss: float
            Loss of the model a training step.
    """
    
    # Evaluate the model using PyTorch's eval()
    model.eval()

    # Initial necessary lists and loss variables.
    accuracies = []
    loss = 0.0
    all_preds = []
    all_labels = []
    preds1 = []
    preds2 = []
    all_count_labels = []
    all_particle_labels = []
    all_images = []
    
    for i, (inputs, count_labels, particle_labels) in enumerate(dataloader):
        
        # Send images to device.
        inputs = inputs.to(device)
        
        if not args.multi_output: # If the model IS NOT multi-output...
            # Get predictions from model.
            outputs = model(inputs)
            
            # Calculate the loss for the batch.
            batch_loss = loss_fn(outputs, count_labels.squeeze(0).to(device))
            
            # Add batch loss to total loss.
            loss += batch_loss
            
            # Record predictions and loss.
            preds = torch.argmax(outputs, dim=-1)
            all_preds.append(preds)
            all_labels.append(count_labels)
            
            loss = loss / len(dataloader)
        else: # If the model IS multi-output...
            
            # y1 contains predictions for particle COUNT classification.
            # y2 contains predictions for particle PDB ID classification.
            y1, y2 = model(inputs)
            
            # loss1 is model loss for particle COUNT classification.
            # loss2 is model loss for particle PDB ID classification.
            loss1 = loss_fn(y1, count_labels.squeeze(0).to(device)).detach().item()
            loss2 = loss_fn(y2, particle_labels.squeeze(0).to(device)).detach().item()
            
            # Total loss of the batch.
            # This is a weighted loss, with the loss from COUNT classification having more weight.
            # Mentioned in Section 4.2 of paper mentioned at the top of the notebook.
            batch_loss = 4 * loss1 + loss2
            
            # Add batch loss to total loss.
            loss += batch_loss
            
            # Convert particle count and id predictions to a list.
            y1 = torch.argmax(y1, dim=-1).to('cpu').numpy().tolist()
            y2 = torch.argmax(y2, dim=-1).to('cpu').numpy().tolist()
            
            # Record predictions.
            preds1.extend(y1)
            preds2.extend(y2)
            all_images.extend(inputs)
            all_count_labels.extend(count_labels.to('cpu').numpy().tolist())
            all_particle_labels.extend(particle_labels.to('cpu').numpy().tolist())
    
    torch.cuda.empty_cache()
    
    # Total accuracy
    correct_pred = [1 if (preds1[i] == all_count_labels[i] and preds2[i] == all_particle_labels[i]) else 0 for i in range(len(preds1))]
    accuracy = sum(correct_pred) / len(preds1) * 100
    
    # Count accuracy
    correct_pred_count = [1 if (preds1[i] == all_count_labels[i]) else 0 for i in range(len(preds1))]
    count_accuracy = sum(correct_pred_count) / len(preds1) * 100
    
    # Particle accuracy
    correct_pred_particle = [1 if (preds2[i] == all_particle_labels[i]) else 0 for i in range(len(preds2))]
    particle_accuracy = sum(correct_pred_particle) / len(preds1) * 100
    
    loss = loss / len(dataloader)
    
    # Compute accuracy for each particle type
    particle_acc_dict = {}
    for idx in idx2particle:
        temp = {}
        temp['crct'] = sum([1 if (preds2[i] == idx and preds2[i] == all_particle_labels[i]) else 0 for i in range(len(preds2))])
        temp['total'] = sum([1 if (all_particle_labels[i] == idx) else 0 for i in range(len(all_particle_labels))])
        particle_acc_dict[idx2particle[idx]] = temp['crct'] / temp['total']
        if idx2particle[idx] == '7dx8':
            wrong_pred_7dx8 = [preds2[i] if (all_particle_labels[i] == idx and preds2[i] != all_particle_labels[i]) else None for i in range(len(preds2))]
    print(particle_acc_dict)

    return accuracy, count_accuracy, particle_accuracy, loss

# Train

In [None]:
def train(args, model, optimizer, loss_fn):
    """
    Train the network on the training data.
    
    Parameters
    ----------
    args
        args.epoches: int
            Number of epoches training should last.
        args.multi_output: True
            If True, specifies that the model is multi-output.
            If False, specifies that the model is not multi-output.
            Affects the loss function used in model training.
        args.evaluate_every: int
            Number of epoches that must occur before recording a training/validation accuracy for recordkeeping.
        args.ckpt_path: str
            Location to record model training checkpoints.
    model: torch.nn.Module
        Model to train.
    optimizer: Optimizers in torch.optim
        Optimizer to use in training.
    loss_fn: PyTorch Loss Function class object
        PyTorch cost function to use in model training.
    """
    
    # Set number of epochs for training.
    EPOCH = args.epoches
    
    # Generate dataloaders.
    train_dataloader, valid_dataloader, test_dataloader = get_dataloaders(args, 
                                                                          PARTICLES, 
                                                                          PARTICLES,
                                                                          test_diff_particle=False)
    
    # Step counter.
    step = 0

    # Initialize lists and dictionaries to use to store loss and accuracy.
    # These are used for plotting the loss and accuracy of the model during the training process.
    train_loss_values = []
    train_accuracies = {
        'Total': [],
        'Count': [],
        'Particle': []
    }
    valid_loss_values = []
    valid_accuracies = {
        'Total': [],
        'Count': [],
        'Particle': []
    }
    
    # Training loop.
    for epoch in range(EPOCH):
        epoch_train_loss = 0.0
        with tqdm(total=len(train_dataloader)) as t: 
            for i, (inputs, count_labels, particle_labels) in enumerate(train_dataloader):
                
                # Train model.
                step += 1
                model.train()

                # Send inputs to device.
                inputs = inputs.to(device)
                
                # Calculate the loss.
                if not args.multi_output:
                    outputs = model(inputs)
                    loss = loss_fn(outputs, count_labels.squeeze(0).to(device))
                else:
                    y1, y2 = model(inputs)
                    loss1 = loss_fn(y1, count_labels.squeeze(0).to(device))
                    loss2 = loss_fn(y2, particle_labels.squeeze(0).to(device))
                    loss = 4 * loss1 + loss2
                
                # Use optimizer to see where model weights should be changed.
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

                # Get the loss.
                cost = loss.item()
                
                epoch_train_loss = cost
                
                t.set_postfix(train_loss='{:05.3f}'.format(cost))
                t.update()
                
                # torch.cuda.empty_cache()
        
        # Print out accuracy and loss metrics if condition is met.
        if epoch % args.evaluate_every == 0:
            valid_accuracy, valid_count_accuracy, valid_particle_accuracy, valid_loss = evaluate(model, criterion, valid_dataloader)
            valid_accuracies['Total'].append(valid_accuracy)
            valid_accuracies['Count'].append(valid_count_accuracy)
            valid_accuracies['Particle'].append(valid_particle_accuracy)
            valid_loss_values.append(valid_loss)
            
            train_accuracy, train_count_accuracy, train_particle_accuracy, train_loss = evaluate(model, criterion, train_dataloader)
            train_accuracies['Total'].append(train_accuracy)
            train_accuracies['Count'].append(train_count_accuracy)
            train_accuracies['Particle'].append(train_particle_accuracy)
            train_loss_values.append(train_loss)
            
            print(f'Step {step}: valid loss={valid_loss}, \n valid accuracy={valid_accuracy}, \n valid count accuracy={valid_count_accuracy}, \n valid particle accuracy={valid_particle_accuracy}')
            print(f'Step {step}: train loss={train_loss}, \n train accuracy={train_accuracy}, \n train count accuracy={train_count_accuracy}, \n valid particle accuracy={train_particle_accuracy}')

    # Save the model in designated checkpoint path.
    torch.save(model.state_dict(), args.ckpt_path)
    
    # Plot the accuracy and loss of train/validation sets.
    plot_loss(args, train_loss_values, 'train')
    plot_accuracies(args, train_accuracies, 'train')
    plot_loss(args, valid_loss_values, 'validation')
    plot_accuracies(args, valid_accuracies, 'validation')
    
    # Retrieve and display test set accuracy.
    test_accuracy, test_count_accuracy, test_particle_accuracy, _ = evaluate(model, criterion, test_dataloader)
    print('Test accuracy: %f' % test_accuracy)
    print('Test count accuracy: %f' % test_count_accuracy)
    print('Test particle accuracy: %f' % test_particle_accuracy)

In [None]:
def plot_accuracies(args, accuracies, split):
    """
    Creates and displays a matplotlib plot representing the accuracy of the model at each epoch.
    
    Parameters
    ----------
    args:
        args.logdir: str
            Directory location to save an image of the plots.
        args.model: str
            ID representing the model that has been trained.
    accuracies: dict(list(float))
        Dictionary of list of floats representing accuracy values recorded during training; contains overall, particle, and count accuracy values.
    split: str
        String representing the type of dataset that the accuracy values represent; ex. "train" or "validation"
    """
    plt.figure(figsize=(6,5))
    plt.title(f"Total, count, and particle accuracies for {split} set")
    plt.plot(accuracies['Total'], label="Total", color='r')
    plt.plot(accuracies['Count'], label="Count", color='g')
    plt.plot(accuracies['Particle'], label="Particle", color='b')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig(f'{args.logdir}/{args.model}_{split}_accuracies_{dt.now()}.png')
    plt.show()

In [None]:
def plot_loss(args, loss_values, split):
    """
    Creates and displays a matplotlib plot representing the loss at each epoch during model training.
    
    Parameters
    ----------
    args:
        args.logdir: str
            Directory location to save an image of the plots.
        args.model: str
            ID representing the model that has been trained.
    loss_values: list(float)
        List of floats representing the loss the model experienced at each epoch.
    split: str
        String representing the type of dataset that the accuracy values represent; ex. "train" or "validation"
    """
    plt.figure(figsize=(6,5))
    plt.title(f"{split} loss")
    plt.plot(loss_values,label="train", color='b')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f'{args.logdir}/{args.model}_{split}_loss_{dt.now()}.png')
    plt.show()

# Run the pipeline

In [None]:
def load_model(args):
    """
    Loads the model for model training.
    
    Parameters
    ----------
    args:
        args.model: str
            The model to load. Must be one of the following:
            {multi_output_cnn_3_layers, multi_output_cnn_5_layers,
             multi_output_cnn_10_layers, multi_output_cnn_18_layers,
             multi_output_cnn_early, multi_output_resnet18, multi_output_vgg16}
        args.num_particles: int
            Number of unique particles to have model train on.
        args.num_counts: int
            Number of unique count types to have model train on.
    """
    if args.model == 'multi_output_cnn_3_layers':
        num_particles = args.num_particles
        num_counts = args.num_counts
        hidden_dim = 8
        model = MultiOutputCNN_3Layer(num_particles=num_particles, num_counts=num_counts, hidden_dim=hidden_dim).to(device)
    elif args.model == 'multi_output_cnn_5_layers':
        num_particles = args.num_particles
        num_counts = args.num_counts
        hidden_dim = 8
        model = MultiOutputCNN_5Layer(num_particles=num_particles, num_counts=num_counts, hidden_dim=hidden_dim).to(device)
    elif args.model == 'multi_output_cnn_10_layers':
        num_particles = args.num_particles
        num_counts = args.num_counts
        hidden_dim = 8
        model = MultiOutputCNN_10Layer(num_particles=num_particles, num_counts=num_counts, hidden_dim=hidden_dim).to(device)
    elif args.model == 'multi_output_cnn_18_layers':
        num_particles = args.num_particles
        num_counts = args.num_counts
        hidden_dim = 8
        model = MultiOutputCNN_18Layer(num_particles=num_particles, num_counts=num_counts, hidden_dim=hidden_dim).to(device)
    elif args.model == 'multi_output_cnn_early':
        num_particles = args.num_particles
        num_counts = args.num_counts
        hidden_dim = 8
        model = MultiOutputCNN_Early(num_particles=num_particles, num_counts=num_counts, hidden_dim=hidden_dim).to(device)
    elif args.model == 'multi_output_resnet18':
        num_particles = args.num_particles
        num_counts = args.num_counts
        model = CustomResNet18Model(num_counts, num_particles).to(device)
    elif args.model == 'multi_output_vgg16':
        num_particles = args.num_particles
        num_counts = args.num_counts
        model = CustomVgg16Model(num_counts, num_particles).to(device)
    else:
        raise Exception('Invalid model type specified. Please selected from following: {multi_output_cnn_3_layers, multi_output_cnn_5_layers, multi_output_cnn_10_layers, multi_output_cnn_18_layers, multi_output_cnn_early, multi_output_resnet18, multi_output_vgg16}')
    
    return model

In [None]:
# Particles to train model on.
PARTICLES = ['1fpv', '1ss8', '3j03', '1ijg', '3iyf', '6ody', '6sp2', '6xs6', '7dwz', '7dx8', '7dx9']

# Particle counts to train model on.
COUNTS = ['single', 'double', 'triple', 'quadruple']

LENGTH = 1000

In [None]:
# 'model': Specify the model you want to train.
# 'root_dir': Directory containing the training data.
# 'epoches': Number of epoches to train model on.
# 'batch_size': Size of image batch for training.
# 'shuffle': If True, shuffle image datasets.
# 'num_workers': Number of subprocesses to use for data loading.
# 'num_particles': Number of particles in PARTICLES list.
# 'num_counts': Number of count types in COUNTS list.
# 'length': Same as LENGTH.
# 'evaluate_every': Number of epoches between every recording of accuracy and loss values.
# 'logdir': Directory to store log information about model training.
# 'multi-output': If True, specifies that the model is multi-output.
args = {
    'model': 'multi_output_cnn_3_layers',
    'root_dir': 'PATH TO DATA',
    'epoches': 20,
    'batch_size': 128,
    'shuffle': False,
    'num_workers': 1,
    'num_particles': 11,
    'num_counts': 4,
    'length': LENGTH,
    'evaluate_every': 1,
    'logdir': './logs',
    'multi_output': True
}

args = Namespace(**args)
args.ckpt_path = f'{args.logdir}/{args.model}_checkpoint.pth'

In [None]:
"""
Define loss function and optimizer

We will use the cross entropy loss and Adam optimizer
"""

# Create model
model = load_model(args)

# Define the cost function
criterion = nn.CrossEntropyLoss()

# Define the optimizer, learning rate 
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)

In [None]:
model

In [None]:
train(args, model, optimizer, criterion)

# Draw confusion matrix for particle prediction

In [None]:
_, _, test_dataloader = get_dataloaders(args, PARTICLES, PARTICLES, test_diff_particle=False)

In [None]:
import pandas as pd
import seaborn as sns

nb_classes = 11
confusion_matrix = np.zeros((nb_classes, nb_classes))

for i, (inputs, count_labels, particle_labels) in enumerate(test_dataloader):
    inputs = inputs.to(device)
    _, y2 = model(inputs)
    y2 = torch.argmax(y2, dim=-1).to('cpu').numpy().tolist()
    for t, p in zip(particle_labels, y2):
        confusion_matrix[t, p] += 1

plt.figure(figsize=(15,10))

class_names = list(particle2idx.keys())
df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names).astype(int)
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")

heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right',fontsize=12)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right',fontsize=12)
plt.ylabel('True label', fontsize=15)
plt.xlabel('Predicted label', fontsize=15)

# Summary

In [None]:
summary(model, input_size=(1, 128, 128))