# Load Data & Train Model
This notebook allows us to load the entire dataset, split it into a proper train/test split and trains the model using K-Fold cross validation.

## Imports

In [None]:
# Essentials
import math
import random
import os
import copy
import glob
import shutil
import numpy as np
import json
import importlib
import time as time

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, Conv2d

import torchvision.models as models
from torchvision import datasets, transforms

from torch.utils.data import Dataset, DataLoader

import torch.optim as optim
import timm
from timm.models import create_model
from timm.data import create_transform
from sklearn.metrics import accuracy_score

# Images
import albumentations
import albumentations.pytorch

import cv2

from PIL import Image

# Machine Learning
from sklearn.model_selection import KFold

from barbar import Bar

import utils
importlib.reload(utils)

from utils import get_data, select_gpu, get_model, get_class_weigths, get_files
from utils import My_data, CustomTransforms, FocalLoss

## Set Device

In [None]:
device = torch.device(f"cuda:{select_gpu()}" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
torch.cuda.device_count()

In [None]:
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary(device=None, abbreviated=False))
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

## Load Dataset

This downloads the dataset to the server and unzips it so we can use it. Check the size of the folder before running the rest of the code. Should be around 4GB

In [None]:
# All eight classes
classes = ["A", "F", "TA", "PT", "DC", "LC", "MC", "PC"]
benign_classes = classes[:4]
malignant_classes = classes[4:]

# All four zoom levels
zooms = ["40", "100", "200", "400"]

In [None]:
model_name = "swin"

org_dataset = './dataset/train/original/**/**/*.png'
pert_dataset = f'./dataset/train/pgd_attack/{model_name}/**/**/*.png'

dataset = get_files(org_dataset) + get_files(pert_dataset)
print(len(dataset))

# Create the train_dict
train_dict = {c: {z: [path for path in dataset if path.split("_")[-1].split("-")[0] == c and path.split("_")[-1].split("-")[3] == z] for z in zooms} for c in classes}

## Prepare Data
In this section we do the following:
1. Store all image paths in a dictionary that is filtered on class and zoom level
2. Create a stratified train/test split (based on class and zoom level) and store the image paths of both sets again in a filtered dictionary
3. Copy all images from the raw data source to the structured data folder "dataset/"
4. Create a 5-fold cross validation split for the train dataset
5. Prepare all splits for forward pass of the model

### Prepare Data for Forward Pass

In [None]:
# initialize transformers
transform = CustomTransforms()

# Create data loaders for each fold
train_dataloader = DataLoader(My_data(dataset, transforms=transform.get_transform("train")), batch_size=4,shuffle=True,num_workers=2, pin_memory=True,prefetch_factor=2)

## Create Model

In [None]:
model = get_model(device, model=model_name)

In [None]:
print(len([param for param in model.named_parameters()]))

# Iterate over the parameters and check requires_grad
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Parameter '{name}' requires grad.")
    else:
        print(f"Parameter '{name}' does not require grad.")
    
    param.requires_grad = True

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, class_weights=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.class_weights = class_weights

    def forward(self, logits, labels):
        probs = torch.sigmoid(logits)
        ce_loss = nn.BCELoss()(probs, labels)
        # print(type(probs), probs, self.gamma)
        weight = (1 - probs).pow(self.gamma)
        loss = ce_loss  # Initialize loss with cross-entropy loss
        if self.class_weights is not None:
            weight = weight * self.class_weights
            loss = loss * weight
        return loss

## Train Model

This section performs the actual training of the model. We first determine methods fit and validate that will be called during the training. Afterwards, we define the loop that optimizes the model.

In [None]:
def fit(model, dataloader, optimizer,scheduler, criterion):
    #print('Training')
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    accum_iter = 4

    for i, (inputs, labels) in enumerate(Bar(dataloader)):
        inputs = inputs.to(device)
        labels = labels.float().to(device)
        optimizer.zero_grad()
        #model.zero_grad(set_to_none=True)
        # Forward pass - compute outputs on input data using the model
        outputs = model(inputs)
        thresholds = [0.5, 0.5, 0.5,0.5,0.5,0.5,0.5,0.5]
        # Compute loss
        loss = criterion(outputs, labels)
        train_running_loss += loss.item()* inputs.size(0)
        # _ , preds = torch.max(outputs.data, 1)
        # Apply sigmoid activation to obtain probabilities
        #preds = (outputs > 0.5).float()
        probs = torch.sigmoid(outputs)
        preds = torch.zeros_like(probs)

        # Set predicted labels based on the threshold
        for i, threshold in enumerate(thresholds):
            preds[:, i] = (probs[:, i] >= threshold).float()
        train_running_correct += (preds == labels).all(dim=1).float().sum()
        # Backpropagate the gradients
        loss /= accum_iter
        loss.backward()

        if ((i + 1) % accum_iter == 0) :
            optimizer.step()
            optimizer.zero_grad()

    scheduler.step()

    train_loss = train_running_loss/len(dataloader.dataset)
    train_accuracy = 100. * train_running_correct/len(dataloader.dataset)
    return train_loss, train_accuracy

In [None]:
history=[]

if torch.cuda.is_available():
    torch.cuda.empty_cache()

model = model.to(device)

class_weights = get_class_weigths(train_dict).to(device)
class_weights = class_weights.to(device)

criterion = FocalLoss(class_weights)

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)

best_acc = 0.0
epochs=50

for epoch in range(epochs):
    epoch_start = time.time()
    print('Epoch-{0}/{1} lr: {2}'.format(epoch+1,epochs ,optimizer.param_groups[0]['lr']))
    
    # # Why is this here???
    # if  epoch > 14:
    #     for param in model.parameters():
    #         param.requires_grad = True
    #print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_accuracy = fit(model,train_dataloader,optimizer,scheduler,criterion)


    epoch_end = time.time()
    history.append([epoch+1,train_epoch_loss, train_epoch_accuracy,(epoch_end-epoch_start)])
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f},time : {epoch_end-epoch_start:.2f}")
    torch.save({'history':history}, f'models/Master_retrained_{model_name}_his.pth')
    if train_epoch_accuracy > best_acc:
        best_acc = train_epoch_accuracy
        best_model_wts = copy.deepcopy(model.state_dict())

        best_epoch=epoch
        torch.save({
            'epoch': epoch+1,
            'model_state_dict': best_model_wts,
            'loss': criterion,
            'history':history,
            'best_epoch': best_epoch+1,

            }, f'models/Master_retrained_{model_name}.pth')