In [None]:
from pathlib import Path
import os
import glob
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, random_split, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
%matplotlib inline
%reload_ext tensorboard
import random
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, f1_score, auc
from datetime import datetime
from time import time
import PIL
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from matplotlib import pyplot
!pip install efficientnet_pytorch
from efficientnet_pytorch import EfficientNet

def show_batch(data_loader, n=8, nrow=4):
    for images, labels in data_loader:
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images[:n], nrow=nrow).permute(1, 2, 0))
        break

def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [None]:
DATASET = 'test1'
LEARNING_RATE = 0.001
LR_DECAY = 0.96
LR_UPDATE_INTERVAL_IN_ITERATIONS = None # initialized later to every epoch, if value is None
MODEL_SAVE_INTERVAL_IN_EPOCHS = 1

NUM_WORKERS = 1
LOG_INTERVAL = None # initialized later to every epoch, if value is None
IMG_RECONSTRUCTION_INTERVAL = 500
SEED = 1
GPU_DEVICE = 0
MULTI_GPU = False 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

DATA_PATH = 'logs/'
G_DRIVE_DIR_BASE = '/content/drive/MyDrive/Colab Notebooks/BEH/'
MODEL_DIR_BASE = DATA_PATH + 'models/'
TB_RUN_DIR_BASE = DATA_PATH + 'runs/'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
LOG_INTERVAL = None # initialized later to every epoch, if value is None
LR_UPDATE_INTERVAL_IN_ITERATIONS = None # initialized later to every epoch, if value is None

if DATASET == 'test1':
    if not os.path.exists('/content/BEH'):
        !unzip -qq '/content/drive/MyDrive/Colab Notebooks/Glaucoma detection/Data/BEH.zip' # link to file on gdrive

    EPOCHS = 10
    BATCH_SIZE = 6 #13 #16
    TRAIN_VAL_RATIO = 0.90
    train_data_path = '/content/BEH/train'
    validation_data_path = '/content/BEH/validation'
    test_data_path = '/content/BEH/test'
    classes = ['glaucoma', 'normal']

    test_set_filenames = [
        sorted([f for f in os.listdir(test_data_path+'/Glaucoma') if os.path.isfile(os.path.join(test_data_path+'/Glaucoma', f))]),
        sorted([f for f in os.listdir(test_data_path+'/Normal') if os.path.isfile(os.path.join(test_data_path+'/Normal', f))])
    ]

    common_transforms = transforms.Compose([
        # transforms.Resize((256, 256)),
        # #transforms.Grayscale(num_output_channels=3),
        # transforms.ToTensor(),
        # transforms.Normalize(mean=[0.25, 0.25, 0.25],std=[0.9, 0.9, 0.9])
        #transforms.ToPILImage(),
        transforms.Resize((300, 300)),
        #transforms.CenterCrop((100, 100)),
        #transforms.RandomCrop((80, 80)),
        # transforms.RandomHorizontalFlip(p=0.5),
        # transforms.RandomRotation(degrees=(-90, 90)),
        # transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        # transforms.Normalize(mean=[0.25, 0.25, 0.25],std=[0.9, 0.9, 0.9])
        transforms.Normalize(mean=[0.05, 0.05, 0.05],std=[0.5, 0.5, 0.5])
    ])

    train_transforms = transforms.Compose([
        #transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
        #transforms.RandomAffine(degrees=(-2,2), translate=None, scale=(0.95, 1.1), shear=0, resample=False, fillcolor=0),
        #transforms.CenterCrop((250, 250)),
        #transforms.Resize((256, 256)),
        # transforms.Resize((300, 300)),
        #transforms.CenterCrop((100, 100)),
        #transforms.RandomCrop((80, 80)),
        # transforms.RandomHorizontalFlip(p=0.5),
        # transforms.RandomRotation(degrees=(-90, 90)),
        # transforms.RandomVerticalFlip(p=0.5),
        # transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0, hue=0),
        # transforms.ToTensor(),
        # #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        # transforms.Normalize(mean=[0.25, 0.25, 0.25],std=[0.9, 0.9, 0.9]),
        common_transforms
    ])

    test_transforms = transforms.Compose([
        common_transforms
    ])
else:
    raise ValueError('DATASET not specified')

TB_COMMENT = f'Network {DATASET} batch_size={BATCH_SIZE} lr={LEARNING_RATE} lr_decay={LR_DECAY} E={EPOCHS}'
TB_RUN_DIR = TB_RUN_DIR_BASE + TB_COMMENT
MODEL_DIR = MODEL_DIR_BASE + TB_COMMENT
G_DRIVE_DIR = G_DRIVE_DIR_BASE + TB_COMMENT
print(TB_COMMENT)


loaders = {}

train_set = torchvision.datasets.ImageFolder(root=train_data_path, transform = train_transforms)
loaders['train'] = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

validation_set = torchvision.datasets.ImageFolder(root=validation_data_path, transform = train_transforms)
loaders['validation'] = torch.utils.data.DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

test_set = torchvision.datasets.ImageFolder(root=test_data_path, transform = test_transforms)
loaders['test'] = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)


if LOG_INTERVAL == None:
    LOG_INTERVAL =  len(loaders['train'])

print( 8*'#', f'Using {DATASET} dataset', 8*'#')
print("Train loader. \tSize: ", len(train_set), '\tData Shape: ', train_set[0][0].shape, '\tBatch len: ', len(loaders['train']))
print("Val loader. \tSize: ", len(validation_set), '\tData Shape: ', validation_set[0][0].shape, '\tBatch len: ', len(loaders['validation']))
print("Test loader.  \tSize: ", len(test_set), '\tData Shape: ', test_set[0][0].shape, '\tBatch len: ', len(loaders['test']))

if LR_UPDATE_INTERVAL_IN_ITERATIONS == None: LR_UPDATE_INTERVAL_IN_ITERATIONS = len(loaders['train'])
MODEL_SAVE_INTERVAL_IN_ITERATIONS = MODEL_SAVE_INTERVAL_IN_EPOCHS * len(loaders['train'])


# Show data in data loaders

# Trainining samples
show_batch(loaders['train'])

# Validation samples 
show_batch(loaders['validation'])

# Testing samples
show_batch(loaders['test'])

if os.path.exists(DATA_PATH):
    !rm -r '/content/logs'

if not os.path.exists(DATA_PATH):
    os.mkdir(DATA_PATH)

if not os.path.exists(MODEL_DIR_BASE):
    os.mkdir(MODEL_DIR_BASE)

if not os.path.exists(MODEL_DIR):
    os.mkdir(MODEL_DIR)

if not os.path.exists(TB_RUN_DIR_BASE):
    os.mkdir(TB_RUN_DIR_BASE)

if not os.path.exists(TB_RUN_DIR):
    os.mkdir(TB_RUN_DIR)

if GPU_DEVICE is not None:
    torch.cuda.set_device(GPU_DEVICE)

if MULTI_GPU:
    batch_size *= torch.cuda.device_count()
  
# train_sets = []
# train_loaders = []
# THRESHOLD = 464

# while True:
#     if len(train_set) < THRESHOLD:
#         train_loaders.append(torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS))
#         break

#     train_set, new_train_set_segment = random_split(train_set, [len(train_set) - THRESHOLD, THRESHOLD])
#     train_loaders.append(torch.utils.data.DataLoader(new_train_set_segment, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS))
# print(len(train_loaders))

In [None]:
class Trainer:
    """
    Wrapper object for handling training and evaluation
    """
    def __init__(self, loaders, batch_size, learning_rate, lr_decay, device, multi_gpu):
        self.tb = SummaryWriter(comment=TB_COMMENT, log_dir=TB_RUN_DIR)
        self.device = device
        self.multi_gpu = multi_gpu
        self.all_preds = []
        self.all_preds_list = []
        self.all_labels = []
        self.all_labels_list = []
        self.all_confidences_list = []
        self.incorrect_samples = []
        self.incorrect_samples_targets = []

        self.loaders = loaders
        img_shape = self.loaders['train'].dataset[0][0].numpy().shape
        
        self.net = EfficientNet.from_pretrained('efficientnet-b3')
        # self.net = torchvision.models.densenet161(pretrained=True)
        # self.net = torchvision.models.resnet50(pretrained=True)
        # self.net = torchvision.models.googlenet(pretrained=True)
        # self.net = torchvision.models.mobilenet_v2(pretrained=True)
        # self.net = torch.hub.load(repo, model_name, pretrained=True)
        # self.net = torchvision.models.shufflenet_v2_x1_0(pretrained=False)
        self.net = self.net.cuda()
        
        if self.multi_gpu:
            self.net = nn.DataParallel(self.net)

        self.optimizer = optim.Adam(self.net.parameters(), lr=learning_rate)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=lr_decay)
        print(10*'#', 'PyTorch Model built'.upper(), 10*'#')
        print('No. of params:', sum([np.prod(p.size()) for p in self.net.parameters()]))
        print(TB_COMMENT)
    
    def __repr__(self):
        return repr(self.net)

    def run(self, epochs, classes):
        print(8*'#', 'Run started'.upper(), 8*'#')
        eye = torch.eye(len(classes)).to(self.device)
        
        epoch_time=time()
        for epoch in range(1, epochs+1):
            for phase in ['train', 'validation']:
                if phase == 'train':
                    self.net.train()
                else:
                    self.net.eval()

                t0 = time()
                running_loss = 0.0
                running_margin_loss = 0.0
                running_reconstruction_loss = 0.0
                correct = 0; total = 0
                batch_len = len(self.loaders[phase])

                for i, (images, labels) in enumerate(self.loaders[phase]):
                    n_iter = ((epoch-1) * batch_len) + i
                    t1 = time()
                    images, labels = images.to(self.device), labels.to(self.device)

                    self.optimizer.zero_grad()

                    preds = self.net(images) # reconstructions[ BATCH_SIZE, CHANNEL_NO, IMG_DIM, IMG_DIM]
                    loss = F.cross_entropy(preds, labels) # Loss function

                    if phase == 'train':
                        loss.backward()
                        self.optimizer.step()

                    running_loss += loss.item()

                    total += labels.size(0)
                    correct += get_num_correct(preds, labels)
                    accuracy = float(correct) / float(total)

                    if phase == 'train' and (n_iter % LOG_INTERVAL) == 0:
                        print('-----------------------\n', 'Epoch: ', epoch)

                    if phase == 'train' and (n_iter % LR_UPDATE_INTERVAL_IN_ITERATIONS) == 0 and (n_iter != 0):
                        self.scheduler.step()

                    if phase == 'train' and (n_iter % MODEL_SAVE_INTERVAL_IN_ITERATIONS) == 0 and (n_iter != 0):
                        torch.save(self.net.state_dict(), os.path.join(MODEL_DIR, str(n_iter)+'.pth.tar'))
                    

                print('{} \t  Loss: {:.5f}  Accuracy: {:.5f}  Time: {:.3f}s'.format(phase.upper(), running_loss/(i+1), accuracy, time()-t0))
                n_iter = epoch * batch_len
                self.tb.add_scalar(f'{phase}/loss', running_loss/(i+1), n_iter)
                self.tb.add_scalar(f'{phase}/accuracy', accuracy, n_iter)
            print('Epoch Ended at: {}'.format(time()))
            print('Epoch: {:02d} {:.3f}s elapsed'.format(epoch,time()-epoch_time)) 
        self.tb.close()
            
        now = str(datetime.now()).replace(" ", "-")
        error_rate = round((1-accuracy)*100, 2)
        torch.save(self.net.state_dict(), os.path.join(MODEL_DIR, 'model.pth.tar'))

    
    def test(self, show_per_class_accuracy=False, show_incorrect_inferences=False):
        self.net.eval()
        self.all_preds = []
        self.all_labels = []
        self.incorrect_samples = []
        self.incorrect_samples_idx = []
        self.incorrect_samples_filename = []
        self.incorrect_samples_targets = []
        self.incorrect_samples_confidences = []
        self.all_confidences = []
        
        eye = torch.eye(len(classes)).to(self.device)
        t0 = time()
        running_loss = 0.0
        correct = 0; total = 0
        batch_len = len(self.loaders['test'])
        for i, (images, labels) in enumerate(self.loaders['test']):
            t1 = time()
            images, labels = images.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()

            preds = self.net(images)
            confidences = []
            confidence_list_for_epoch = []

            for i in range(len(preds)):
                confidence = (torch.nn.functional.softmax(preds[i][:2], dim=0))[0].item()
                self.all_confidences.append(confidence)

            loss = F.cross_entropy(preds, labels) # Loss function

            if show_incorrect_inferences:
                incorrect_idxes = torch.nonzero((preds.argmax(dim=1).eq(labels)==False))

                for incorrect_idx in incorrect_idxes:
                    idx = incorrect_idx.item()
                    self.incorrect_samples.append(images[idx])
                    relative_idx = (i*BATCH_SIZE + idx) if labels[idx].item()==0 else ((i*BATCH_SIZE + idx)-len(test_set_filenames[0]))
                    self.incorrect_samples_idx.append(relative_idx)
                    class_name = 'glaucoma' if labels[idx].item()==0 else 'normal'
                    self.incorrect_samples_filename.append( class_name + '/' + test_set_filenames[labels[idx].item()][relative_idx])
                    self.incorrect_samples_targets.append(labels[idx].item())
                    confidence = (torch.nn.functional.softmax(preds[incorrect_idx][:2], dim=1))[0][ preds[incorrect_idx].argmax(dim=1).item() ].item()
                    self.incorrect_samples_confidences.append(confidence)

            running_loss += loss.item()
            total += labels.size(0)
            correct += get_num_correct(preds, labels)
            accuracy = float(correct) / float(total)
            self.all_labels = np.append(self.all_labels, labels.cpu().numpy())
            self.all_preds = np.append(self.all_preds, preds.argmax(dim=1).cpu().numpy())

        self.all_labels_list.append(self.all_labels)
        self.all_preds_list.append(self.all_preds)
        self.all_confidences_list.append(self.all_confidences)
        
        print('{} \tLoss: {:.5f}   Accuracy: {:.5f}  Time: {:.3f}s ------ Test Accuracy: {:.5f}'.format( #22
            'TEST', 
            running_loss/(i+1), 
            accuracy, 
            time()-t0,
            accuracy
            )) #22
            
        now = str(datetime.now()).replace(" ", "-")
        error_rate = round((1-accuracy)*100, 2)

        if show_per_class_accuracy:
            class_correct = list(0. for _ in classes)
            class_total = list(0. for _ in classes)
            for images, labels in self.loaders['test']:
                images, labels = images.to(self.device), labels.to(self.device)

                preds = self.net(images)
                preds = preds.argmax(dim=1)
                for i in range(labels.size(0)):
                    label = labels[i]
                    if labels[i] == preds[i]:
                            class_correct[label] += 1
                    class_total[label] += 1
                        
            print('\nPer class accuracy on TEST set:')
            for i in range(len(classes)):
                print('Accuracy of {} ({}) : {:.2f}%     ({:5d}/{:5d})'.format(classes[i], i, 100 * class_correct[i] / class_total[i], int(class_correct[i]), int(class_total[i])))

    def show_incorrect_prediction(self):
        print('\nIncorrect samples\' corrrect labels: ', self.incorrect_samples_targets)
        print('\nIncorrect samples\' confidences: ', self.incorrect_samples_confidences)
        print('\nIncorrect samples\' idxes: ', self.incorrect_samples_idx)
        print('\nIncorrect samples\' filenames (', len(self.incorrect_samples_filename), ') : ', self.incorrect_samples_filename)
        print('Incorrectly predicted samples:')
        fig, ax = plt.subplots(figsize=(25, 25))
        ax.set_xticks([]); ax.set_yticks([])
        img_grid = torchvision.utils.make_grid(self.incorrect_samples, nrow=10, normalize=True)
        _ = ax.imshow(make_grid(img_grid.cpu().detach().permute(1, 2, 0)))

    def show_classification_report(self, epoch, target_names=classes):
        all_preds = self.all_preds_list[epoch]
        print(classification_report(self.all_labels, all_preds, target_names=target_names, digits=4))

    def show_confusion_matrix(self, epoch=0, xticklabels=classes, yticklabels=classes, ):
        all_labels = self.all_labels_list[epoch]
        all_preds = self.all_preds_list[epoch]
        confusion_matrix_test = confusion_matrix(all_labels, all_preds, labels = None, sample_weight = None, normalize = None)
        heatmap_test = sn.heatmap(confusion_matrix_test, annot=True)
        _ = heatmap_test.set(xlabel='Predicted label', ylabel='Actual label', xticklabels=xticklabels, yticklabels=yticklabels)
    
    def saveData(self, G_DRIVE_DIR=''):
        if D_DRIVE_DIR != '':
            try:
                if not os.path.exists(G_DRIVE_DIR):
                    os.mkdir(G_DRIVE_DIR)
                    os.mkdir(G_DRIVE_DIR+'/models')
                    os.mkdir(G_DRIVE_DIR+'/runs')
            except:
                print('ERROR: G_DRIVE_DIR dir creation error')

            try:
                dest = shutil.move( MODEL_DIR, os.path.join(G_DRIVE_DIR, 'models'))
                print("Transfered to: ", dest)
            except:
                print('ERROR: G_DRIVE_DIR model transfer error')

            try:
                dest = shutil.move( TB_RUN_DIR, os.path.join(G_DRIVE_DIR, 'runs'))
                print("Transfered to: ", dest)
            except:
                print('ERROR: G_DRIVE_DIR runs transfer error')

    def load(self, load_path):
        if load_path != None:
            try:
                self.net.load_state_dict(torch.load(load_path))
                _ = self.net.eval()
                print('Model state loaded')
            except Exception as e: 
                print('ERROR: Model state load error: ', e)

In [None]:
# Train Model
EPOCHS = 70
LEARNING_RATE = 0.001

TB_COMMENT = f'Network {DATASET} batch_size={BATCH_SIZE} lr={LEARNING_RATE} lr_decay={LR_DECAY} E={EPOCHS}'
print(TB_COMMENT)


net_trainer = Trainer(loaders, BATCH_SIZE, LEARNING_RATE, LR_DECAY, device=DEVICE, multi_gpu=MULTI_GPU)
# net_trainer.load(load_path = None)
net_trainer.run(EPOCHS, classes=classes)
# net_trainer.saveData('/content/BEH_Net')
try:
  net_trainer.test(show_per_class_accuracy=True, show_incorrect_inferences=False)
except Exception as e:
  print("Exception ", e)

In [None]:
EPOCH = 62

net_trainer.show_classification_report(epoch=EPOCH-1)

roc_auc = roc_auc_score(net_trainer.all_labels_list[EPOCH-1], net_trainer.all_preds_list[EPOCH-1])
print('ROC AUC Score: {:.4f}'.format(roc_auc))

print('\n\nConfusion matrix:')
net_trainer.show_confusion_matrix(epoch = EPOCH-1)