# Preparatory code

## Imports

In [45]:
from d2l import torch as d2l
import numpy as np
from torch import nn
import torchvision
from torchvision.transforms import v2
import torch
import cv2
import os

## Class support functions

In [25]:
def add_to_class(Class):    #@save
    """Register functions as methods in the created class"""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

# Dataloader

First, a generic data loader.

In [26]:
class DataModule(d2l.HyperParameters):  #@save
    """The base class for data."""
    def __init__(self, root='../data', num_workers=4):
        self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError
    
    def train_dataloader(self):
        return self.get_dataloader(train=True)
    
    def val_dataloder(self):
        return self.get_dataloader(train=False)
    
    def get_tensorloader(self, tensors, train, indices=slice(0, None)):
        tensors = tuple(a[indices] for a in tensors)
        dataset = torch.utils.data.TensorDataset(*tensors)
        return torch.utils.data.DataLoader(dataset, self.batch_size, shuffle=train)

On to the implementation of our dataloader.
The structure of `self.train` (or `self.val`) is the following:
* Input images (as an array)
* Tuple (image tensor, label)
* label is an int
* image tensor is (channels, height, width)


In [27]:
def open_image(path):
    return NotImplementedError

In [54]:
def open_image_placeholder(path, transf=v2.ToTensor()):
    img = []
    for f in os.listdir(path):
        img.append((transf(cv2.imread(f"{path}/{f}")), 0))
    
    return img 

open_image_placeholder('../dataset/TestA/Basophil')[0][0].shape

torch.Size([3, 575, 575])

In [57]:
class CroppedCells(DataModule):
    """The dataset of cropped cells."""
    def __init__(self, transf = None, resize = None, batch_size=64):
        super().__init__()
        self.save_hyperparameters()
        self.train = open_image_placeholder('../dataset/TestA/Basophil')

    def get_dataloader(self, train=True):
        data = self.train if train else self.val
        return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train, num_workers=self.num_workers)

data = CroppedCells()
X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)

torch.Size([64, 3, 575, 575]) torch.float32 torch.Size([64]) torch.int64


# Models

In [None]:
class Module(nn.Module, d2l.HyperParameters):   #@save
    """The base class of models"""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError
    
    def forward(self, X):
        assert hasattr(self, "net"), "Neural network is defined"
        return self.net(X)
    
    def plot(self, key, value, train):
        """Plot a point in animation"""

        assert hasattr(self, "trainer"), "Trainer is not initiated"
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            '''
            Note: backward slash means newline. Forward slash is the usual division. 
            '''
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
            
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
            
        self.board.draw(x, value.to(d2l.cpu()).detach().numpy(),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))
        
    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l
    
    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

## ResNet

# Training

## ResNet v0.1