In [1]:
from torch.utils.data import Dataset, DataLoader
from SSL_for_Diabetic_Retinopathy.data import data_loader
import torchvision.models as models
import torch
import csv
import os
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from typing import Callable
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import torchvision

In [2]:
import torchvision
print(torch.__version__)
print(torchvision.__version__)

1.9.0
0.10.0+cu111


In [3]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
print(dev)

cuda:0


In [5]:
def set_seeds(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [7]:
set_seeds(0)
image_path = '/home/mkelly_mehresearch_org/data/kaggle-eyepacs-data/images/'
csv_path = '/home/mkelly_mehresearch_org/data/kaggle-eyepacs-data/clean_binary.csv'
csv_path2 = '/home/mkelly_mehresearch_org/data/kaggle-eyepacs-data/clean_binary.csv'
test_csv = 'subset.csv'
transform = transforms.Compose([transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomVerticalFlip(),
                                transforms.ToTensor()
                               ])
train_data = data_loader.DataSetFromFolder(image_path, csv_path, transform, mode='train', index=False)
val_data = data_loader.DataSetFromFolder(image_path, csv_path2, transform, mode='validation', index=False)
sampler = ImbalancedDatasetSampler(train_data)
train_loader = DataLoader(train_data, batch_size=64, sampler=sampler, drop_last=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=True, drop_last=True)

In [8]:
def train(model, train_data, criterion, optimiser, num_epochs,
          device, scheduler=None, start_point=0, save_dir=None, 
          val_data=None):
    # create the directory to store the model and training history
    try:
        os.mkdir(save_dir)
    except:
        print('Directory already exists!!')

    accuracies = []
    losses = []
    val_accuracies = None
    num_batchs = len(train_data)
    for epoch in range(start_point, num_epochs):

        print('epoch {} of {}'.format(epoch, num_epochs))
        running_loss = 0
        running_corrects = 0
        data_size = 0
        for batch_no, (images, labels) in enumerate(train_data):
            batch_size = len(images)
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimiser.step()
            running_loss += loss.item() * images.size(0)
            corrects = (preds.data == labels.data).sum().item()
            running_corrects += corrects
            data_size += batch_size

        # at each epoch if validation data is available
        # calculate validation accuracy
        val_acc = None
        if val_data is not None:
            val_accuracies = []
            # model.eval()
            with torch.no_grad():
                running_val_corrects = 0
                val_size = 0
                for images, labels in val_data:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    _, preds = torch.max(outputs, 1)
                    running_val_corrects += (preds.data == labels.data).sum().item()
                    val_size += len(images)
            val_acc = running_val_corrects / val_size
            val_accuracies.append(val_acc)
            # model.train()

        total_loss = running_loss / data_size
        losses.append(total_loss)
        total_acc = running_corrects / data_size
        accuracies.append(total_acc)

        # save the history if given save_path
        if save_dir is not None:
            with open('{}/history.csv'.format(save_dir), 'a+', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([total_loss, total_acc, val_acc])

        print('epoch loss: {} epoch accuracy: {} validation accuracy: {}'.format(total_loss, total_acc, val_acc))

        # save the model every 5 epochs if give save_path
        # we are note decaying the learning rate
        # so we don't need to save the optimsier
        
        if save_dir is not None:
            if (epoch % 5 == 0) | (epoch == num_epochs - 1):
                model.state_dict(torch.save(model.state_dict(), save_dir + 'epoch_{}'.format(epoch)))
                
        if scheduler is not None:
            scheduler.step()

    return model, losses, accuracies, val_accuracies

In [9]:
set_seeds(0)
fine_tune = False
num_epochs = 2
learning_rate = 0.01
r18 = models.resnet18(pretrained=False) 

if not fine_tune:
    for param in r18.parameters():
        param.requires_grad = False
    r18.eval()
    
features = r18.fc.in_features
r18.fc = torch.nn.Linear(features, 2)

r18 = r18.to(dev)

trainable_params = r18.parameters()
if not fine_tune:
    trainable_params = []
    for name, param in r18.named_parameters():
        if param.requires_grad:
            trainable_params.append(param)

In [10]:
set_seeds(0)
optimiser = optim.SGD(trainable_params, lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss().to(dev)
train(r18, train_loader, criterion, optimiser, num_epochs dev, scheduler=None, start_point=0, val_data=val_loader, save_dir='random_baseline_classifier/')

epoch 0 of 25


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch loss: 1.1733875692197167 epoch accuracy: 0.5120501285347043 validation accuracy: 0.7889204545454546
epoch 1 of 25
epoch loss: 1.1175226720868776 epoch accuracy: 0.5345035347043702 validation accuracy: 0.7417613636363637
epoch 2 of 25
epoch loss: 3.567212635110154 epoch accuracy: 0.5096200192802056 validation accuracy: 0.23650568181818182
epoch 3 of 25
epoch loss: 10.31633847691406 epoch accuracy: 0.5091380141388174 validation accuracy: 0.4971590909090909
epoch 4 of 25
epoch loss: 15.492928085106504 epoch accuracy: 0.5148417416452442 validation accuracy: 0.6362215909090909
epoch 5 of 25
epoch loss: 21.394456627129895 epoch accuracy: 0.514480237789203 validation accuracy: 0.8079545454545455
epoch 6 of 25
epoch loss: 21.56015670728561 epoch accuracy: 0.5152032455012854 validation accuracy: 0.8164772727272728
epoch 7 of 25
epoch loss: 31.170816846862245 epoch accuracy: 0.5124718830334191 validation accuracy: 0.5066761363636364
epoch 8 of 25
epoch loss: 49.274399409257356 epoch accura

(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU