# OCT Retina Diagnose

## Project: Build CNN Model to diagnose OCT Retina images

---

### Why We're Here 

Retinal optical coherence tomography (OCT) is an imaging technique used to capture high-resolution cross sections of the retinas of living patients. Approximately 30 million OCT scans are performed each year, and the analysis and interpretation of these images takes up a significant amount of time (Swanson and Fujimoto, 2017).

![](https://i.imgur.com/fSTeZMd.png)
Figure 1: Representative Optical Coherence Tomography Images and the Workflow Diagram [Kermany et. al. 2018]

(A) (Far left) choroidal neovascularization (CNV) with neovascular membrane (white arrowheads) and associated subretinal fluid (arrows). (Middle left) Diabetic macular edema (DME) with retinal-thickening-associated intraretinal fluid (arrows). (Middle right) Multiple drusen (arrowheads) present in early AMD. (Far right) Normal retina with preserved foveal contour and absence of any retinal fluid/edema.

#### Content
The dataset is organized into 3 folders (train, test, val) and contains subfolders for each image category (NORMAL,CNV,DME,DRUSEN). There are 84,495 X-Ray images (JPEG) and 4 categories (NORMAL,CNV,DME,DRUSEN).

Images are labeled as (disease)-(randomized patient ID)-(image number by this patient) and split into 4 directories: CNV, DME, DRUSEN, and NORMAL.

Optical coherence tomography (OCT) images (Spectralis OCT, Heidelberg Engineering, Germany) were selected from retrospective cohorts of adult patients from the Shiley Eye Institute of the University of California San Diego, the California Retinal Research Foundation, Medical Center Ophthalmology Associates, the Shanghai First People’s Hospital, and Beijing Tongren Eye Center between July 1, 2013 and March 1, 2017.

Before training, each image went through a tiered grading system consisting of multiple layers of trained graders of increasing exper- tise for verification and correction of image labels. Each image imported into the database started with a label matching the most recent diagnosis of the patient. The first tier of graders consisted of undergraduate and medical students who had taken and passed an OCT interpretation course review. This first tier of graders conducted initial quality control and excluded OCT images containing severe artifacts or significant image resolution reductions. The second tier of graders consisted of four ophthalmologists who independently graded each image that had passed the first tier. The presence or absence of choroidal neovascularization (active or in the form of subretinal fibrosis), macular edema, drusen, and other pathologies visible on the OCT scan were recorded. Finally, a third tier of two senior independent retinal specialists, each with over 20 years of clinical retina experience, verified the true labels for each image. The dataset selection and stratification process is displayed in a CONSORT-style diagram in Figure 2B. To account for human error in grading, a validation subset of 993 scans was graded separately by two ophthalmologist graders, with disagreement in clinical labels arbitrated by a senior retinal specialist.


#### In this notebook we will design and benchmark CNN models in accuracy of diagnosing CNV, DME, DRUSEN, and NORMAL Retina OCT image

### The Road Ahead

We break the notebook into separate steps.

* [Step 0](#step0): Import Datasets
* [Step 1](#step1): Define Train and Test function
* [Step 2](#step2): Train and Test origin ResNet18 
* [Step 3](#step3): Train and Test modified ResNet18 (build CNN from scratch)
* [Step 4](#step4): Train and Test pretrained ResNet50 (transfer learning)
* [Step 5](#step5): Conclusion

---
<a id='step0'></a>
## Step 0: Import Datasets

Download Dataset OCT2017.zip


In [10]:
import numpy as np
from glob import glob

import torch
import torchvision.models as models

import os
from torchvision import datasets

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import copy

In [2]:
BATCH_SIZE = 32

In [3]:
data_dir = './OCT_data'
train_dir = data_dir + '/train224'
val_dir = data_dir + '/val224'
test_dir = data_dir + '/test224'

In [4]:
# Define transforms for data
# Because the original image already stay in from 224x224
# No need to do resize or crop

# Data Augmentation:
# HorizontalFlip is reasonable enough
# Not anthing else

train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])

val_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])

test_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])

data_transforms = [train_transform, val_transform, test_transform]

# Load datasets
train_data = datasets.ImageFolder(train_dir, transform=train_transform)
val_data = datasets.ImageFolder(val_dir, transform=val_transform)
test_data = datasets.ImageFolder(test_dir, transform=test_transform)

image_datasets = [train_data, val_data, test_data]

# Create Loader
trainloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valloader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
testloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

loaders = {"train":trainloader, "test":testloader, "val":valloader}

### Names of Class Labels mapped by ImageFolder Pytorch


In [5]:
cat_to_name = train_data.class_to_idx
cat_to_name

{'CNV': 0, 'DME': 1, 'DRUSEN': 2, 'NORMAL': 3}

### Check data shape

In [6]:
len(train_data.classes)

4

In [7]:
len(trainloader)

1668

### Check balance of data

In [13]:
# Training
train_info_df = pd.read_csv('OCT_data/train.csv')
train_info_df['label'].value_counts()

3    1000
1    1000
2    1000
0    1000
Name: label, dtype: int64

In [14]:
# Validation
val_info_df = pd.read_csv('OCT_data/val.csv')
val_info_df['label'].value_counts()

3    100
2    100
1    100
0    100
Name: label, dtype: int64

In [15]:
# Testing
test_info_df = pd.read_csv('OCT_data/test.csv')
test_info_df['label'].value_counts()

3    242
2    242
1    242
0    242
Name: label, dtype: int64

### Check device

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Task 1: Define Train and Test Function

In [59]:
def train(n_epochs, loaders, model, optimizer, criterion, save_path=None, valid_loss_min=np.Inf):
    """returns trained model"""
    
    for epoch in range(1, n_epochs+1):
        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        
        ###################
        # train the model #
        ###################
        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            data, target = data.to(device), target.to(device)
            ## find the loss and update the model parameters accordingly
            output = model(data)
            loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ## record the average training loss, using something like
            train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
            print_every = batch_idx
            
            # Get val loss every 400 batch
            if print_every % 400 == 0:
                
                ######################    
                # validate the model #
                ######################
                model.eval()
                for batch_idx, (data, target) in enumerate(loaders['val']):
                    data, target = data.to(device), target.to(device)
                    ## update the average validation loss
                    output = model(data)
                    loss = criterion(output, target)
                    valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.data - valid_loss))
                
            
                # print training/validation statistics
                print('Epoch: {} \tbatch {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
                    epoch,
                    print_every,
                    train_loss,
                    valid_loss
                    ))
                # Set model training again
                model.train()
                
        
                ##  save the model if validation loss has decreased
                if (save_path is not None):
                    if (valid_loss <= valid_loss_min):
                        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                        valid_loss_min,
                        valid_loss))
                        torch.save(model.state_dict(), save_path)
                        valid_loss_min = valid_loss
        
            
    # return trained model
    return model, valid_loss_min

In [60]:
def test(loaders, model):
    acc_test = 0
    model.eval()
    with torch.no_grad():
        for data, target in loaders['test']:
            data, target = data.to(device), target.to(device)
            
            Z = model(data)
            target_pred = Z.data.max(dim=1)[1]
            acc_test += torch.sum(target_pred==target).item()
    
    
    return acc_test/len(loaders['test'].dataset)

---
## Task 2: Train and Test Origin ResNet-18

<img src='https://www.researchgate.net/profile/Paolo_Napoletano/publication/322476121/figure/tbl1/AS:668726449946625@1536448218498/ResNet-18-Architecture.png' width=500px>

In here, we will build a general ResNet model.
To get specific ResNet we only need to define number of convolution layers in each convolution block.

For ex:
- resnet18 = ResNet(2,2,2,2,N_output)
- resnet34 = ResNet(3,4,6,3,N_output)

Instead of importing prebuilt ResNet-18 model from pytroch, we decided to implement by hand again to warm up our PyTorch skills

In [61]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, conv1_stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=conv1_stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        y = self.bn(self.conv1(x))
        y = F.relu(y)
        y = self.bn(self.conv2(y))
        
        return y
        

In [62]:
def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [63]:
class Layer(nn.Module):
    def __init__(self, in_channels, out_channels, N):
        super().__init__()
        self.N = N
        self.basic_block1 = BasicBlock(in_channels, out_channels, conv1_stride=2)
        self.basic_blocks = get_clones(BasicBlock(out_channels, out_channels, conv1_stride=1), self.N-1)
        self.downsample = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
                                        nn.BatchNorm2d(out_channels))
        
        
    def forward(self, x):
        # Add residual skip after pass each block
        y = self.basic_block1(x)
        y = y + self.downsample(x)

        for i in range(self.N-1):
            y = y + self.basic_blocks[i](y)
            
        return y
        

In [64]:
class ResNet(nn.Module):
    def __init__(self, N1, N2, N3, N4, N_output):
        super().__init__()
        self.N1 = N1
        self.N2 = N2
        self.N3 = N3
        self.N4 = N4
        self.N_output = N_output
        
        # Layer 0
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1)
        
        # Layer 1
        self.basic_blocks_1 = get_clones(BasicBlock(64, 64, conv1_stride=1), self.N1)
        
        
        # Other layers
        self.layer2 = Layer(64, 128, self.N2)
        self.layer3 = Layer(128, 256, self.N3)
        self.layer4 = Layer(256, 512, self.N4)
        
        
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.fc = nn.Linear(512, self.N_output)
        
        
    
    def forward(self, x):
        # Layer 0
        y = F.relu(self.bn1(self.conv1(x)))
        y = self.pool1(y)
        
        # Layer 1
        for i in range(self.N1):
            y = y + self.basic_blocks_1[i](y)
        
        y = self.layer2(y)
        y = self.layer3(y)
        y = self.layer4(y)
        
        y = self.avgpool(y)
        y = y.view(x.size(0), -1) # Flatten
        
        y = self.fc(y)
        
        return y
        

### Initiate resnet18 model, loss function, optimizer

In [65]:
# Initiate the model
resnet18_model = ResNet(2,2,2,2,4)

# Loss function fomula
res18_criterion = nn.CrossEntropyLoss()

# Reference from https://www.kaggle.com/carloalbertobarbano/vgg16-transfer-learning-pytorch
res18_optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9) 

resnet18_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (basic_blocks_1): ModuleList(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Layer(
    (basic_block1): BasicBlock(
      (conv1): Conv2d(64, 128, k

### Train Resnet-18

In [66]:
val_loss_min_res18 = np.Inf

In [67]:
resnet18_model, val_loss_min_res18 = train(3, loaders, resnet18_model, res18_optimizer, res18_criterion,
                                          save_path=None,
                                          valid_loss_min=val_loss_min_res18)

Epoch: 1 	batch 0 	Training Loss: 1.386990 	Validation Loss: 1.388235
Epoch: 1 	batch 400 	Training Loss: 0.599900 	Validation Loss: 1.376849
Epoch: 1 	batch 800 	Training Loss: 0.479702 	Validation Loss: 1.243111
Epoch: 1 	batch 1200 	Training Loss: 0.420820 	Validation Loss: 1.136152
Epoch: 1 	batch 1600 	Training Loss: 0.382876 	Validation Loss: 1.203773
Epoch: 2 	batch 0 	Training Loss: 0.388344 	Validation Loss: 1.171317
Epoch: 2 	batch 400 	Training Loss: 0.245765 	Validation Loss: 1.261480
Epoch: 2 	batch 800 	Training Loss: 0.237673 	Validation Loss: 1.097593
Epoch: 2 	batch 1200 	Training Loss: 0.229866 	Validation Loss: 1.124406
Epoch: 2 	batch 1600 	Training Loss: 0.223209 	Validation Loss: 1.053322
Epoch: 3 	batch 0 	Training Loss: 0.116971 	Validation Loss: 1.301796
Epoch: 3 	batch 400 	Training Loss: 0.189956 	Validation Loss: 1.121692
Epoch: 3 	batch 800 	Training Loss: 0.182260 	Validation Loss: 1.160074
Epoch: 3 	batch 1200 	Training Loss: 0.178916 	Validation Loss: 1.

### Test Resnet-18

In [68]:
test(loaders, resnet18_model)

0.29545454545454547

---
## Task 2: Train and Test modified ResNet18
Instead build CNN from scratch, I will modify ResNet to achieve higher accuracy

ResNet with GroupNorm with num_group=1 instead of BatchNorm
--> LayerNorm

In [79]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, conv1_stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=conv1_stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.gn = nn.GroupNorm(1, out_channels)
        
    def forward(self, x):
        y = self.gn(self.conv1(x))
        y = F.relu(y)
        y = self.gn(self.conv2(y))
        
        return y
        

In [80]:
class Layer(nn.Module):
    def __init__(self, in_channels, out_channels, N):
        super().__init__()
        self.N = N
        self.basic_block1 = BasicBlock(in_channels, out_channels, conv1_stride=2)
        self.basic_blocks = get_clones(BasicBlock(out_channels, out_channels, conv1_stride=1), self.N-1)
        self.downsample = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
                                        nn.GroupNorm(1, out_channels))
        
        
    def forward(self, x):
        # Add residual skip after pass each block
        y = self.basic_block1(x)
        y = y + self.downsample(x)

        for i in range(self.N-1):
            y = y + self.basic_blocks[i](y)
            
        return y
        

In [81]:
class ResNetGN(nn.Module):
    def __init__(self, N1, N2, N3, N4, N_output):
        super().__init__()
        self.N1 = N1
        self.N2 = N2
        self.N3 = N3
        self.N4 = N4
        self.N_output = N_output
        
        # Layer 0
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.gn1 = nn.GroupNorm(1, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1)
        
        # Layer 1
        self.basic_blocks_1 = get_clones(BasicBlock(64, 64, conv1_stride=1), self.N1)
        
        
        # Other layers
        self.layer2 = Layer(64, 128, self.N2)
        self.layer3 = Layer(128, 256, self.N3)
        self.layer4 = Layer(256, 512, self.N4)
        
        
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.fc = nn.Linear(512, self.N_output)
        
        
    
    def forward(self, x):
        # Layer 0
        y = F.relu(self.gn1(self.conv1(x)))
        y = self.pool1(y)
        
        # Layer 1
        for i in range(self.N1):
            y = y + self.basic_blocks_1[i](y)
        
        y = self.layer2(y)
        y = self.layer3(y)
        y = self.layer4(y)
        
        y = self.avgpool(y)
        y = y.view(x.size(0), -1) # Flatten
        
        y = self.fc(y)
        
        return y
        

In [82]:
# Initiate the model
resnet18_GN_model = ResNetGN(2,2,2,2,4)

# Loss function fomula
res18_GN_criterion = nn.CrossEntropyLoss()

# Reference from https://www.kaggle.com/carloalbertobarbano/vgg16-transfer-learning-pytorch
res18_GN_optimizer = optim.SGD(resnet18_GN_model.parameters(), lr=0.001, momentum=0.9) 

resnet18_GN_model.to(device)

ResNetGN(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (gn1): GroupNorm(1, 64, eps=1e-05, affine=True)
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (basic_blocks_1): ModuleList(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (gn): GroupNorm(1, 64, eps=1e-05, affine=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (gn): GroupNorm(1, 64, eps=1e-05, affine=True)
    )
  )
  (layer2): Layer(
    (basic_block1): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), s

In [83]:
val_loss_min_res18_gn = np.Inf

In [84]:
resnet18_GN_model, val_loss_min_res18_gn = train(3, loaders, resnet18_GN_model, res18_GN_optimizer, res18_GN_criterion,
                                          save_path=None,
                                          valid_loss_min=val_loss_min_res18_gn)

Epoch: 1 	batch 0 	Training Loss: 1.633554 	Validation Loss: 3.744724
Epoch: 1 	batch 400 	Training Loss: 1.365723 	Validation Loss: 1.557902
Epoch: 1 	batch 800 	Training Loss: 1.275613 	Validation Loss: 1.046402
Epoch: 1 	batch 1200 	Training Loss: 1.116997 	Validation Loss: 1.022070
Epoch: 1 	batch 1600 	Training Loss: 0.998848 	Validation Loss: 0.649461
Epoch: 2 	batch 0 	Training Loss: 0.445842 	Validation Loss: 0.730278
Epoch: 2 	batch 400 	Training Loss: 0.494867 	Validation Loss: 0.425684
Epoch: 2 	batch 800 	Training Loss: 0.452669 	Validation Loss: 0.329882
Epoch: 2 	batch 1200 	Training Loss: 0.416345 	Validation Loss: 0.290522
Epoch: 2 	batch 1600 	Training Loss: 0.390313 	Validation Loss: 0.201673
Epoch: 3 	batch 0 	Training Loss: 0.128026 	Validation Loss: 0.144296
Epoch: 3 	batch 400 	Training Loss: 0.280475 	Validation Loss: 0.126784
Epoch: 3 	batch 800 	Training Loss: 0.272431 	Validation Loss: 0.229600
Epoch: 3 	batch 1200 	Training Loss: 0.262124 	Validation Loss: 0.

In [86]:
test(loaders, resnet18_GN_model)

0.981404958677686

__Note: Insanely high accuracy from a model that is trained from scratch__

---
## Task 3: Train and Test pretrained ResNet50 (transfer learning)

In [301]:
model_pretrained = models.resnet50(pretrained=True)
# model_pretrained

In [302]:
# Freeze all the pretrained weight
for params in model_pretrained.parameters():
    params.requires_grad = False

# Define new classifer
fc_input_features = model_pretrained.fc.in_features
model_pretrained.fc = nn.Linear(fc_input_features, 4)

# Loss function fomula
model_pretrained_criterion = nn.CrossEntropyLoss()

# Reference from https://www.kaggle.com/carloalbertobarbano/vgg16-transfer-learning-pytorch
model_pretrained_optimizer = optim.SGD(model_pretrained.parameters(), lr=0.001, momentum=0.9) 

model_pretrained.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [303]:
val_loss_min_model_pretrained = np.Inf

In [304]:
model_pretrained, val_loss_min_model_pretrained = train(3, loaders, model_pretrained, model_pretrained_optimizer, 
                                                        model_pretrained_criterion,
                                                        valid_loss_min=val_loss_min_model_pretrained)

Epoch 1 	batch 0 	loss: 1.4611423015594482
Epoch 1 	batch 100 	loss: 1.0470964908599854
Epoch 1 	batch 200 	loss: 0.9074360728263855
Epoch 1 	batch 300 	loss: 0.8397747874259949
Epoch 1 	batch 400 	loss: 0.7934359908103943
Epoch 1 	batch 500 	loss: 0.7630904316902161
Epoch 1 	batch 600 	loss: 0.7393508553504944
Epoch 1 	batch 700 	loss: 0.7205396890640259
Epoch 1 	batch 800 	loss: 0.702239990234375
Epoch 1 	batch 900 	loss: 0.6898284554481506
Epoch 1 	batch 1000 	loss: 0.6785690784454346
Epoch 1 	batch 1100 	loss: 0.669047474861145
Epoch 1 	batch 1200 	loss: 0.6647937297821045
Epoch 1 	batch 1300 	loss: 0.6556822061538696
Epoch 1 	batch 1400 	loss: 0.6490129828453064
Epoch 1 	batch 1500 	loss: 0.642238438129425
Epoch 1 	batch 1600 	loss: 0.6345715522766113
Epoch: 1 	Training Loss: 0.631112 	Validation Loss: 0.438402
Epoch 2 	batch 0 	loss: 0.3330639898777008
Epoch 2 	batch 100 	loss: 0.5132921934127808
Epoch 2 	batch 200 	loss: 0.5142815113067627
Epoch 2 	batch 300 	loss: 0.51718276739

In [305]:
test(loaders, model_pretrained)

0.8646694214876033

## Conclusion

1. With epoch = 3, we trained 3 models:
    - Orginal Resnet-18, accuracy = 0.29
    - Modified Resnet-18, GroupNorm(group=1) --> LayerNorm, accuracy = 0.98
    - Pretrained Resnet-50, accuracy = 0.86
2. SGD perform better than Adam optimizer
3. Interestingly, when build CNN from scratch by modifying BatchNorm to GroupNorm --> we achieve insanely high accuracy