# Install all necessary modules here




In [None]:
from tqdm.notebook import tqdm
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torch
import os
import numpy as np
import platform
import pickle
from PIL import Image
from torchvision import transforms
from numpy import random
from torch.utils.data import DataLoader
import albumentations as A


# Preparation of milestone three

Today we will start preparing the third milestone. The third milestone is to train an object detector to recognize cells. To successfully complete the milestone, you will have to complete the following sub-tasks:
- Initialize a pytorch object detector. I'd suggest to choose a RetinaNet or FCOS detection model. More Information can be found [here](https://pytorch.org/vision/stable/models.html#object-detection-instance-segmentation-and-person-keypoint-detection). Since we do not have endless compute power available, we will use a frozen ResNet18 pre-trained on ImageNet as backbone and only train the detection and classification heads of our object detector. So you will have to finde a way to **freeze** the backbone of your detector.
- You will have to write a [training and validation/test](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html) loop to train your detector. Make sure you measure the convergence of the training by monitoring a detection metric like the [mAP](https://torchmetrics.readthedocs.io/en/stable/detection/mean_average_precision.html). Also, you will have to find a way to select the best model during training based on some metric.
- You will have to train your model until convergence using the  class you created for the last milestone. Also you will have to pass your dataset to a dataloader to be able to use multithreading as well as automatic batching.
- At the end, you will have to save the **state_dict** of your trained object detector, to be able to reuse it later.

Please use a jupyter notebook for coding your training/testing pipeline. In the end, you will have to submit that jupiter notebook at moodle.

# If you run the notebook in colab, you have to mount the google drive with the images. Proceed as follows:

- **First**: Open the following **[link](https://drive.google.com/drive/folders/18P74V8kli6qDZtGBLN-tPrJFu3O2NPEK?usp=sharing)** in a new tab.
- **Second**: Add a link to your google Drive.
Example: [Link](https://drive.google.com/file/d/1IcFGGIoktPkDj9-4j5IQ3evInn0c2aq-/view?usp=sharing)
- **Third**: Run the line of code below
- **Fourth**: Grant Google access to your Drive

In [None]:
from google.colab import drive

# path to the link you created
path_to_slides = '/content/gdrive/MyDrive/AgNORs/'
# mount the data
drive.mount('/content/gdrive')

# 1. Initializing the model

In this project, we will use a pre-trained RetinaNet model as the backbone for our object detection task. The model and weights can be easily loaded from the torchvision library. It is important to note that the anchor boxes used by the model may need to be adjusted to suit the specific task.

The behavior of the RetinaNet model changes depending on whether it is in training or evaluation mode. During training, the model expects both an image and a dictionary of targets as input. It returns a dictionary containing the losses and predictions. During validation, the model only expects images as input and returns the predictions without calculating any losses.

In [None]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, annotations_frame,
                 path_to_slides,
                 crop_size = (128,128),
                 pseudo_epoch_length:int = 1000,
                 transformations = None):
        
        super().__init__()
        
        if platform.system() == 'Linux':
            self.separator = '/'
        else:
            self.separator = '\\'

        self.anno_frame = annotations_frame
        self.path_to_slides = path_to_slides
        self.crop_size = crop_size
        self.pseudo_epoch_length = pseudo_epoch_length
        
        # list which holds annotations of all slides in slide_names in the format
        # slide_name, annotation, label, min_x, max_x, min_y, max_y
        
        self.slide_dict, self.annotations_list = self._initialize()
        self.sample_cord_list = self._sample_cord_list()

        # set up transformations
        self.transformations = transformations
        self.transform_to_tensor = transforms.Compose([transforms.ToTensor()])


    def _initialize(self):
        # open all images and store them in self.slide_dict with their name as key value
        slide_dict = {}
        annotations_list = []
        for slide in self.anno_frame.filename.unique():
            # open slide
            slide_dict[slide] =  Image.open(self.path_to_slides + self.separator + slide).convert('RGB')
            im_obj = Image.open(self.path_to_slides + self.separator + slide).convert('RGB')
            slide_dict[slide] = im_obj
            # setting up a list with all bounding boxes
            for idx,annotations in self.anno_frame[self.anno_frame.filename == slide][['max_x','max_y','min_x','min_y','label']].iterrows():
                annotations_list.append([slide, annotations['label'], annotations['min_x'], annotations['min_y'], annotations['max_x'], annotations['max_y']])

        return slide_dict, annotations_list


    def __getitem__(self,index):
        slide, x_cord, y_cord = self.sample_cord_list[index]
        x_cord = np.int64(x_cord)
        y_cord = np.int64(y_cord)
        # load image
        img = self.slide_dict[slide].crop((x_cord,y_cord,x_cord + self.crop_size[0],y_cord + self.crop_size[1]))
        # transform image
        #img = self.transformations(img)
        
        # load boxes for the image
        labels_boxes = self._get_boxes_and_label(slide,x_cord,y_cord)
        
        labels_boxes = [[i[1] - x_cord, i[2] - y_cord, i[3] - x_cord, i[4] - y_cord] + [i[0]] for i in labels_boxes]
        
        
        # applay transformations
        if self.transformations != None:
            if len(labels_boxes) > 0:
                transformed = self.transformations(image = np.array(img), bboxes = labels_boxes)
                boxes = torch.tensor([line[:-1] for line in transformed['bboxes']], dtype = torch.float32)
                labels = torch.ones(boxes.shape[0], dtype = torch.int64)
                img = self.transform_to_tensor(transformed['image'])
                
            # check if there is no labeld instance on the image
            if len(labels_boxes) == 0:
                labels = torch.tensor([0], dtype = torch.int64)
                boxes = torch.zeros((0,4),dtype = torch.float32)
                img = self.transform_to_tensor(img)

        else:
            if len(labels_boxes) == 0:
                labels = torch.tensor([0], dtype = torch.int64)
                boxes = torch.zeros((0,4),dtype = torch.float32)
                img = self.transform_to_tensor(img)
            else:
                # now, you need to change the originale box cordinates to the cordinates of the image
                boxes = torch.tensor([line[:-1] for line in labels_boxes],dtype=torch.float32)
                labels = torch.ones(boxes.shape[0], dtype = torch.int64)
                img = self.transform_to_tensor(img)

        target = {
            'boxes': boxes,
            'labels':labels
        }

        return img, target
        

    def _sample_cord_list(self):
        # select slides from which to sample an image
        slide_names = np.array(list(self.slide_dict.keys()))
        slide_indice = random.choice(np.arange(len(slide_names)), size = self.pseudo_epoch_length, replace = True)
        slides = slide_names[slide_indice]
        # select coordinates from which to load images
        # only works if all images have the same size
        width,height = self.slide_dict[slides[0]].size
        cordinates = random.randint(low = (0,0), high=(width - self.crop_size[0], height - self.crop_size[1]), size = (self.pseudo_epoch_length,2))
        return np.concatenate((slides.reshape(-1,1),cordinates), axis = -1)

    def __len__(self):
        return self.pseudo_epoch_length

    def _get_boxes_and_label(self,slide,x_cord,y_cord):
        return [line[1::] for line in self.annotations_list if line[0] == slide and line[2] > x_cord and line [3] > y_cord and line[4] < x_cord + self.crop_size[0] and line[5] < y_cord + self.crop_size[1]]

    def collate_fn(self, batch):
        """
        Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).
        This describes how to combine these tensors of different sizes. We use lists.
        Note: this need not be defined in this Class, can be standalone.
        :param batch: an iterable of N sets from __iter__()
        :return: a tensor of images, lists of varying-size tensors of bounding boxes, labels, and difficulties
        """

        images = list()
        targets = list()

        for b in batch:
            images.append(b[0])
            targets.append(b[1])
            
        images = torch.stack(images, dim=0)

        return images, targets

    def trigger_sampling(self):
        self.sample_cord_list = self._sample_cord_list()

In [None]:
import torch
import torchvision
from torchvision.models.detection import RetinaNet
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models import MobileNet_V2_Weights

# the size of your crops
#size = 256

# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features

# RetinaNet needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280,
# so we need to add it here
backbone.out_channels = 1280
# let's make the network generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(
     sizes=((32, 64, 128, 256, 512),),
     aspect_ratios=((0.5, 1.0, 2.0),)
)
# put the pieces together inside a RetinaNet model
model = RetinaNet(backbone,
                  num_classes=2,
                  anchor_generator=anchor_generator)

# freeze backbone
# for p in model.backbone.parameters():
#   p.requires_grad = False



# 2. Setting up an optimzer, a detection metric and the train and validation dataloaders

To train the object detector, it is necessary to select an appropriate optimizer. Additionally, the torchmetrics class needs to be instantiated before it can be used for evaluation or tracking metrics during training.
Additionally, initialize a training and validation dataloader your dataset. For more information on how to set up your dataloaders have a look [here](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)


In [None]:
# add your code

# initialize dataset
annotation_frame = pickle.load(open('/content/gdrive/MyDrive/AgNORs/annotation_frame.p','rb'))
path_to_images = '/content/gdrive/MyDrive/AgNORs/'

In [None]:
batch_size = 2
num_workers = 2
num_samples = 250
# train val split
imgs = annotation_frame.filename.unique()
train_imgs = imgs[np.random.choice(np.arange(len(imgs)),size = int(0.75 * len(imgs)), replace = False)]
val_imgs = [i for i in imgs if i not in train_imgs]

# set up transforms
transform = A.Compose([
    A.HorizontalFlip(p=0.3),
    A.RandomBrightnessContrast(p=0.3),
    A.Blur(p=0.3),
    A.ColorJitter(p=0.3),
    A.GaussNoise(p=0.1)
], bbox_params=A.BboxParams(format='pascal_voc'))

train_df = annotation_frame[annotation_frame.filename.isin(train_imgs)]
val_df = annotation_frame[annotation_frame.filename.isin(val_imgs)]

train_ds = Dataset(annotations_frame = train_df,
                   path_to_slides = path_to_images,
                   crop_size = (256,256),
                   pseudo_epoch_length=num_samples,
                   transformations=transform)
val_ds = Dataset(annotations_frame = val_df,
                   path_to_slides = path_to_images,
                   crop_size = (256,256),
                 pseudo_epoch_length=num_samples,
                 transformations = transform)

train_dl = DataLoader(train_ds,
                      batch_size = batch_size,
                      num_workers = num_workers,
                      collate_fn = train_ds.collate_fn)
val_dl = DataLoader(val_ds,
                    batch_size = batch_size,
                    num_workers = num_workers,
                    collate_fn = val_ds.collate_fn)

# free some space
del val_ds
del train_ds

In [None]:
from torch import optim
learning_rate = 1e-4

# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params=params, lr=learning_rate, momentum=0.9)

#3. Train and validation loop

Please write two functions, one for training and one for evaluating your object detector. Use these functions to train the detector for a few epochs. During training, track both the training losses and validation metrics to monitor the model's performance. Save the best detector as observerd by the validation metric.

In [None]:
def train_one_epoch(train_loader, model, optimizer, device:str = 'cpu', epoch:int = 0):
    running_loss = 0.0

    loss_classifier = 0.0
    loss_box_reg = 0.0

    # switch to train mode
    if not model.training:
        model.train()
    # iterating over batches in train_loader
    for i, (images,targets) in tqdm(enumerate(train_loader,0), total = np.ceil(train_loader.dataset.__len__() / train_loader.batch_size)):

        images = images.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        # log losses
        loss_classifier += loss_dict['classification'].detach().to('cpu').numpy()
        loss_box_reg += loss_dict['bbox_regression'].detach().to('cpu').numpy()

        running_loss += losses
        
        # print every 100 Minibatches
        if i % 5 == 0:
            print(f"Epoch: {epoch + 1}, overall loss train: {running_loss/(i + 1 * train_loader.batch_size):.4f}", end='\r')
    
    return running_loss / (i + 1 * train_loader.batch_size)

def validation_one_epoch(val_loader, model, device:str = 'gpu', epoch:int = 0):
    running_loss = 0.0

    loss_classifier = 0.0
    loss_box_reg = 0.0
    
    metric = MeanAveragePrecision()
    preds = []
    tag = []

    # switch to validation mode
    if model.training:
        model.eval()

    with torch.no_grad():
            # iterating over batches in valoader_loader
        for i, (images,targets) in tqdm(enumerate(val_loader,0), total = np.ceil(val_loader.dataset.__len__() / val_loader.batch_size)):
            images = images.to(device)

            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            predictions = model(images)
            
            # handle images with no detections
            for idx,t in enumerate(targets):
                if len(t['boxes']) == 0:
                    targets[idx]['boxes'] = torch.tensor([[0,0,0,0]], dtype = torch.float32).to(device)
            
            #return predictions, targets
            metric.update(predictions,targets)
            preds.append(predictions)
            tag.append(targets)
            
    #return preds, tag
    metrics_values = metric.compute()
    
    print('\n')
    print(f"mAP 50: {metrics_values['map_50']:.3f}\n")
    return  metrics_values

In [None]:
max_epochs = 15
losses_train = []
mAP_val = []

device = 'cuda'
model.to(device)

best_map = 0.0

for e in range(max_epochs):
    print(f"Epoch {e+1}\n-------------------------------")

    # resample list
    train_dl.dataset.trigger_sampling()
    # training loop
    train_loss = train_one_epoch(
    train_loader = train_dl,
    model = model,
    optimizer = optimizer,
    epoch = e,
    device = device)

    losses_train.append(train_loss)

    # validation loop
    metrics = validation_one_epoch(
        model = model,
        val_loader = val_dl,
        device = device)

    mAP_val.append(metrics['map_50'].numpy())
    
    if metrics['map_50'] > best_map:
        print(f"Saving best model at epoch {e+1} with mAP50 of {metrics['map_50'].numpy():.3f}")
        best_model = model.state_dict()
        best_map = metrics['map_50']