In [1]:
from torch.utils.data import Dataset, DataLoader
from SSL_for_Diabetic_Retinopathy.data import data_loader
from SSL_for_Diabetic_Retinopathy.models import simCLR_encoders
import torchvision.models as models
import torch
import csv
import os
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from typing import Callable
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import torchvision
from torch import nn
from collections import OrderedDict

In [2]:
print(torch.__version__)
print(torchvision.__version__)

1.9.0
0.10.0+cu111


In [3]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
print(dev)

cuda:0


In [4]:
def set_seeds(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
set_seeds(0)
image_path = '/home/mkelly_mehresearch_org/data/kaggle-eyepacs-data/images/'
csv_path = '/home/mkelly_mehresearch_org/data/kaggle-eyepacs-data/clean_binary.csv'
csv_path2 = '/home/mkelly_mehresearch_org/data/kaggle-eyepacs-data/clean_binary.csv'
test_csv = 'subset.csv'
transform = transforms.Compose([transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomVerticalFlip(),
                                transforms.ToTensor()
                               ])
train_data = data_loader.DataSetFromFolder(image_path, csv_path, transform, mode='train', index=False)
val_data = data_loader.DataSetFromFolder(image_path, csv_path2, transform, mode='validation', index=False)
# sampler = ImbalancedDatasetSampler(train_data)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=True, drop_last=True)

In [6]:
def train(model, train_data, criterion, optimiser, num_epochs,
          device, scheduler=None, start_point=0, save_dir=None, 
          val_data=None):
    # create the directory to store the model and training history
    try:
        os.mkdir(save_dir)
    except:
        print('Directory already exists!!')

    accuracies = []
    losses = []
    val_accuracies = None
    num_batchs = len(train_data)
    for epoch in range(start_point, num_epochs):

        print('epoch {} of {}'.format(epoch, num_epochs))
        running_loss = 0
        running_corrects = 0
        data_size = 0
        for batch_no, (images, labels) in enumerate(train_data):
            batch_size = len(images)
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimiser.step()
            running_loss += loss.item() * images.size(0)
            corrects = (preds.data == labels.data).sum().item()
            running_corrects += corrects
            data_size += batch_size

        # at each epoch if validation data is available
        # calculate validation accuracy
        val_acc = None
        if val_data is not None:
            val_accuracies = []
            # model.eval()
            with torch.no_grad():
                running_val_corrects = 0
                val_size = 0
                for images, labels in val_data:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    _, preds = torch.max(outputs, 1)
                    running_val_corrects += (preds.data == labels.data).sum().item()
                    val_size += len(images)
            val_acc = running_val_corrects / val_size
            val_accuracies.append(val_acc)
            # model.train()

        total_loss = running_loss / data_size
        losses.append(total_loss)
        total_acc = running_corrects / data_size
        accuracies.append(total_acc)

        # save the history if given save_path
        if save_dir is not None:
            with open('{}/history.csv'.format(save_dir), 'a+', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([total_loss, total_acc, val_acc])

        print('epoch loss: {} epoch accuracy: {} validation accuracy: {}'.format(total_loss, total_acc, val_acc))

        # save the model every 5 epochs if give save_path
        # we are note decaying the learning rate
        # so we don't need to save the optimsier
        
        if save_dir is not None:
            if (epoch % 5 == 0) | (epoch == num_epochs - 1):
                model.state_dict(torch.save(model.state_dict(), save_dir + 'epoch_{}'.format(epoch)))
                
        if scheduler is not None:
            scheduler.step()

    return model, losses, accuracies, val_accuracies

In [7]:
def representation_model(model):
    removed = list(model.children())[:-1]
    model = torch.nn.Sequential(*removed)
    # for param in model.parameters():
        # param.requires_grad = False
    return model

In [8]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearClassifier, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = self.linear(x)
        return outputs

In [9]:
class SimCLREncoder(nn.Module):
    """
    ResNet18 encoder adapted for MNIST
    Args:
        forward:
            x (Tensor): batch of images
    Returns:
        model(x) (Tensor): batch of encoded images
    """

    def __init__(self, encoder_type, out_dim, device, DDP=True, local_rank=None, checkpoint=None):
        super(SimCLREncoder, self).__init__()
        self.device = device
        self.model_dict = {'resnet18': models.resnet18(),
                           'resnet50': models.resnet50(),
                           'resnet101': models.resnet101()}
        self.encoder = self.model_dict[encoder_type]
        self.out_dim = out_dim
        self.checkpoint = checkpoint
        self.in_size = self._get_output_shape()
        self.projection_head = self._get_projection_head(self.in_size)
        self.model = self._get_model().cuda()
        if DDP:
            # convert all batchnorm layers to sync batch norm!
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            self.model = DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank)

    def _get_output_shape(self, image_dim=(1, 3, 224, 224)):
        model = nn.Sequential(*list(self.encoder.children())[:-1],
                      nn.Flatten())
        return model(torch.rand(*image_dim)).data.shape[1]

    def _get_model(self):
        resnet = self.encoder
        model = nn.Sequential(*list(resnet.children())[:-1],
                              nn.Flatten(),
                              self.projection_head)
        if self.checkpoint is not None:
              # uncooment if the model was saved in data parallel!
#             state_dict = torch.load(self.checkpoint)
#             new_state_dict = OrderedDict()
#             for k, v in state_dict.items():
#                 name = k[7:] # remove `module.`
#                 new_state_dict[name] = v
#             model.load_state_dict(new_state_dict)
            model.load_state_dict(torch.load(self.checkpoint))
        return model

    def _get_projection_head(self, in_size):
        projection_head = nn.Sequential(nn.Linear(in_size, in_size),
                                        nn.ReLU(),
                                        nn.Linear(in_size, self.out_dim))
        return projection_head

    def get_model(self):
        return self.model

    def forward(self, x):
        return self.model(x)

In [11]:
class classification_model(torch.nn.Module):
    def __init__(self, rep_model, input_dim, output_dim ):
        super(classification_model, self).__init__()
        self.rep_model = rep_model
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.rep_model(x)
        return self.linear(x)

In [None]:
fine_tune = False
num_epochs = 100
learning_rate = 0.01

In [10]:
set_seeds(0)
encoder = SimCLREncoder('resnet18', 128, dev, DDP=False, checkpoint='aug_1/epoch_99')
representation_model = representation_model(encoder.get_model())

if not fine_tune:
    representation_model.eval()
    for param in representation_model.parameters():
        param.requires_grad = False

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [12]:
set_seeds(0)
model_ft = classification_model(representation_model, 512, 2)
model_ft.load_state_dict(torch.load('trial_2_multi_gpu_classifier/epoch_19'))
model_ft = model_ft.to(dev)

In [14]:
set_seeds(0)
optimiser = optim.SGD(model_ft.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimiser, gamma=0.5)
criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor([0.19, 0.81])).to(dev)
train(model_ft, train_loader, criterion, optimiser, num_epochs, dev, scheduler=None, start_point=20, val_data=val_loader, save_dir='trial_2_multi_gpu_classifier/')

Directory already exists!!
epoch 20 of 25
epoch loss: 8.554932199437697 epoch accuracy: 0.5391227506426736 validation accuracy: 0.790340909090909
epoch 21 of 25
epoch loss: 47.02837398303505 epoch accuracy: 0.5374558161953727 validation accuracy: 0.7869318181818182
epoch 22 of 25
epoch loss: 45.480270905482435 epoch accuracy: 0.5533218187660668 validation accuracy: 0.7863636363636364
epoch 23 of 25
epoch loss: 82.64133435717585 epoch accuracy: 0.5435812982005142 validation accuracy: 0.6431818181818182
epoch 24 of 25
epoch loss: 107.8288378163904 epoch accuracy: 0.5383394922879178 validation accuracy: 0.5856534090909091


(classification_model(
   (rep_model): Sequential(
     (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace=True)
     (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
     (4): Sequential(
       (0): BasicBlock(
         (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (relu): ReLU(inplace=True)
         (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       )
       (1): BasicBlock(
         (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine