In [7]:
import torch
import torchvision
import pandas as pd
import glob
import os
from PIL import Image
import cv2
import numpy as np
from PIL import Image
import random
from pathlib import Path
import matplotlib.pyplot as plt

In [8]:
def set_seed(s):
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(s)
    random.seed(s)
    os.environ['PYTHONHASHSEED'] = str(s)
set_seed(0)

In [9]:
from skimage.exposure import equalize_adapthist
from skimage.transform import warp_polar

class CLAHE(torch.nn.Module):
    def forward(self, img):
        image = np.array(img, dtype=np.float64) / 255.0
        image = equalize_adapthist(image)
        image = (image*255).astype('uint8')

        return image

class POLAR(torch.nn.Module):
    def polar(self,image):
        return warp_polar(image, radius=(max(image.shape) // 2), multichannel=True)
    
    def forward(self, image):
        image = np.array(image, dtype=np.float64)
        image = self.polar(image)
        return image

In [10]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

split = "test"
batch_size = 32
num_workers = 0
train_path = f"/home/wangqy/gardnet/RIM-ONE_DL_images/partitioned_by_hospital/training_set" # path to dataset training set
path = f"/home/wangqy/gardnet/RIM-ONE_DL_images/partitioned_by_hospital/{split}_set"        # path to dataset folder
output_dir = "/home/wangqy/gardnet/RIM-ONE_DL_images/OUTPUTS"                               # path to save checkpoints

train_transform = torchvision.transforms.Compose([
            CLAHE(),
            transforms.ToTensor(),
            transforms.Resize(256),
            transforms.RandomVerticalFlip(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(0,scale=(1.0,1.3))
        ])
transform = torchvision.transforms.Compose([
            CLAHE(),
            transforms.ToTensor(),
            transforms.Resize(256)
        ])
train_dataset = ImageFolder(train_path, transform=train_transform)
num = int(np.floor(len(train_dataset) * 1))
indices = np.random.choice(len(train_dataset), num, replace=False)
train_dataset = torch.utils.data.Subset(train_dataset, indices)
train_loader = DataLoader(train_dataset, 
                  batch_size=batch_size, 
                  shuffle=True,
                  num_workers=num_workers,
              )
test_dataset = ImageFolder(path, transform=transform)
test_loader = DataLoader(test_dataset, 
                  batch_size=batch_size, 
                  shuffle=True,
                  num_workers=num_workers,
              )

print(len(train_dataset))
print(len(test_dataset))

311
174




In [11]:
_labels = []
for j in range(len(train_dataset)):
    _labels.append(train_dataset[j][1])
_labels = np.asarray(_labels)
np.unique(_labels)

array([0, 1])

In [12]:
import timm

model_name = "efficientnet_b0"
pretrained = True
dropout = 0.2
lr = 0.0005
#momentum = 0.1
epochs = 20

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

model = timm.create_model(model_name, pretrained=pretrained, num_classes=2, drop_rate=dropout)
model = model.to(device)

In [13]:
path = "/home/wangqy/gardnet/Checkpoints/rimonedl_1.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
print("Best F1 {} from epoch {}\n".format(checkpoint["best_f1"], checkpoint["epoch"]))

Best F1 0.9490370014311152 from epoch 13



In [14]:
import os
from sklearn.utils import class_weight
from torch.nn import CrossEntropyLoss
import torch.optim as optim
from tqdm import tqdm
import sklearn
from sklearn import metrics
from sklearn.metrics import f1_score


if not os.path.exists(output_dir):
    os.makedirs(output_dir)

weight_referable = class_weight.compute_class_weight(class_weight='balanced', classes = np.unique(_labels), y=_labels).astype('float32')    
weight_referable = np.array([weight_referable[0], weight_referable[1]])
criterion = CrossEntropyLoss(weight=torch.from_numpy(weight_referable).to(device))
print(weight_referable)

optimizer = optim.Adam(model.parameters(),lr=lr)

epoch_resume = 0
best_f1 = 0.0


# Train
if epoch_resume < epochs:
    print("Resuming training\n")
    for epoch in range(epoch_resume, epochs):
        for split in ['Train']:
            if split == "Train":
                model.train()
            else:
                model.eval()

            epoch_total_loss = 0
            labels = []
            predictions = []
            loader = train_loader if split == "Train" else val_loader
            for batch_num, (inp, target) in enumerate(tqdm(loader)):
                labels+=(target)
                optimizer.zero_grad()
                output = model(inp.to(device))
                _, batch_prediction = torch.max(output, dim=1)
                predictions += batch_prediction.detach().tolist()
                batch_loss = criterion(output, (target).to(device))
                epoch_total_loss += batch_loss.item()

                if split == "Train":
                    batch_loss.backward()
                    optimizer.step()

            avrg_loss = epoch_total_loss / loader.dataset.__len__()
            accuracy = metrics.accuracy_score(labels, predictions)
            confusion = metrics.confusion_matrix(labels, predictions)
            _f1_score = f1_score(labels, predictions, average="macro")
            auc = sklearn.metrics.roc_auc_score(labels, predictions)
            print("%s Epoch %d - loss=%0.4f AUC=%0.4f F1=%0.4f  Accuracy=%0.4f" % (split, epoch, avrg_loss, auc, _f1_score, accuracy))


        # save model
        checkpoint = {
            'epoch': epoch,
            'best_f1': best_f1,
            'f1': _f1_score,
            'auc': auc,
            'loss': avrg_loss,
            'state_dict': model.state_dict(),
            'opt_dict': optimizer.state_dict(),
            #'scheduler_dict': scheduler.state_dict()
        }

        torch.save(checkpoint, os.path.join(output_dir, f"checkpoint_{epoch}.pt"))
        if _f1_score > best_f1:
            best_f1 = _f1_score
            checkpoint["best_f1"] = best_f1
            torch.save(checkpoint, os.path.join(output_dir, "best.pt"))
else:
    print("Skipping training\n")

[1.3405173 0.7974359]
Resuming training



100%|██████████| 10/10 [00:12<00:00,  1.21s/it]


Train Epoch 0 - loss=0.1157 AUC=0.0386 F1=0.0413  Accuracy=0.0418


100%|██████████| 10/10 [00:11<00:00,  1.20s/it]


Train Epoch 1 - loss=0.0524 AUC=0.1756 F1=0.1833  Accuracy=0.1961


100%|██████████| 10/10 [00:11<00:00,  1.13s/it]


Train Epoch 2 - loss=0.0258 AUC=0.4359 F1=0.4252  Accuracy=0.4984


100%|██████████| 10/10 [00:11<00:00,  1.13s/it]


Train Epoch 3 - loss=0.0177 AUC=0.6633 F1=0.6608  Accuracy=0.6785


100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Train Epoch 4 - loss=0.0133 AUC=0.8347 F1=0.8077  Accuracy=0.8103


100%|██████████| 10/10 [00:11<00:00,  1.13s/it]


Train Epoch 5 - loss=0.0114 AUC=0.8690 F1=0.8458  Accuracy=0.8489


100%|██████████| 10/10 [00:11<00:00,  1.14s/it]


Train Epoch 6 - loss=0.0096 AUC=0.8758 F1=0.8578  Accuracy=0.8617


100%|██████████| 10/10 [00:11<00:00,  1.14s/it]


Train Epoch 7 - loss=0.0080 AUC=0.9145 F1=0.8943  Accuracy=0.8971


100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Train Epoch 8 - loss=0.0066 AUC=0.9272 F1=0.9161  Accuracy=0.9196


100%|██████████| 10/10 [00:11<00:00,  1.20s/it]


Train Epoch 9 - loss=0.0057 AUC=0.9288 F1=0.9221  Accuracy=0.9260


100%|██████████| 10/10 [00:11<00:00,  1.17s/it]


Train Epoch 10 - loss=0.0052 AUC=0.9495 F1=0.9364  Accuracy=0.9389


100%|██████████| 10/10 [00:11<00:00,  1.18s/it]


Train Epoch 11 - loss=0.0042 AUC=0.9580 F1=0.9493  Accuracy=0.9518


100%|██████████| 10/10 [00:11<00:00,  1.20s/it]


Train Epoch 12 - loss=0.0052 AUC=0.9494 F1=0.9424  Accuracy=0.9453


100%|██████████| 10/10 [00:11<00:00,  1.19s/it]


Train Epoch 13 - loss=0.0044 AUC=0.9554 F1=0.9522  Accuracy=0.9550


100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Train Epoch 14 - loss=0.0040 AUC=0.9640 F1=0.9592  Accuracy=0.9614


100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Train Epoch 15 - loss=0.0040 AUC=0.9580 F1=0.9493  Accuracy=0.9518


100%|██████████| 10/10 [00:11<00:00,  1.14s/it]


Train Epoch 16 - loss=0.0035 AUC=0.9683 F1=0.9626  Accuracy=0.9646


100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Train Epoch 17 - loss=0.0036 AUC=0.9699 F1=0.9691  Accuracy=0.9711


100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Train Epoch 18 - loss=0.0033 AUC=0.9726 F1=0.9661  Accuracy=0.9678


100%|██████████| 10/10 [00:11<00:00,  1.12s/it]


Train Epoch 19 - loss=0.0033 AUC=0.9667 F1=0.9562  Accuracy=0.9582


In [15]:
path = f"{output_dir}/best.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
print("Best F1 {} from epoch {}\n".format(checkpoint["best_f1"], checkpoint["epoch"]))

Best F1 0.9691185718856538 from epoch 17



In [16]:
import torch
from tqdm import tqdm
import sklearn
from sklearn import metrics
from sklearn.metrics import f1_score

model.eval()
labels = []
predictions = []
with torch.no_grad():
    for (inp, target) in tqdm(test_loader):
        labels+=(target)
        batch_prediction = model(inp.to(device))
        _, batch_prediction = torch.max(batch_prediction, dim=1)
        predictions += batch_prediction.detach().tolist()
accuracy = metrics.accuracy_score(labels, predictions)
print("Test Accuracy = %0.5f" % (accuracy))

confusion = metrics.confusion_matrix(labels, predictions)
print(confusion)

_f1_score = f1_score(labels, predictions, average="macro")
print("Test F1 = %0.5f" % (_f1_score))

auc = sklearn.metrics.roc_auc_score(labels, predictions)
print("Test AUC = %0.5f" % (auc))

100%|██████████| 6/6 [00:04<00:00,  1.45it/s]

Test Accuracy = 0.83908
[[47  9]
 [19 99]]
Test F1 = 0.82330
Test AUC = 0.83913



