### Data source: https://www.bracs.icar.cnr.it/

### Import packages

In [None]:
import numpy as np 
import pandas as pd 
import os
import cv2
import random
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
from collections import Counter
from sklearn.preprocessing import LabelBinarizer, LabelEncoder

import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset
import torch.nn as nn
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data.sampler import Sampler
import json

from sklearn.metrics import f1_score, precision_recall_curve
from sklearn.metrics import classification_report

import timm

from omegaconf import OmegaConf
# Load config
preproc_conf = OmegaConf.load("../conf/preproc.yaml")
preproc_conf = preproc_conf['classic_mil_on_embeddings_bag']['bracs_224_224_patches']

In [None]:
CUDANUM = 0

### Locate annotations

In [None]:
parent_folder = preproc_conf.img_dir_lvl4
parent_folder

In [None]:
os.makedirs(preproc_conf.weights_dir, exist_ok=True)

### Load data

In [None]:
%%time
x_train = np.load( parent_folder+'bracs_level4_regions_224_training_data_macenkonorm_bracs.npy')
y_train = np.load( parent_folder+'bracs_level4_regions_224_training_label.npy')

x_val = np.load( parent_folder+'bracs_level4_regions_224_validation_data_macenkonorm_bracs.npy')
y_val = np.load( parent_folder+'bracs_level4_regions_224_validation_label.npy')

x_test = np.load( parent_folder+'bracs_level4_regions_224_test_data_macenkonorm_bracs.npy')
y_test = np.load( parent_folder+'bracs_level4_regions_224_test_label.npy')

x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape

### Preprocess label

In [None]:
#Binary encode
lb = LabelEncoder()
#lb = LabelBinarizer()
lb.fit(y_train)
label_oh = lb.transform(y_train)

y_train_oh = lb.transform(y_train)
y_val_oh = lb.transform(y_val)
y_test_oh = lb.transform(y_test)

In [None]:
label_oh.shape, y_test_oh

In [None]:
lb.classes_

In [None]:
Counter(label_oh)

### Create balanced data loader -> balanced folds with subset

In [None]:
def create_balanced_biopsy_subset(labels, minority_class_ratio=0.2, rnd_seed=38):
    # set random seed as given
    np.random.seed(rnd_seed)
    
    # collect selected biopsies that will be in the balanced subset
    test_local_idx = []
    
    # get current class occurences for biopsy
    class_occurence = np.array(list(dict( Counter(labels) ).values()))[ np.argsort(list(dict( Counter(labels) ).keys()))]
    #print(class_occurence)
    
    # calc class weights
    class_weights = ( class_occurence / class_occurence.sum() ).astype(np.float32)
    class_weights_dict = dict( zip( np.arange(class_weights.shape[0]), class_weights ))
    #print(class_weights_dict)
    
    # how many of biopsies to include in the balanced subset
    nr_class_test = int(labels.shape[0]*np.min(class_weights)*minority_class_ratio)

    # collect biopsy indices for the balanced subset
    for s in np.unique(labels): #loop over labelss
        s_idx = np.arange(labels.shape[0])[labels == s]
        rnd_idx = np.random.permutation(s_idx.shape[0])
        test_local_idx.append(s_idx[rnd_idx[:nr_class_test]])

    # aggregate all the balanced subset's indices
    test_idx = np.concatenate(test_local_idx)
    
    # other indices not in balanced set will be the rest
    train_idx = np.arange(labels.shape[0])[~np.in1d(np.arange(labels.shape[0]), test_idx)]
    
    return train_idx, test_idx#, label_remaining[]

In [None]:
def create_biopsy_subset(labels, minority_class_ratio=0.2, rnd_seed=38):
    # set random seed as given
    np.random.seed(rnd_seed)
    
    # collect selected biopsies that will be in the balanced subset
    test_local_idx = []
    
    # get current class occurences for biopsy
    class_occurence = np.array(list(dict( Counter(labels) ).values()))[ np.argsort(list(dict( Counter(labels) ).keys()))]
    #print(class_occurence)
    
    # calc class weights
    class_weights = ( class_occurence / class_occurence.sum() ).astype(np.float32)
    class_weights_dict = dict( zip( np.arange(class_weights.shape[0]), class_weights ))
    #print(class_weights_dict)
    
    # how many of biopsies to include in the balanced subset
    nr_class_test = int(labels.shape[0]*np.min(class_weights)*minority_class_ratio)

    # collect biopsy indices for the balanced subset
    for s in np.unique(labels): #loop over labelss
        s_idx = np.arange(labels.shape[0])[labels == s]
        rnd_idx = np.random.permutation(s_idx.shape[0])
        test_local_idx.append(s_idx[rnd_idx[:nr_class_test]])

    # aggregate all the balanced subset's indices
    test_idx = np.concatenate(test_local_idx)
    
    # other indices not in balanced set will be the rest
    train_idx = np.arange(labels.shape[0])[~np.in1d(np.arange(labels.shape[0]), test_idx)]
    
    return train_idx, test_idx#, label_remaining[]

In [None]:
def give_back_balanced_training_fold( X_current, y_current,
                                      minority_class_ratio=1.0, rnd_seed=12 ):
    
    _, test_idx, = create_biopsy_subset(y_current,
                                                 minority_class_ratio,
                                                 rnd_seed)
    X_train_balanced = X_current[test_idx]
    y_train_balanced = y_current[test_idx]
    #y_train_balanced_oh = lb.transform(y_train_balanced)
    #print( X_train_balanced.shape, y_train_balanced_oh.shape )
    
    return X_train_balanced, y_train_balanced

## Dataloader

In [None]:
class CollectionsDataset(Dataset):
    def __init__(self,
                 data,
                 labels,
                 num_classes, 
                 transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.num_classes = num_classes

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)

        return {'image': image,
                'label': label
                }

In [None]:
#https://github.com/pytorch/pytorch/issues/7359
class BalancedSampler(Sampler):
    def __init__(self,
                 batch_size,
                 data,
                 labels,
                 num_classes, 
                 transform=None,
                 rand_seed=12):
        self.batch_size = batch_size
        self.data = data
        self.labels = labels
        self.transform = transform
        self.num_classes = num_classes
        self.rand_seed = rand_seed
        

        
    def give_back_balanced_training_fold(self, X_current, y_current, minority_class_ratio, rand_seed):
        _, test_idx = create_biopsy_subset(y_current, minority_class_ratio, rand_seed)
        
        return test_idx
        

    def __len__(self):
        test_idx = self.give_back_balanced_training_fold(self.data, self.labels, minority_class_ratio=1.0, rand_seed=self.rand_seed)
        
        return len(test_idx) #3220

    def __iter__(self):
        test_idx = self.give_back_balanced_training_fold(self.data, self.labels, minority_class_ratio=1.0, rand_seed=self.rand_seed)
        
        random.shuffle(test_idx)
        num_batches = len(test_idx) // self.batch_size - 1 #99
        
        n=0
        while num_batches > 0:
            
            sampled = test_idx[n*32:(n+1)*32]
                
            yield sampled #(32, 3, 224, 224)
            num_batches -=1
            n += 1

In [None]:
class BatchBalancedSampler(Sampler):

    def __init__(self, sampler, batch_size):
        self.sampler = sampler
        self.batch_size = batch_size

    def __iter__(self):
        for _, idx in enumerate(iter(self.sampler)): #99
            batch = idx
            yield batch

    def __len__(self):
        return len(self.sampler) // self.batch_size - 1 #99

## Model

#### very small LR for backbone

In [None]:
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# alternative
# model = models.resnet18(pretrained=True)

In [None]:
model.fc

In [None]:
## HERE THIS IS SIMPLEHEAD !  ## CHOOSE ##
model.fc = nn.Linear(in_features=2048, out_features=7)


## HERE THIS IS COMPLEXHEAD !  ## CHOOSE ##
"""
model.fc = nn.Sequential(
    nn.Linear(in_features=2048, out_features=128),
    nn.ReLU(),
    nn.BatchNorm1d(128),
    nn.Dropout(p=0.5),
    
    nn.Linear(in_features=128, out_features=128),
    nn.ReLU(),
    nn.BatchNorm1d(128),
    nn.Dropout(p=0.25),

    nn.Linear(in_features=128, out_features=7),

    #nn.Softmax(1)
)
"""

In [None]:
# Define two sets of parameters: one for the backbone and one for the head
backbone_params = [param for name, param in model.named_parameters() if 'fc' not in name]
head_params = model.fc.parameters()

## Traning loop

In [None]:
weights_folder = preproc_conf.weights_dir+'weights_train_resnet50_smallLRbackbone_simplehead_level4_macenko_bracs_50epochs/'

os.makedirs(weights_folder, exist_ok=True)


os.makedirs(weights_folder, exist_ok=True)

def train_model(model,
                device,
                transform,
                optimizer, 
                scheduler, 
                num_epochs):
    
    history = {}
    history_train_loss = []
    history_train_acc = []
    history_val_loss = []
    history_val_acc = []
    

    criterion = nn.CrossEntropyLoss()
    
    
    train_dataset = CollectionsDataset(data=x_train,
                                       labels=y_train_oh,
                                       num_classes=NUM_CLASSES,
                                       transform=transform)
    
    # VAL dataset
    val_dataset = CollectionsDataset(data=x_val,
                                     labels=y_val_oh,
                                     num_classes=NUM_CLASSES,
                                     transform=transform)

    # create the pytorch data loader
    val_dataset_loader = torch.utils.data.DataLoader(val_dataset,
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=True,
                                                     num_workers=4)
    
    
    
    # training loop wiht balanced folds
    for epoch in range(0, num_epochs):
        
        sampler = BalancedSampler(
                     batch_size=32,
                     data=x_train,
                     labels=y_train_oh,
                     num_classes=NUM_CLASSES, 
                     transform=None,
                     rand_seed=int(epoch*1.5+3*epoch))

        batch_sampler = BatchBalancedSampler(sampler, batch_size=BATCH_SIZE)


        # create the pytorch data loader
        train_dataset_loader = torch.utils.data.DataLoader(train_dataset,
                                                           num_workers=4,
                                                           batch_sampler=batch_sampler)
        
        
        
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        

        #scheduler.step()
        model.train()

        running_loss = 0.0
        correct = 0
        total = 0
        
        # Iterate over data.
        
        
        for bi, d in enumerate(train_dataset_loader):
            
            inputs = d["image"]
            labels = d["label"]
            inputs = inputs.to(device, dtype=torch.float, non_blocking=True)
            labels = labels.to(device, dtype=torch.long,  non_blocking=True)

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            
            #acc metrics
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        checkpoint = { 
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            #'scheduler': scheduler
        }
            
        epoch_loss = running_loss / len(sampler) #(len(sampler) -52)
        writer.add_scalar("Train CE Loss", epoch_loss)
        print('Train CE Loss: {:.4f}'.format(epoch_loss), f'----- Train Accuracy: {100 * correct // total} %')
        
        
        history_train_loss.append(epoch_loss)
        history_train_acc.append(100 * correct // total)
        
        
        # VALIDATION
        correct_val = 0
        total_val = 0
        running_loss_val = 0
        # since we're not training, we don't need to calculate the gradients for our outputs
        with torch.no_grad():
            for bi, d in enumerate(val_dataset_loader):
                inputs = d["image"]
                labels = d["label"]
                inputs = inputs.to(device, dtype=torch.float)
                labels = labels.to(device, dtype=torch.long)
                # calculate outputs by running images through the network
                outputs = model(inputs)
                # the class with the highest energy is what we choose as prediction
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
                
                loss_val = criterion(outputs, labels)
                
                running_loss_val += loss_val.item() * inputs.size(0)
                
        epoch_loss_val = running_loss_val / x_val.shape[0]
        print('Val CE Loss: {:.4f}'.format(epoch_loss_val), f'----- Val Accuracy: {100 * correct_val // total_val} %')
        print('-' * 10)
        
        history_val_loss.append(epoch_loss_val)
        history_val_acc.append(100 * correct_val // total_val)
        
        if (100 * correct_val // total_val) > 44:
            print(weights_folder+f'checkpoint_epoch_{epoch}'+\
                   '_{:.4f}'.format(epoch_loss_val)+f'_{100 * correct_val // total_val} %.pth')
            torch.save(checkpoint, weights_folder+f'checkpoint_epoch_{epoch}'+\
                   '_{:.4f}'.format(epoch_loss_val)+f'CE_{100 * correct_val // total_val}_acc.pth')

    history['train_loss'] = history_train_loss
    history['train_acc'] = history_train_acc
    history['val_loss'] = history_val_loss
    history['val_acc'] = history_val_acc
    
    writer.flush()
        
    return model, history

## Transforms, augmentation

In [None]:
# check augmentations

In [None]:
dummy_transform = transforms.Compose([transforms.ToTensor(), transforms.RandomResizedCrop(size=224, scale=(0.55,1))])

dummy_transform2 = transforms.Compose([transforms.ToTensor(), transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)])

dummy_transform3 = transforms.Compose([transforms.ToTensor(), transforms.RandomAdjustSharpness(sharpness_factor=3, p=0.5)])

dummy_transform4 = transforms.Compose([transforms.ToTensor(), transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 3))])


dummy_dataset = CollectionsDataset(data=x_train,
                                       labels=y_train_oh,
                                       num_classes=7,
                                       transform=dummy_transform4)

plt.imshow(np.swapaxes(dummy_dataset[0]['image'],0,2))

In [None]:
dummy_dataset = CollectionsDataset(data=x_train,
                                       labels=y_train_oh,
                                       num_classes=7,
                                       transform=None)
plt.imshow(dummy_dataset[0]['image'])

In [None]:
# https://stackoverflow.com/questions/51677788/data-augmentation-in-pytorch

# define some re-usable stuff
IMAGE_SIZE = 224
NUM_CLASSES = 7
BATCH_SIZE = 32
device = torch.device(f'cuda:{CUDANUM}')


train_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # IMAGENET !
     #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     transforms.RandomResizedCrop(size=224, scale=(0.55,1)),
     #transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
     transforms.RandomAdjustSharpness(sharpness_factor=3, p=0.5),
     transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 3))
    ])

# push model to device
model = model.to(device)

### Train UNI with balanced dataset

## Optimizer and scheduler

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [None]:
def seed_torch(seed=7):
    import random
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device=torch.device(f'cuda:{CUDANUM}' if torch.cuda.is_available() else "cpu") 
    if device.type == f'cuda:{CUDANUM}':
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
# Set a very small learning rate for the backbone and a higher one for the head
optimizer_ft = optim.Adam([
    {'params': backbone_params, 'lr': 1e-5, 'amsgrad': True},  
    {'params': head_params, 'lr': 1e-4, 'amsgrad': True}      
])

lr_sch = None # lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.8) # was 10 and 0.8
device = torch.device(f'cuda:{CUDANUM}' if torch.cuda.is_available() else 'cpu') #device config

seed_torch() # FIX SEED 
model_ft, history = train_model(model,
                       device,
                       train_transform,
                       optimizer_ft,
                       lr_sch,
                       num_epochs=50)

In [None]:
plt.figure(figsize=(12,8))
plt.plot( np.arange(len(history['train_loss'])), history['train_loss'], label='train' )
plt.plot( np.arange(len(history['val_loss'])), history['val_loss'], label='val' )
plt.legend(fontsize=18)
plt.xlabel('epochs', fontsize=22)
plt.ylabel('Cross entropy loss', fontsize=22)
plt.tick_params(labelsize=18)
#plt.yscale('log')
plt.title('mcr=0.5, bs=32, AdamW, lr=1e-4, StepLR(step_size=5, gamma=0.8)')
#plt.savefig('test3_aug3')

### Save plot data as well !!

In [None]:
with open(f'{weights_folder}history.json', 'w') as f:
    json.dump(history, f, indent=4)

In [None]:
weights_folder