# Preparation for Milestone Four

Today, we will begin preparing for the final milestone. Here, we will assemble all the pieces of the pipeline you've created. You will need to write a function **compute_AgNOR_score**. This function first utilizes the detection model to locate cells within a given image and then feeds those cells into a classification model to classify them into one of the AgNOR classes. Finally, you will aggregate all predictions into a final AgNOR score for the entire image.

In [24]:
!pip install torch torchvision torchmetrics albumentations opencv-python



In [26]:
from google.colab import drive
import os
import pandas as pd
import numpy as np
import torch
import random
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torchvision.models.detection import RetinaNet
from torchvision.models.detection.retinanet import RetinaNetClassificationHead, AnchorGenerator
from torchvision.models import MobileNet_V2_Weights
from tqdm import tqdm
from torchmetrics.detection.mean_ap import MeanAveragePrecision

In [27]:
# path to the link you created
annotations_path = '/content/gdrive/MyDrive/AgNORs/annotation_frame.p'
path_to_slides = '/content/gdrive/MyDrive/AgNORs/'

# mount the data
drive.mount('/content/gdrive')
annotations = pd.read_pickle(annotations_path)
print(annotations.head())
print(annotations.columns)
df = pd.read_csv(path_to_slides +"annotation_frame.csv")
unique_filenames = df['filename'].unique()

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
          filename  max_x  max_y  min_x  min_y  label
0  AgNOR_0495.tiff     26     41      4     15      1
1  AgNOR_0495.tiff     71     23     42      0      2
2  AgNOR_0495.tiff    133     61    104     37      1
3  AgNOR_0495.tiff    143    117    121     88      2
4  AgNOR_0495.tiff    224     37    199     12      1
Index(['filename', 'max_x', 'max_y', 'min_x', 'min_y', 'label'], dtype='object')


In [28]:
total_images = len(unique_filenames)
train_size = int(total_images * 0.8)
test_size = int(total_images * 0.1)
validation_size = total_images - train_size - test_size

train_images = unique_filenames[:train_size]
test_images = unique_filenames[train_size:train_size + test_size]
validation_images = unique_filenames[train_size + test_size:]

train_df = df[df['filename'].isin(train_images)]
val_df = df[df['filename'].isin(validation_images)]

class CustomDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.dataframe.iloc[idx, 0])
        image = Image.open(img_name).convert("RGB")
        boxes = self.dataframe.iloc[idx, 1:5].values
        boxes = boxes.astype(np.float32).reshape(-1, 4)
        target = {'boxes': torch.tensor(boxes), 'labels': torch.tensor(self.dataframe.iloc[idx, 5], dtype=torch.int64).unsqueeze(0)}
        if self.transform:
            image = self.transform(image)
        return image, target

train_transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(224, 224),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(256, 256),
    A.CenterCrop(224, 224),
    ToTensorV2()
])

train_dataset = CustomDataset(train_df, path_to_slides, transform=train_transform)
val_dataset = CustomDataset(val_df, path_to_slides, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

# 1. Write a function "process_image" which receives an image and runs the detection model on it.

The function should have the following parameters:

1. image: The image on which you want to run inference.
2. crop_size: The size of the crops you want to load from the image.
3. overlap: Percentage or number of pixels the crops should overlap.
4. model: The object detection model. This function should generally be able to run with any detection model.
5. detection_threshold: A threshold to apply to the detections to reject false positives.

The function will have to tile the image into **overlapping crops** and then feed each crop to the model. After that, all detections have to be transformed to the global coordinate system of the image since the detections are within the coordinate system of the image crop. Subsequently, [non-maximal suppression](https://pytorch.org/vision/stable/generated/torchvision.ops.nms.html) needs to be applied to the detections in order to reject overlapping detections. In the end, the function will return the coordinates and scores of the detected cells that exceed the given threshold. Use **torch_no_grad** to save computation time and also ensure your **model is in evaluation mode** before feeding the cells to it.

In [30]:
def process_image(image, crop_size, overlap, model, detection_threshold):
    if isinstance(image, Image.Image):
        image = np.array(image)
    img_h, img_w = image.shape[:2]
    if overlap < 1:
        step_size = int(crop_size * (1 - overlap))
    else:
        step_size = crop_size - overlap
    model.eval()
    detections = []
    with torch.no_grad():
        for y in range(0, img_h, step_size):
            for x in range(0, img_w, step_size):
                crop = image[y:y+crop_size, x:x+crop_size]
                if crop.shape[0] < crop_size or crop.shape[1] < crop_size:
                    pad_h = crop_size - crop.shape[0]
                    pad_w = crop_size - crop.shape[1]
                    crop = np.pad(crop, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
                crop_tensor = torch.tensor(crop).permute(2, 0, 1).unsqueeze(0).float()
                output = model(crop_tensor)[0]
                for box, score in zip(output['boxes'], output['scores']):
                    if score >= detection_threshold:
                        global_box = box + torch.tensor([x, y, x, y])
                        detections.append((global_box, score))
    if len(detections) == 0:
        return [], []
    boxes, scores = zip(*detections)
    boxes = torch.stack(boxes)
    scores = torch.tensor(scores)
    keep = torchvision.ops.nms(boxes, scores, iou_threshold=0.5)
    final_boxes = boxes[keep].tolist()
    final_scores = scores[keep].tolist()
    return final_boxes, final_scores

# 2. Write a function "process_cells" which classifies the cells from the coordinates that were given to the model.

The function should have the following parameters:

1. image: The image from which to load the cells.
2. coords: Coordinates of the cells which you found with the detection algorithm.
3. model: The trained classification model.
4. crop_size: A size to resize the crops to. It should be equal to the size with which you trained the classification network.

The function should load each cell from the respective image and feed them to the classification model. Save the prediction and, in the end, aggregate the classifications of all cells into a final AgNOR score. The function should return the labels of the respective cells as well as the final AgNOR score.

In [31]:
def process_cells(image, coords, model, crop_size):
    transform = transforms.Compose([
        transforms.Resize((crop_size, crop_size)),
        transforms.ToTensor(),
    ])
    model.eval()
    labels = []
    with torch.no_grad():
        for box in coords:
            cropped_image = image.crop(box)
            input_tensor = transform(cropped_image).unsqueeze(0)
            output = model(input_tensor)
            label = output.argmax(1).item()
            labels.append(label)
    AgNOR_score = sum(labels) / len(labels) if labels else 0
    return labels, AgNOR_score

# 3. Combine both functions into the function **compute_AgNOR_score**.

This function should receive the image as a parameter and also require all parameters to execute the subfunctions. In the end, this function should return the overall AgNOR score of the image.

In [32]:
def compute_AgNOR_score(image, detection_model, classification_model, crop_size, overlap, detection_threshold):
    boxes, scores = process_image(image, crop_size, overlap, detection_model, detection_threshold)
    labels, AgNOR_score = process_cells(image, boxes, classification_model, crop_size)
    return AgNOR_score

# 4. Test your pipeline.

Take several images (approximately 5) and run them through your pipeline. Then, calculate the error between the predicted AgNOR score and the AgNOR score defined by the labels of the cells in the annotation file. To obtain this label, simply calculate the mean of the labels of the respective image.

In [34]:
sample_images = unique_filenames[:5]
ground_truth_scores = []
for image_name in sample_images:
    image_annotations = df[df['filename'] == image_name]
    ground_truth_score = image_annotations['label'].mean()
    ground_truth_scores.append(ground_truth_score)

detection_model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
classification_model = torchvision.models.resnet50(pretrained=True)
classification_model.fc = torch.nn.Linear(classification_model.fc.in_features, 2)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
classification_model.to(device)
classification_model.eval()

crop_size = 224
overlap = 0.2
detection_threshold = 0.5

predicted_scores = []
for image_name in sample_images:
    image_path = os.path.join(path_to_slides, image_name)
    image = Image.open(image_path).convert("RGB")
    predicted_score = compute_AgNOR_score(image, detection_model, classification_model, crop_size, overlap, detection_threshold)
    predicted_scores.append(predicted_score)

errors = [abs(pred - gt) for pred, gt in zip(predicted_scores, ground_truth_scores)]
mean_error = sum(errors) / len(errors)
print(f'Mean Absolute Error: {mean_error}')

Mean Absolute Error: 1.9812982357580686
