In [1]:
# Built-in modules
import sys, os, argparse
from collections import OrderedDict, defaultdict
# Public modules
import numpy as np
import matplotlib.pyplot as plt
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import progressbar

from torchvision.datasets import ImageFolder
from torchvision.transforms import CenterCrop, ColorJitter, Compose, \
        Normalize, Resize, RandomCrop, RandomHorizontalFlip, \
        RandomRotation, ToTensor
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from PIL import Image
from utils import SimpleAUC

os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Setting parameters

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
init_lr = 1e-4
batch_size = 32
num_epochs = 300
steps_per_epoch = 5
weight_decay = 1e-5

tag = "baseline1"

In [3]:
checkpoint_dir = './train_logs/baseline'

In [4]:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, '{}.pt'.format(tag))
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    print('Load model trained for {} epochs.'.format(checkpoint['epoch']))
else:
    checkpoint = None

# Build dataset

In [5]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class RetinaDataset(Dataset):
    """ Custom dataset for Retina image dataset """
    def __init__(self,
                 csv_df,
                 img_dir,
                 transform=None,
                 standardize=False):
        """
        Args:
            csv_path (str): A path to the csv file.
            img_dir (str): A path to the fundus image dir.
            transform (callable, optional): Tranform function
                to be applied fundus images.
            standardize (bool): Whether to apply standardize to numeric data.
        """
        super().__init__()
        self.df = csv_df
        self.img_dir = img_dir
        self.transform = transform
        self.standardize = standardize

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        """ read left and right retina image """
        path = os.path.join(self.img_dir, self.df['image'].values[i]+".jpeg")

        sample = {}
        sample['fundus'] = Image.open(path)
        sample["level"] = 1 if self.df["level"].values[i] > 0 else 0

        if self.transform is not None:
            sample['fundus'] = self.transform(sample['fundus'])
        return sample
    
def build_dataset():
    """ Build dataset """
    # Define transforms for training and evaluation
#     transform_train = Compose([Resize([256, 256]),
#                                RandomCrop([224, 224]),
#                                ColorJitter(brightness=0.2,
#                                            saturation=1),
#                                RandomHorizontalFlip(),
#                                RandomRotation(degrees=30, fill=128),
#                                ToTensor(),
#                                Normalize(IMAGENET_MEAN, IMAGENET_STD)])
#     transform_eval = Compose([Resize([256, 256]),
#                               CenterCrop([224, 224]),
#                               ToTensor(),
#                               Normalize(IMAGENET_MEAN, IMAGENET_STD)])

    transform_train = Compose([Resize([256, 256]),
                               ToTensor(),
                               Normalize(IMAGENET_MEAN, IMAGENET_STD)])
    transform_eval = Compose([Resize([256, 256]),
                              ToTensor(),
                              Normalize(IMAGENET_MEAN, IMAGENET_STD)])
    
    
    df = pd.read_csv("../kaggle_data/trainLabels.csv")
    n_samples = len(df)
    train_df = df[0:int(n_samples*0.8)]
    val_df = df[int(n_samples*0.8):]
    
    IMG_DIR = "../kaggle_data/train_resize"
    
    dataset = {'train': RetinaDataset(csv_df=train_df,
                                   img_dir=IMG_DIR,
                                   transform=transform_train),
               'valid' : RetinaDataset(csv_df=val_df,
                                   img_dir=IMG_DIR,
                                  transform=transform_eval)}

        

    return dataset

In [6]:
dataset = build_dataset()
loader = {}
for key in ['train', 'valid']:
    shuffle = (key != 'valid')
    loader[key] = DataLoader(dataset[key], batch_size=batch_size, shuffle=shuffle,
                             pin_memory=True, num_workers=4)

# Define Model

In [7]:
_model_dict = {
            'resnet18' : torchvision.models.resnet18,
            'resnet34' : torchvision.models.resnet34,
            'resnet50' : torchvision.models.resnet50,
            'resnet101' : torchvision.models.resnet101,
        }

class Classifier(nn.Module):
    """ Pre-trained model to which new layers are attached. """
    def __init__(self,
                 cnn_name,
                 num_classes,
                 pretrained=False):
        """ Initialize module
        Args:
            cnn_name (str): The name of a pretrained CNN model.
            num_classes (int): The number of output classes.
        """
        super().__init__()
        if cnn_name not in _model_dict.keys():
            raise NotImplementedError('{} is not supported.'.format(cnn_name))

        self.num_classes = num_classes
        self.model = _model_dict[cnn_name](pretrained=pretrained)
        self.model.fc = nn.Linear(512, num_classes)

    def forward(self, fundus):
        """ forward pass """
        h = self.model(fundus)
        return h


In [8]:
if checkpoint is None:
    cls = Classifier('resnet18', num_classes=2, pretrained=False)
else:
    cls = Classifier(checkpoint['cnn_name'], num_classes=2,
            pretrained=False)
    cls.load_state_dict(checkpoint['cls_state_dict'])
cls.to(device)

Classifier(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_ru

# Define Loss function and optimizer

In [9]:
# Define loss function
loss_fn = {'ce' : nn.CrossEntropyLoss()}

# Build optimizer

optimizer = optim.SGD(cls.parameters(),
                      lr=init_lr,
                      momentum=0.9,
                      weight_decay=weight_decay,
                      nesterov=True)

if checkpoint is not None:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

decay_factor = 0.99
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer,lambda epoch: decay_factor ** epoch)
if checkpoint is not None:
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])

# Train

In [10]:
# Define metric objects 
metric_objects = {'train_auroc' : SimpleAUC(),
                  'val_auroc' : SimpleAUC()}
best_val_metric = 0.0 if checkpoint is None else checkpoint['best_val_metric']
i = 0 if checkpoint is None else checkpoint['epoch']
while i < num_epochs:
    # Reset training state variables
    training_loss = defaultdict(lambda: 0.0)
    num_samples = 0
    iterators = {k : iter(v) for k, v in loader.items()}
    for v in metric_objects.values():
        v.reset_state()

    # Training phase 
    cls.train() # Set model to training mode
    with progressbar.ProgressBar(steps_per_epoch) as pbar:
        for j in range(1, steps_per_epoch+1):
            # Initialze loader's iterater
            for k, v in loader.items():
                if j % len(v) == 0:
                    iterators[k] = iter(v)

            # Load a batch of data
            batch = next(iterators['train'])
            fundus = batch['fundus'].to(device)

            y_true = batch["level"].to(device)

            # Forward pass
            y_pred = cls(fundus)
            loss = loss_fn['ce'](y_pred, y_true)

            # Backward pass 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update training metrics
            training_loss['cls_loss'] += loss.item() * y_true.size(0)
            num_samples += y_true.size(0)
            metric_objects['train_auroc'].update_state(y_pred[:, 1], y_true)

            pbar.update(j)

        for k, v in training_loss.items():
            training_loss[k] = v / float(num_samples)
        lr_scheduler.step()

    
    cls.eval() # Set model to evaluation mode.
    for batch in loader['valid']:
        fundus = batch['fundus'].to(device)
        y_true = batch["level"]

        # Forward pass
        with torch.no_grad():
            y_pred = cls(fundus)
            metric_objects['val_auroc'].update_state(y_pred[:, 1], y_true)

    # Display results after an epoch
    i += 1
    print('Epoch: {:d}/{:d}'.format(i, num_epochs))
    print('training classification loss: {:.4f}'.format(training_loss['cls_loss']))
    for k, v in metric_objects.items():
        print('{}: {:.4f}'.format(k, v.result()))
    
    # Save model when reached the highest validation accuracy 
    curr_val_metric = metric_objects['val_auroc'].result()
    if curr_val_metric > best_val_metric:
        best_val_metric = curr_val_metric
        checkpoint = {'cls_state_dict' : cls.state_dict(),
                      'optimizer_state_dict' : optimizer.state_dict(),
                      'lr_scheduler_state_dict' : lr_scheduler.state_dict(),
                      'best_val_metric' : best_val_metric,
                      'epoch' : i,
                      'cnn_name' : cnn_name}
        torch.save(checkpoint, checkpoint_path)
        print('Model saved.')

| |  #                                              | 500 Elapsed Time: 0:04:43


ValueError: too many values to unpack (expected 2)