## Import libraries and define useful things

In [1]:
import sys

import numpy as np
import pandas as pd
from PIL import Image as PIL_Image

import os, glob
import gc

from collections import OrderedDict

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score, f1_score
from sklearn.preprocessing import OneHotEncoder

import matplotlib
from matplotlib import pyplot as plt

# from tqdm import tqdm
from tqdm.notebook import tqdm

from scipy.special import softmax
# from scipy.special import expit
from scipy.spatial import distance

import relplot as rp

# sys.path.insert(1, '../RETFound_MAE/')

import torch
import torch.nn as nn
import models_vit
from util.pos_embed import interpolate_pos_embed
# from timm.models.layers import trunc_normal_
from timm.layers import trunc_normal_
import util.lr_decay as lrd

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision import transforms as T
# from torchvision.transforms import v2 as T

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(f'Max float : {sys.float_info.max}')
print(torch.__version__)
print(f'Cuda available : {torch.cuda.is_available()}')
print(f'Number of GPUs : {torch.cuda.device_count()}')
print(f'CUDA Version : {torch.version.cuda}')
print(f'timm Version : {timm.__version__}')

torch.set_default_dtype(torch.float32)

has_gpu = torch.cuda.is_available()
has_mps = torch.backends.mps.is_built() #getattr(torch, 'has_mps', False)
device = 'mps' if torch.backends.mps.is_built() else 'gpu' if torch.cuda.is_available() else 'cpu'

chkpt_dir = './RETFound_mae_natureCFP.pth'
model_name = 'RETFound_mae'
input_size = 224
num_classes=5

def prepare_model(chkpt_dir, arch=model_name):
    # build model
    model = models_vit.__dict__[arch](
        img_size=input_size,
        num_classes=5,
        drop_path_rate=0,
        global_pool=True,
    )
    # load model
    checkpoint = torch.load(chkpt_dir, weights_only=False, map_location=device)
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    return model


model_save_dir = './modelstore/IDRiD_FT_RETFound_multiclass/'
model_descriptor = model_name + 'Nature_CFP_'

Max float : 1.7976931348623157e+308
2.7.0
Cuda available : False
Number of GPUs : 0
CUDA Version : None
timm Version : 1.0.15


In [2]:
# IDRiD 
img_dir_tr = '/Users/msa/Datasets/IDRiD/DiseaseGrading/OriginalImages/TrainingSet/crop_224/'
# full_path_list_tr = sorted(glob.glob(img_dir_tr + '*' + '.jpg', recursive=False))
# print(f'Number of files in {img_dir_tr}\t{len(full_path_list_tr)}', flush=True)

csv_file_tr = '/Users/msa/Datasets/IDRiD/DiseaseGrading/Groundtruths/TrainingLabels.csv'
df_metadata_tr = pd.read_csv(csv_file_tr, low_memory=False)
df_metadata_tr = df_metadata_tr[['Image name', 'Retinopathy grade', 'Risk of macular edema ']]
file_paths = []
split = []
for idx, row in df_metadata_tr.iterrows():
    file_paths.append(img_dir_tr + str(row['Image name']) + '.png') # '.jpg')
    split.append('train')
df_metadata_tr['file_path'] = file_paths
df_metadata_tr['split'] = split
print(f'Metadata shape : {df_metadata_tr.shape}')
print(df_metadata_tr.columns)

img_dir_te = '/Users/msa/Datasets/IDRiD/DiseaseGrading/OriginalImages/TestingSet/crop_224/'
# full_path_list_te = sorted(glob.glob(img_dir_te + '*' + '.jpg', recursive=False))
# print(f'Number of files in {img_dir_te}\t{len(full_path_list_te)}', flush=True)

csv_file_te = '/Users/msa/Datasets/IDRiD/DiseaseGrading/Groundtruths/TestingLabels.csv'
df_metadata_te = pd.read_csv(csv_file_te, low_memory=False)
file_paths = []
split = []
for idx, row in df_metadata_te.iterrows():
    file_paths.append(img_dir_te + str(row['Image name']) + '.png') # '.jpg')
    split.append('test')
df_metadata_te['file_path'] = file_paths
df_metadata_te['split'] = split
print(f'Metadata shape : {df_metadata_te.shape}')
print(df_metadata_te.columns)

df_metadata = pd.concat([df_metadata_tr, df_metadata_te], axis=0)
print(f'Metadata shape : {df_metadata.shape}')
print(df_metadata.columns)

del df_metadata_tr, df_metadata_te, file_paths, split

Metadata shape : (413, 5)
Index(['Image name', 'Retinopathy grade', 'Risk of macular edema ',
       'file_path', 'split'],
      dtype='object')
Metadata shape : (103, 5)
Index(['Image name', 'Retinopathy grade', 'Risk of macular edema ',
       'file_path', 'split'],
      dtype='object')
Metadata shape : (516, 5)
Index(['Image name', 'Retinopathy grade', 'Risk of macular edema ',
       'file_path', 'split'],
      dtype='object')


In [3]:
random_seed = 42
#### df_metadata = df_metadata.sample(frac=1.0, random_state=random_seed) Bonkers!
num_folds = 10
# num_repeats = 5

metadata_splits = []
# Same order as in the .csv files and df_metadata
with open(f'IDRiD_Features_MultiClass.npy', 'rb') as handle:
    X = np.load(handle)
    y = np.load(handle)

print(f'{np.unique(y, return_counts=True)[1]/np.sum(np.unique(y, return_counts=True)[1])}')

kFoldCV = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=random_seed)
for i, (trainval_index, test_index) in enumerate(kFoldCV.split(X, y)):

    X_trainval = X[trainval_index,:]
    y_trainval = y[trainval_index]
    df_metadata_trainval = df_metadata.iloc[trainval_index]

    kFoldCV_inner = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=random_seed)
    for ii, (train_index, val_index) in enumerate(kFoldCV_inner.split(X_trainval, y_trainval)):
        break

    metadata_dict = {"train": df_metadata_trainval.iloc[train_index], 
                     "val": df_metadata_trainval.iloc[val_index], 
                     "test": df_metadata.iloc[test_index]
                    }
    
    print(metadata_dict['train'].shape)
    print(metadata_dict['val'].shape)
    print(metadata_dict['test'].shape)
    
    metadata_splits.append(metadata_dict)

del metadata_dict, X_trainval, y_trainval, trainval_index

[0.3255814  0.04844961 0.3255814  0.18023256 0.12015504]
(417, 5)
(47, 5)
(52, 5)
(417, 5)
(47, 5)
(52, 5)
(417, 5)
(47, 5)
(52, 5)
(417, 5)
(47, 5)
(52, 5)
(417, 5)
(47, 5)
(52, 5)
(417, 5)
(47, 5)
(52, 5)
(418, 5)
(47, 5)
(51, 5)
(418, 5)
(47, 5)
(51, 5)
(418, 5)
(47, 5)
(51, 5)
(418, 5)
(47, 5)
(51, 5)


## Preprare dataset, dataloaders, transformations and perform cross-validation with RETFound

In [4]:
from bazinga import IDRiD_ImageDataset

transforms_train = T.Compose([
    T.ToTensor(), #v1
    # T.ToImage(),  # v2
    # T.ToDtype(torch.uint8, scale=True), # v2
    
    # T.Resize(size=(input_size,input_size), interpolation=T.InterpolationMode.BILINEAR),
    T.RandomResizedCrop(size=(input_size, input_size), scale=(0.9, 1.0), ratio=(0.9, 1.1), interpolation=T.InterpolationMode.BILINEAR), 
    
    T.RandomApply([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.9),
    T.RandomApply([T.GaussianBlur((23,23), sigma=(0.1, 2.0))], p=0.1),

    # T.RandomGrayscale(p=0.25), 
    
#     # Following the color transformations, spatial/geometric transformations are due.
#     # T.RandomHorizontalFlip(p=0.25),
#     # T.RandomVerticalFlip(p=0.25),    
#     # use NEAREST. Others cause values outside [0,255]
    T.RandomApply([T.RandomRotation(10, interpolation=T.InterpolationMode.BILINEAR)], p=0.9),
    
    # T.ToDtype(torch.float32, scale=True), # v2
    T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), # [-1, 1]
])

transforms_inf = T.Compose([
    T.ToTensor(), #v1
    # T.ToImage(),  # v2
    # T.ToDtype(torch.uint8, scale=True), # v2
    
    # T.Resize(size=(input_size,input_size), interpolation=T.InterpolationMode.BILINEAR),

    # T.ToDtype(torch.float32, scale=True), # v2
    T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), # [-1, 1]
])

transforms = {'train': transforms_train, 'val': transforms_inf, 'test': transforms_inf}

In [5]:
# num_workers = 8
# batch_size = 1 # 32 // n_views #52 # 128 # 96 # 128 # 208 # 164 # 112 # 75 # 22*4

# # Note that shuffle is mutually exclusive with Sampler
# # shuffle_dict = {'train': False, 'test': False} #, 'test': False}

# split = 'train'

# idrid_dataset = IDRiD_ImageDataset(metadata_splits[0][split], transforms=transforms[split], target_transforms=None)

# dataloader = DataLoader(idrid_dataset, batch_size=batch_size,
#                         shuffle=False, sampler=None, # samplers[split], 
#                         num_workers=num_workers, pin_memory=True)

# width = 5
# height = 5
# n_rows = 5 # len(data_loaders)
# n_cols = 5

# f = plt.figure(figsize=(n_cols*width, n_rows*height))

# loader = iter(dataloader)
# print(f'Size : {len(idrid_dataset)}')

# for i in range(n_rows): #, (split, loader) in enumerate(data_loaders.items()):
    
#     for j in range(n_cols):

#         img, label = next(loader) # cfp and oct views packed together
        
#         idx = (i*n_cols)+j
        
#         ax = f.add_subplot(n_rows, n_cols, idx+1)
#         img = torch.squeeze(img)
#         temp_img = torch.squeeze(img.permute(1,2,0))
#         print(f'Min : {torch.amin(temp_img)}\tMean : {temp_img.mean((0,1))}\tStd : {temp_img.std((0,1))}\tMax : {torch.amax(temp_img)}')

#         ax.imshow(temp_img)
        
#         ax.set_title(f'{label.item()}')
#         ax.set_xlabel('')
#         ax.set_ylabel('')
#         ax.set_xticks([])
#         ax.set_xticklabels([])
#         ax.set_yticks([])
#         ax.set_yticklabels([])

# # plt.savefig('../../retfoundm_images_multiview.png')
# plt.show()

In [6]:
#################################################################################
### My generic model architecture for adapting RETFound to downstream tasks ###
#################################################################################
class DownstreamTask_Network(nn.Module):
    def __init__(self, backbone, num_outputs, frozen_backbone=True, hidden_dims=[1024, 2048, 512]):
        super(DownstreamTask_Network, self).__init__()
        
        self.backbone = backbone #.fullstack
        if frozen_backbone:
            print(f'Freezing the backbone parameters')
            for param in self.backbone.parameters():
                param.requires_grad = False
        self.num_outputs = num_outputs
        self.hidden_dims = hidden_dims # dim[0] corresponds to the input layer!
        
        # reconstruct the avg. pooling layer before head
        # self.backbone.append(nn.Sequential(nn.AdaptiveAvgPool2d(1), 
        #                                    nn.LayerNorm([self.hidden_dims[0],1,1]), 
        #                                    nn.Flatten(start_dim=1, end_dim=-1)
        #                                   )
        #                     )        
        
        # self.head = nn.Sequential()
        modules = []
        if len(self.hidden_dims) > 1: # hidden layers exist
            for idx in range(len(self.hidden_dims)-1):
                modules.append(nn.Linear(self.hidden_dims[idx], self.hidden_dims[idx+1]))
                modules.append(nn.LayerNorm(self.hidden_dims[idx+1]))
                modules.append(nn.GELU(approximate='tanh'))

        modules.append(nn.Linear(self.hidden_dims[-1], self.num_outputs))
        
        self.head = nn.Sequential(*modules)
    
    def forward_features(self, x):
        return self.backbone.forward_features(x)
    
    def forward(self, x):
        return self.head(self.forward_features(x))

In [7]:
class PrintCallback:
    def on_epoch_end(self, epoch, logs):
        print(f"Epoch {epoch}: loss[train] = {logs['loss']['train']:.5f}\tloss[val] = {logs['loss']['val']:.5f}") #, accuracy = {logs['accuracy']:.4f}")

In [None]:
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs, batch_accumulation_steps, device, callbacks=None, 
                model_save_dir='../../', model_descriptor='IDRiD_RETFound_'):
    
    if model_save_dir is not None and not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir, exist_ok=True)
    
    total_loss_history = {'train' : [], 'val' : []}
    best_loss = {'train' : sys.float_info.max, 'val' : sys.float_info.max}
        
    # Create once at the beginning of training
    # scaler = torch.amp.GradScaler(device=device)
    # log_softmax = nn.LogSoftmax(dim=1)

    for epoch in range(num_epochs):

        print(f'Epoch {epoch+1}/{num_epochs}', flush=True)
        print('-' * 10, flush=True)

        epoch_loss = {'train' : 0.0, 'val' : 0.0}
        logs = {'loss': epoch_loss}
    
        for phase in ['train','val']:

            print(f'# of examples for "{phase}" : {len(dataloaders[phase].dataset)}', flush=True)

            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = [] 
            # zero the parameter gradients
            optimizer.zero_grad(set_to_none=True)
            batch_idx = 0 
            
            for datum in tqdm(iter(dataloaders[phase])): # Each is of shape B x K x 3 x H x W
                # Casts operations to mixed precision
                # forward
                # with torch.autocast(device_type=device, dtype=torch.float16): #torch.cuda.amp.autocast():
                with torch.set_grad_enabled(phase == 'train'): 

                    datum[0] = datum[0].to(device) # img
                    # print(f'{datum[0].type()}')
                    datum[1] = datum[1].to(device) # label

                    output = model(datum[0]) #(torch.unsqueeze(datum[0], dim=0))

                    # features = model.forward_features(datum[0])
                    # print(torch.squeeze(features.cpu().detach()).numpy())
                    # print(torch.squeeze(output.cpu().detach()).numpy())

                    # loss = criterion(nn.functional.softmax(output, dim=-1), nn.functional.one_hot(datum[1].long(), num_classes=num_classes)) #torch.unsqueeze(datum[1].float(), dim=1))
                    loss = criterion(output, datum[1].long()) #torch.unsqueeze(datum[1].float(), dim=1))
                    # loss = criterion(log_softmax(output), nn.functional.one_hot(datum[1].long(), num_classes=num_classes))
                    running_loss.append(loss.item())
                    print(f'Batch loss : {loss.item()}')
                    
                    # # .backward() here accumulates gradients
                    if phase == 'train':
                        loss.backward() # do not call .backward within autocast. it is not recommended. 
                        # scaler.scale(loss).backward()
                
                if phase == 'train' and ((batch_idx+1)%batch_accumulation_steps==0):
                    optimizer.step()
                    # scaler.step(optimizer) # a step in the direction of gradients accumulated so far
                    # # Updates the scale for next iteration
                    # scaler.update()
                    
                    # scheduler.step()
                    
                    optimizer.zero_grad(set_to_none=True)
                                
                batch_idx += 1
                
            epoch_loss[phase] = np.sum(np.asarray(running_loss, dtype=np.float32)) / (len(dataloaders[phase].dataset))
            total_loss_history[phase].append(epoch_loss[phase])

            print(f'Epoch {epoch+1}, {phase} Loss : {epoch_loss[phase]:.4f}', flush=True)
                            
            # deep copy the model
            if phase == 'val' and epoch_loss[phase] < best_loss[phase]: 
                
                print(f'Better validation loss found : {epoch_loss["val"]}\t previous : {best_loss["val"]}', flush=True)
                
                best_loss['train'] = epoch_loss['train']
                best_loss['val'] = epoch_loss['val']
            
                checkpoint_path = f'{model_save_dir}{model_descriptor}.tar'
                torch.save({'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            # 'optimizer_state_dict': optimizer.state_dict(),
                            'loss': epoch_loss[phase],
                           }, checkpoint_path)
        
        if callbacks is not None:
            for callback in callbacks:
                callback.on_epoch_end(epoch+1, logs)
    
    print('Best validation loss : {0:.4f}\t found at epoch : {1:d}'.format(np.min(total_loss_history['val']), np.argmin(total_loss_history['val'])), flush=True)
    
    return model, total_loss_history

In [9]:
num_epochs = 20
# verbose = False

num_workers = 8
batch_size = 8
batch_accumulation_steps = 1
print(f'Batch size : {batch_size}\t batch accumulation steps : {batch_accumulation_steps}')
# criterion = nn.BCEWithLogitsLoss(reduction='sum')
criterion = nn.CrossEntropyLoss(reduction='sum')
# criterion = nn.NLLLoss(reduction='sum')

#SCHEDULER FOR COSINE DECAY
T_0 = 100 # 100 # 10
T_mult = 1 # 1 # 2
eta_min = 1e-5

base_lr = 1e-3 #5e-4 # 0.001
lr = base_lr * (float(batch_size*batch_accumulation_steps)/256) # base learning rate w.r.t. effective batch size
weight_decay = 1e-2
layer_decay = 0.75

# Note that shuffle is mutually exclusive with Sampler
shuffle_dict = {'train': True, 'val': False, 'test': False}

frozen_backbone = False

loss_histories = []

for cv_idx, metadata in enumerate(metadata_splits):
    
    print(f'Cross validation idx : {cv_idx}\tTraining split size : {metadata["train"].shape[0]}\tValidation split size : {metadata["val"].shape[0]}\tTest split size : {metadata["test"].shape[0]}')
    
    dim_mlp = 1024
    
    backbone = prepare_model(chkpt_dir, model_name)
    # no_weight_decay_list = backbone.no_weight_decay()
    model = DownstreamTask_Network(backbone, num_outputs=num_classes, frozen_backbone=frozen_backbone, hidden_dims=[dim_mlp])
        
    param_count = sum([p.numel() for p in model.parameters()])
    print(f'# of parameters in model : {param_count}')
    param_count = sum([p.numel() for p in model.parameters() if p.requires_grad])
    print(f'# of trainable parameters in model : {param_count}')
    model = model.float()
    model = model.to(device)
    # print(model)
    
    # set up the optimization objects
    # param_groups = lrd.param_groups_lrd(model, weight_decay, no_weight_decay_list=no_weight_decay_list, layer_decay=layer_decay)
    optimizer = torch.optim.AdamW([params for params in model.parameters() if params.requires_grad],
                                  lr=lr, 
                                  betas=(0.9, 0.999), eps=1e-08, 
                                  weight_decay=weight_decay, amsgrad=False 
                                 )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult, eta_min=eta_min, last_epoch=-1)
    
    # set up the datasets and loaders for training and validation splits
    datasets = OrderedDict()
    dataloaders = OrderedDict()
    for phase in ['train','val']:        
        print(f'Metadata split shape : {metadata[phase].shape}', flush=True)
        
        datasets[phase] = IDRiD_ImageDataset(metadata[phase], target_column='Retinopathy grade', 
                                             transforms=transforms[phase], target_transforms=None)
        dataloaders[phase] = DataLoader(datasets[phase], batch_size=batch_size,
                                        shuffle=shuffle_dict[phase], sampler=None, 
                                        num_workers=num_workers, pin_memory=True, drop_last=False)
    
    # train the model along with model selection based on validation resutls, e.g., validation loss
    model, total_loss_history = train_model(model, dataloaders, criterion, optimizer, scheduler, 
                                            num_epochs, batch_accumulation_steps, device, [PrintCallback()],
                                            model_save_dir, model_descriptor+str(cv_idx))

    loss_histories.append(total_loss_history)

    del backbone, model
    gc.collect()
    torch.cuda.empty_cache()
    torch.clear_autocast_cache()
    
    # break # cross-val.

# with open(f'{model_save_dir}{model_descriptor}LossHist.pkl', 'wb') as handle:
#     pickle.dump(loss_histories, handle, protocol=4)

del loss_histories

print('Done!')

Batch size : 8	 batch accumulation steps : 1
Cross validation idx : 0	Training split size : 417	Validation split size : 47	Test split size : 52
# of parameters in model : 303311882
# of trainable parameters in model : 303311882
Metadata split shape : (417, 5)
Metadata split shape : (47, 5)
Epoch 1/20
----------
# of examples for "train" : 417




  0%|          | 0/53 [00:00<?, ?it/s]

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

## Evaluate the models on the same splits as above

In [None]:
# run the models for inference on the same partitions as above.
# for both training and validation splits, calculate loss as follows:
#    for each prediction head, sum the loss over minibatches and divide by the total sample size

# criterion = nn.BCEWithLogitsLoss(reduction='sum')
criterion = nn.CrossEntropyLoss(reduction='sum')

# with open(f'{model_save_dir}{model_descriptor}LossHist.pkl', 'rb') as handle:
#     loss_histories = pickle.load(handle)

# print(f'Reading loss histories from {len(loss_histories)}-fold cross validation')

verbose = False

num_workers = 8
batch_size = 8 #128

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device : {device}')

avg_losses = []
performance_metrics = []

df_results = pd.DataFrame()
acc_col = []
roc_auc_col = []
avg_prec_col = []
f1_col = []
calib_error_col = []
cv_col = []
split_col = []
sampling_col = []

all_predictions = []
all_labels = []


for cv_idx, metadata in enumerate(metadata_splits):
    
    print(f'Cross validation idx : {cv_idx}\tTraining split size : {metadata["train"].shape[0]}\tValidation split size : {metadata["val"].shape[0]}\tTest split size : {metadata["test"].shape[0]}')
    
    dim_mlp = 1024
    
    backbone = prepare_model(chkpt_dir, 'vit_large_patch16')    
    model = DownstreamTask_Network(backbone, num_outputs=1, frozen_backbone=False, hidden_dims=[dim_mlp])
    
    checkpoint_path = f'{model_save_dir}{model_descriptor+str(cv_idx)}.tar'
    print(f'Loading : {checkpoint_path}')
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    del checkpoint

    model = model.to(device)
    # print(model)
        
    # set up the datasets and loaders for training and validation splits
    # datasets = OrderedDict()
    # dataloaders = OrderedDict()
    avg_losses.append(OrderedDict())
    performance_metrics.append(OrderedDict())

    for phase in ['test']: #['train','val','test']:        
        print(f'Metadata split shape : {metadata[phase].shape}', flush=True)

        dataset = IDRiD_ImageDataset(metadata[phase], target_column='Retinopathy grade', 
                                     transforms=transforms[phase], target_transforms=None)
        dataloader = DataLoader(dataset, batch_size=batch_size,
                                shuffle=False, sampler=None, 
                                num_workers=num_workers, pin_memory=True, drop_last=False)

        print(f'# of examples for "{phase}" : {len(dataloader.dataset)}', flush=True)
        
        model.eval()   # Set model to evaluate mode

        running_loss = [] 
        predictive_prob = [] 
        labels = [] 
        performance_metrics[-1][phase] = OrderedDict()
        batch_idx = 0 

        for datum in tqdm(iter(dataloader)): # Each is of shape B x K x 3 x H x W

            # Casts operations to mixed precision
            # forward
            # with torch.autocast(device_type='cuda', dtype=torch.float16): #torch.cuda.amp.autocast():

            with torch.set_grad_enabled(False): 
                
                datum[0] = datum[0].to(device) # img
                datum[1] = datum[1].to(device) # label
                
                output = model(datum[0]) #(torch.unsqueeze(datum[0], dim=0))
                
                loss = criterion(output, torch.unsqueeze(datum[1], dim=1).float())
                running_loss.append(loss.item())
                # running_loss.append(np.sum(np.abs(np.subtract(output.cpu().detach().numpy(), datum[1].cpu().detach().numpy()))))
                
                predictive_prob.append(nn.functional.sigmoid(output).cpu().detach().numpy())
                labels.append(datum[1].cpu().detach().numpy())
                    

            # # if ((batch_idx+1)%(2*batch_accumulation_steps)==0):
            # if batch_idx % 10 == 0:
            #     print(f'Batch idx : {batch_idx}\tLoss : {running_loss[-1]}', flush=True)
            
            batch_idx += 1
            
        print(f'{np.sum(np.asarray(running_loss, dtype=np.float32))} / {len(dataloader.dataset)}')
        avg_losses[-1][phase] = np.sum(np.asarray(running_loss, dtype=np.float32)) / len(dataloader.dataset)
        
        predictive_prob = np.concatenate(predictive_prob, axis=0)
        labels = np.concatenate(labels, axis=0)
        predictions = np.asarray((predictive_prob >= 0.5), dtype=np.int64)

        all_predictions.append(predictions)
        all_labels.append(labels)
        
        acc = accuracy_score(labels, predictions)
        performance_metrics[-1][phase]['accuracy'] = acc
        print(f'Accuracy : {acc}')
        
        roc_auc = roc_auc_score(labels, predictive_prob)
        performance_metrics[-1][phase]['roc-auc'] = roc_auc
        print(f'ROC-AUC : {roc_auc}')

        avg_prec = average_precision_score(labels, predictive_prob)
        performance_metrics[-1][phase]['avg-prec'] = avg_prec
        print(f'Avg. Prec. : {avg_prec}')

        f1 = f1_score(labels, predictions)
        performance_metrics[-1][phase]['f1'] = f1
        print(f'F1 : {f1}')

        calib_error = rp.smECE(np.squeeze(predictive_prob), labels)
        performance_metrics[-1][phase]['ECE'] = calib_error
        print(f'ECE : {calib_error}')


        acc_col.append(acc)
        roc_auc_col.append(roc_auc)
        avg_prec_col.append(avg_prec)
        f1_col.append(f1)
        calib_error_col.append(calib_error)
        cv_col.append(cv_idx)
        split_col.append(phase)
        sampling_col.append("RETFound")
        
        
    # for phase in ['train','val']:
    #     print(f'-- Task losses {phase} : {avg_losses[-1][phase]} --')
    # print(f'Loss [train] : {avg_losses[-1]["train"]}\t[val] : {avg_losses[-1]["val"]}\nAccuracy [train] : {performance_metrics[-1]["train"]["accuracy"]}\t[val] : {performance_metrics[-1]["val"]["accuracy"]}\nROC-AUC [train] : {performance_metrics[-1]["train"]["roc-auc"]}\t[val] : {performance_metrics[-1]["val"]["roc-auc"]}\nAvg. Prec. [train] : {performance_metrics[-1]["train"]["avg-prec"]}\t[val] : {performance_metrics[-1]["val"]["avg-prec"]}\nF1 [train] : {performance_metrics[-1]["train"]["f1"]}\t[val] : {performance_metrics[-1]["val"]["f1"]}')
    # print(f'Loss [train] : {avg_losses[-1]["train"]}\t[val] : {avg_losses[-1]["val"]}\t[test] : {avg_losses[-1]["test"]}\nAccuracy [train] : {performance_metrics[-1]["train"]["accuracy"]}\t[val] : {performance_metrics[-1]["val"]["accuracy"]}\t[test] : {performance_metrics[-1]["test"]["accuracy"]}\nROC-AUC [train] : {performance_metrics[-1]["train"]["roc-auc"]}\t[val] : {performance_metrics[-1]["val"]["roc-auc"]}\t[test] : {performance_metrics[-1]["test"]["roc-auc"]}\nAvg. Prec. [train] : {performance_metrics[-1]["train"]["avg-prec"]}\t[val] : {performance_metrics[-1]["val"]["avg-prec"]}\t[test] : {performance_metrics[-1]["test"]["avg-prec"]}\nF1 [train] : {performance_metrics[-1]["train"]["f1"]}\t[val] : {performance_metrics[-1]["val"]["f1"]}\t[test] : {performance_metrics[-1]["test"]["f1"]}')

    del backbone, model
    gc.collect()
    torch.cuda.empty_cache()
    
    # break # cross-val.

# with open(f'{model_save_dir}{model_descriptor}CrossValPerformance.pkl', 'wb') as handle:
#     pickle.dump([avg_losses, 
#                  performance_metrics], handle, protocol=4)

# del avg_losses, performance_metrics


df_results['Fold'] = cv_col
df_results['Split'] = split_col
df_results['Sampling'] = sampling_col
df_results['Accuracy'] = acc_col
df_results['ROC-AUC'] = roc_auc_col
df_results['Avg. Precision'] = avg_prec_col
df_results['F1 Score'] = f1_col
df_results['ECE'] = calib_error_col


# df_results.to_csv('ICL4Ophthalmology_IDRiD_Binary_RETFound_3splits_FIXED.csv', index=False)


print('Done!')

all_predictions = np.concatenate(all_predictions, axis=0)
print(f'All predictions shape : {all_predictions.shape}')
print(all_predictions)
all_labels = np.concatenate(all_labels, axis=0)
print(f'All labels shape : {all_labels.shape}')
print(all_labels)


# with open(f'IDRiD_Binary_PredictionsRETFound_FIXED.npy', 'wb') as handle:
#     # pickle.dump(out_data, handle, protocol=4)
#     np.save(handle, all_predictions)
#     np.save(handle, all_labels)