# Imports:

In [2]:
import pandas as pd
import numpy as np
import skimage.io
from glob import glob
from tqdm import tqdm
import os
import cv2
import torch
from torchvision.models import resnet50
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import math

# Constants:

In [3]:
PERCENT_SAMPLING = 0.1
BASELINE = False
MODEL_LOAD_DIR = './models'
BATCH_SIZE = 15
PATCH_SIZE = 128
SCALE_FACTOR = 4
IMAGE_SIZE = 512*SCALE_FACTOR
WARMUP_EPOCHS = 2
PATCH_BATCHES = math.ceil(1/PERCENT_SAMPLING)
INNER_ITERATION = PATCH_BATCHES
ACCELARATOR = 
LATENT_DIMENSION = 256
NUM_CLASSES = 6
SEED = 42
STRIDE = PATCH_SIZE
NUM_PATCHES = ((IMAGE_SIZE-PATCH_SIZE)//STRIDE) + 1
NUM_WORKERS = 4
TRAIN_ROOT_DIR = f'..\\data\\pandas_dataset\\training_images_{IMAGE_SIZE}'
VAL_ROOT_DIR = TRAIN_ROOT_DIR
TRAIN_CSV_PATH = f'..\\data\\pandas_dataset\\train_kfold.csv'
MEAN = [0.9770, 0.9550, 0.9667]
STD = [0.0783, 0.1387, 0.1006]
MODEL_SAVE_DIR = f"./"
DECAY_FACTOR = 1
VALIDATION_EVERY = 1

# Dataset Classes

In [4]:
def get_best_resolution_index(io_object,dimension):
    best_index = len(io_object) - 1
    for i in reversed(range(len(io_object))):
        shape = io_object[i].shape
        if dimension > min(shape[0],shape[1]):
            break
        best_index = i
    return best_index


class PandasDataset(Dataset):
    def __init__(self,df,root_dir,transforms=None):
        self.df = df
        self.root_dir = root_dir
        self.transforms = transforms
    def __len__(self):
        return len(self.df)
    def __getitem__(self,index):
        image_id = self.df.iloc[index].image_id
        path = glob(self.root_dir+f'\\{image_id}.tiff')
        biopsy = skimage.io.MultiImage(path)
        best_index = get_best_resolution_index(biopsy,IMAGE_SIZE)
        label = self.df.iloc[index].isup_grade
        image = cv2.imread(f"{self.root_dir}/{image_id}.png")
        im = biopsy[best_index]
        old_size = im.shape[:2] # old_size is in (height, width) format
        ratio = float(IMAGE_SIZE)/max(old_size)
        new_size = tuple([int(x*ratio) for x in old_size])
        # new_size should be in (width, height) format
        im = cv2.resize(im, (new_size[1], new_size[0]))
        delta_w = IMAGE_SIZE - new_size[1]
        delta_h = IMAGE_SIZE - new_size[0]
        top, bottom = delta_h//2, delta_h-(delta_h//2)
        left, right = delta_w//2, delta_w-(delta_w//2)
        color = [255, 255, 255]
        new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT,value=color)
        if self.transforms is not None:
            image = self.transforms(image)
        return image, torch.tensor(label)

class PatchDataset(Dataset):
    def __init__(self,images,num_patches,stride,patch_size):
        self.images = images
        self.num_patches = num_patches
        self.stride = stride
        self.patch_size = patch_size
    def __len__(self):
        return self.num_patches ** 2
    def __getitem__(self,choice):
        i = choice%self.num_patches
        j = choice//self.num_patches
        return self.images[:,:,self.stride*i:self.stride*i+self.patch_size,self.stride*j:self.stride*j+self.patch_size], choice

# Models:

In [None]:
class Backbone(nn.Module):
    def __init__(self,baseline,latent_dim):
        super(Backbone,self).__init__()
        # self.encoder = timm.create_model('resnet10t',pretrained=True)
        self.encoder = resnet50(pretrained=True)
        if baseline:
            self.encoder.fc = nn.Linear(2048,NUM_CLASSES)
        else:
            self.encoder.fc = nn.Linear(2048,latent_dim)
    def forward(self,x):
        return self.encoder(x)
 
class CNN_Block(nn.Module):
    def __init__(self,latent_dim,num_classes,num_patches):
        super(CNN_Block,self).__init__()
        self.expected_dim = (2,latent_dim,num_patches,num_patches)
        self.layer1 = nn.Sequential(
            nn.Conv2d(latent_dim,latent_dim,3,1,1), 
            nn.ReLU(),
            nn.BatchNorm2d(latent_dim)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(latent_dim,latent_dim,3,2,1), 
            nn.ReLU(),
            nn.BatchNorm2d(latent_dim)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(latent_dim,latent_dim,3,2,1), 
            nn.ReLU(),
            nn.BatchNorm2d(latent_dim)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(latent_dim,latent_dim,3,2,1), 
            nn.ReLU(),
            nn.BatchNorm2d(latent_dim)
        )
        self.dropout = nn.Dropout2d(p=0.2)
        flatten_dim = self.get_final_out_dimension(self.expected_dim)
        self.linear = nn.Linear(flatten_dim,num_classes)

    def get_output_shape(self, model, image_dim):
        return model(torch.rand(*(image_dim))).data.shape

    def get_final_out_dimension(self,shape):
        s = shape
        s = self.get_output_shape(self.layer1,s)
        s = self.get_output_shape(self.layer2,s)
        s = self.get_output_shape(self.layer3,s)
        s = self.get_output_shape(self.layer4,s)
        return np.prod(list(s[1:]))

    def forward(self,x,print_shape=False):
        x = self.layer1(x)
        if print_shape:
            print(x.size())
        x = self.dropout(x)
        x = self.layer2(x)
        if print_shape:
            print(x.size())
        x = self.dropout(x)
        x = self.layer3(x)
        if print_shape:
            print(x.size())
        x = self.dropout(x)
        x = self.layer4(x)
        if print_shape:
            print(x.size())
        x = x.reshape(x.shape[0],-1)
        x = self.linear(x)
        if print_shape:
            print(x.size())
        return x


# Inference functions:

In [None]:
@torch.no_grad()
def make_baseline_inference():
    torch.cuda.empty_cache()
    model = Backbone(True,0)
    model.eval()
    model.to(ACCELARATOR)
    timing = []

    inputs = torch.rand(batch_size,3,image_size,image_size)
    inputs = inputs.to(ACCELARATOR)

    # warmup
    for _ in range(WARMUP):
        model(inputs)

    torch.cuda.synchronize()
    for _ in range(NUM_ITERATIONS):
        start = time.time()
        model(inputs)
        torch.cuda.synchronize()
        timing.append(time.time() - start)

    timing = torch.as_tensor(timing, dtype=torch.float32)
    return batch_size / timing.mean()
