In [1]:
from pathlib import Path
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from skimage.measure import label, regionprops
import pickle
from torchvision import transforms
import cv2
import numpy as np
from torchvision.ops import nms
from torchmetrics.detection import MeanAveragePrecision
from pathlib import Path
from PIL import Image
from torchvision.utils import draw_bounding_boxes
from tqdm import tqdm
import numpy as np
import platform
import matplotlib.pyplot as plt
from numpy import random
from os import listdir
from os.path import isfile, join
from skimage.filters import gabor, gabor_kernel
from skimage.morphology import dilation, erosion
from scipy.signal import convolve2d
import imutils
from imutils import contours

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, annotations_dict, slide_names, path_to_slides, crop_size = (128,128), pseudo_epoch_length:int = 1000, transformations = None):
        super().__init__()
        
        self.anno_dict = annotations_dict
        self.slide_names = slide_names
        self.path_to_slides = path_to_slides
        self.crop_size = crop_size
        self.pseudo_epoch_length = pseudo_epoch_length
        
        # list which holds annotations of all slides in slide_names in the format
        # slide_name, annotation, label, min_x, max_x, min_y, max_y
        
        self.slide_dict, self.annotations_list, self.slide_dict_segmented = self._initialize()
        self.sample_cord_list = self._sample_cord_list()

        # set up transformations
        if transformations is not None:
            self.transformations = transformations
        else:
            self.transformations = transforms.Compose([transforms.ToTensor()])


    def _initialize(self):
        # open all images and store them in self.slide_dict with their name as key value
        slide_dict = {}
        slide_dict_segmented = {}
        annotations_list = []
        for slide in os.listdir(self.path_to_slides):
            if slide in self.slide_names:
                im_obj = Image.open(os.path.join(self.path_to_slides, slide)).convert('RGB')
                slide_dict[slide] = im_obj
                # setting up a list with all bounding boxes


                # Building Segmented Image
                segmented = np.zeros((im_obj.size[1], im_obj.size[0]), dtype=np.uint8)
                for annotation in self.anno_dict[slide]:
                    x_y_list = [(self.anno_dict[slide][annotation]['x'][i], self.anno_dict[slide][annotation]['y'][i]) for i in range(len(self.anno_dict[slide][annotation]['y']))]
                    cv2.fillPoly(segmented, pts=np.array([x_y_list]), color=(self.anno_dict[slide][annotation]['class']))

                slide_dict_segmented[slide] = segmented

                for annotation in self.anno_dict[slide]:
                    max_x, min_x = max(self.anno_dict[slide][annotation]['x']), min(self.anno_dict[slide][annotation]['x'])
                    max_y, min_y = max(self.anno_dict[slide][annotation]['y']), min(self.anno_dict[slide][annotation]['y'])
                    # since 0 is always the background class
                    label = self.anno_dict[slide][annotation]['class']

                    annotations_list.append([slide, annotation, label, min_x, min_y, max_x, max_y])

        return slide_dict, annotations_list, slide_dict_segmented


    def __getitem__(self,index):
        slide, x_cord, y_cord = self.sample_cord_list[index]
        x_cord = np.int64(x_cord)
        y_cord = np.int64(y_cord)
        # load image
        img = self.slide_dict[slide].crop((x_cord,y_cord,x_cord + self.crop_size[0],y_cord + self.crop_size[1]))

        segmented = (self.slide_dict_segmented[slide])[y_cord:y_cord+self.crop_size[1], x_cord:x_cord+self.crop_size[0]]
        # transform image
        img = self.transformations(img)
        
        # load boxes for the image
        labels_boxes = self._get_boxes_and_label(slide,x_cord,y_cord)
        # check if there is no labeld instance on the image
        if len(labels_boxes) == 0:
            labels = torch.tensor([0], dtype = torch.int64)
            boxes = torch.zeros((0,4),dtype = torch.float32)
        else:
            labels = torch.tensor([line[0]-1 for line in labels_boxes], dtype=torch.int64).to(device)
            # now, you need to change the originale box cordinates to the cordinates of the image
            boxes = torch.tensor([[line[1] - x_cord, line[2] - y_cord, line[3] - x_cord, line[4] - y_cord] for line in labels_boxes],dtype=torch.float32).to(device)
        

        target = {
            "boxes" :boxes,
            "labels": labels,
            "segmentation": torch.from_numpy(segmented).type(torch.int64).to(device)
        }

        return img, target
        

    def _sample_cord_list(self):
        # select slides from which to sample an image
        slides = random.choice(self.slide_names, size = self.pseudo_epoch_length, replace = True)
        # select coordinates from which to load images
        # only works if all images have the same size
        width,height = self.slide_dict[slides[0]].size
        cordinates = random.randint(low = (0,0), high=(width - self.crop_size[0], height - self.crop_size[1]), size = (self.pseudo_epoch_length,2))
        return np.concatenate((slides.reshape(-1,1),cordinates), axis = -1)

    def __len__(self):
        return self.pseudo_epoch_length

    def _get_boxes_and_label(self,slide,x_cord,y_cord):
        return [line[2::] for line in self.annotations_list if line[0] == slide and line[3] > x_cord and line [4] > y_cord and line[5] < x_cord + self.crop_size[0] and line[6] < y_cord + self.crop_size[1]]

    def collate_fn(self, batch):
        """
        Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).
        This describes how to combine these tensors of different sizes. We use lists.
        Note: this need not be defined in this Class, can be standalone.
        :param batch: an iterable of N sets from __iter__()
        :return: a tensor of images, lists of varying-size tensors of bounding boxes, labels, and difficulties
        """

        images = list()
        segmentations = list()
        targets = list()

        for b in batch:
            images.append(b[0])
            segmentations.append(b[1]["segmentation"])
            targets.append(b[1])
            
        images = torch.stack(images, dim=0).to(device)
        segmentations = torch.stack(segmentations, dim=0).to(device)
        
        return images, segmentations, targets 

In [4]:
path_to_slides = 'AgNOR_ROI/'
annotations = pickle.load(open(path_to_slides+"annotations_dict_train.p","rb"))

In [5]:
# slides are the filenames of the train images
slides = list(annotations.keys())

In [6]:
batch_size=4

# setting up datasets
training_dataset = Dataset(annotations,slide_names=[slides[0], slides[2]],path_to_slides = path_to_slides ,crop_size=(256,256), pseudo_epoch_length=1000)
validation_dataset = Dataset(annotations,slide_names=[slides[1]],path_to_slides = path_to_slides ,crop_size=(256,256), pseudo_epoch_length=1000)

# setting up dataloaders
train_loader = DataLoader(training_dataset,batch_size=batch_size,collate_fn=training_dataset.collate_fn)
val_loader = DataLoader(validation_dataset,batch_size=batch_size,collate_fn=validation_dataset.collate_fn)

In [7]:
class UNetEncoderBlock(nn.Module):
    def __init__(self, c_in, c_out, k=3, pad=1, stride=2, first_layer=False, use_incr=False):
        super().__init__()
        incr = 0 if (first_layer != True or use_incr == False) else 4
        self.encoder = nn.Sequential(
            nn.Identity() if first_layer else nn.ReLU(), 
            nn.Conv2d(c_in, c_out, (k + 1 + incr), padding=(pad + incr // 2), stride=stride), 
            nn.Identity() if first_layer else nn.InstanceNorm2d(c_out), 
            nn.ReLU(), 
            nn.Conv2d(c_out, c_out, k, padding=pad), 
            nn.InstanceNorm2d(c_out)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x
    
class UNetDecoderBlock(nn.Module):
    def __init__(self, c_in, c_out, k=3, pad=1, stride=2, last_layer=False, use_incr=False):
        super().__init__()
        incr = 0 if (last_layer != True or use_incr == False) else 4
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(c_in, c_in, (k + 1), padding=pad, stride=stride), 
            nn.InstanceNorm2d(c_in), 
            nn.ReLU(), 
            nn.Conv2d(c_in, c_out, (k + incr), padding=(pad + incr // 2)), 
            nn.Identity() if last_layer else nn.InstanceNorm2d(c_out), 
            nn.Identity() if last_layer else nn.ReLU()
        )

    def forward(self, x):
        x = self.decoder(x)
        return x
    
class UNet(nn.Module):
    def __init__(self, hidden_size=1024, c_in=3, c_out=7):
        super().__init__()
        self.encoders = nn.ModuleList([
            UNetEncoderBlock(c_in, (hidden_size//8), first_layer=True),
            UNetEncoderBlock(hidden_size//8, hidden_size//4),
            UNetEncoderBlock(hidden_size//4, hidden_size//2),
            UNetEncoderBlock(hidden_size//2, hidden_size),
        ])
        self.decoders = nn.ModuleList([
            UNetDecoderBlock(hidden_size, hidden_size//2),
            UNetDecoderBlock(2*hidden_size//2, hidden_size//4),
            UNetDecoderBlock(2*hidden_size//4, hidden_size//8),
            UNetDecoderBlock((2*hidden_size//8), c_out, last_layer=True),
        ])


    def forward(self, x):
        encodings = []
        for i, encoder in enumerate(self.encoders):
            x = encoder(x)
            encodings.insert(0, x)

        for i, decoder in enumerate(self.decoders):
            if i > 0:
                x = torch.cat((x, encodings[i]), axis=1)
            x = decoder(x)
        
        return x.softmax(dim=1)

In [8]:
def remove_border_bboxes(bbox_list):
    bbox_new = []
    for bbox in bbox_list:
        tmp = [int(x) for x in bbox]
        # bbox = [x_min, y_min, x_max, y_max]
        if (0 in tmp or  256 in tmp):
            continue
        else:
            bbox_new.append(bbox)
    return bbox_new

In [9]:
def cv_contours(img):
    cnts = cv2.findContours(img.type(torch.uint8).numpy(), cv2.RETR_TREE,cv2.CHAIN_APPROX_NONE)
    cnts = imutils.grab_contours(cnts)
    if len(cnts) == 0:
        return []
    cnts = contours.sort_contours(cnts)[0]
    # # loop over the contours
    bboxes = []
    for (i, c) in enumerate(cnts):
        (x, y, w, h) = cv2.boundingRect(c)
        if w * h >= 500:
            bboxes.append([x, y, x+w, y+h])
    bboxes = remove_border_bboxes(bboxes) 
    bboxes = torch.Tensor(bboxes)
    # z = draw_bounding_boxes((images[0]*255).type(torch.uint8), bboxes)
    return bboxes

In [10]:
def eval_forward(model, images):
    seg_pred = model(images)
    seg_pred_to_channels = F.one_hot(seg_pred.argmax(dim=1)).permute(0, 3, 1, 2).detach().cpu()[:, 1:, :, :]  # bs x 6 x 256 x 256 
    predictions = []
    for pred in seg_pred_to_channels:  # pred: 6 x 256 x 256 
        prediction = {}
        for i, pred_class in enumerate(pred):  # pred_class: 256 x 256
            bboxes = cv_contours(pred_class)  # bboxes: [[1, 1, 1, 1], [2, 3, 3 ,2]]
            if len(bboxes) > 0:
                prediction["boxes"] = bboxes if "boxes" not in prediction.keys() else torch.concat([prediction["boxes"], bboxes])
                prediction["scores"] = torch.ones((bboxes.shape[0])) if "scores" not in prediction.keys() else torch.concat([prediction["scores"], torch.ones((bboxes.shape[0]))])
                prediction["labels"] = torch.ones((bboxes.shape[0])) * i if "labels" not in prediction.keys() else torch.concat([prediction["labels"], torch.ones((bboxes.shape[0])) * i])
        predictions.append(prediction)
    return predictions, seg_pred

In [11]:
def validation_one_epoch(val_loader, model, device:str = 'cpu', epoch:int = 0):
    metric = MeanAveragePrecision()
    with torch.no_grad():
        for i, (images, _, targets) in enumerate(val_loader):
            images = images.to(device)

            predictions, _ = eval_forward(model, images)
            
            for idx,t in enumerate(targets):
                if len(t["boxes"]) == 0:
                    targets[idx]['boxes'] = torch.tensor([[0,0,0,0]], dtype = torch.float32).to(device)
            targets = [{'boxes': t["boxes"].cpu(), 'labels': t["labels"].cpu()} for t in targets]
            metric.update(predictions,targets)
            
    metrics_values = metric.compute()
    print(f"mAP 50: {metrics_values['map_50']:.3f}")

In [12]:
model = UNet().to(device)
epochs = 40
optim = torch.optim.Adam(model.parameters(), lr=3e-3)
# lr_schedule = WarmupLinearLRSchedule(optim, init_lr=3e-5, peak_lr=3e-4, end_lr=5e-5, warmup_epochs=int(epochs*0.1), epochs=epochs)
# pos_weight = torch.Tensor([1, 5, 5, 5, 5, 5]).to(device)
pos_weight = torch.Tensor([1, 1, 1, 1, 1, 1]).to(device)
from kornia.losses import FocalLoss
loss_fn = FocalLoss(alpha=0.5, gamma=5., reduction="mean")
# loss_fn = nn.CrossEntropyLoss(weight=pos_weight)

In [None]:
losses = []
for e in range(epochs):
    moving_loss = 0.
    pbar = tqdm(train_loader)
    for i, (images, segmentations, _) in enumerate(pbar):
    # images, segmentations = d  # overfitting one example to see if model can learn
        seg_pred = model(images)
        loss = loss_fn(seg_pred, segmentations)
        moving_loss += loss.item() / (i+1)
        losses.append(moving_loss)
#         if i % 10 == 0:
#             plt.imshow(torch.cat([
#                 torch.cat([s for s in seg_pred.argmax(dim=1).detach().cpu()], dim=1),
#                 torch.cat([s for s in segmentations.cpu()], dim=1)
#             ], dim=0))
#             plt.show()
        optim.zero_grad()
        loss.backward()
        optim.step()
        pbar.set_postfix(loss=moving_loss)
#     print("Predicting on Val Set:")
#     lr_schedule.step()
    
    print(f"======================================================================================")
    print(f"========================================Epoch: {e}======================================")
    
    torch.save(model.state_dict(), "ckpt.pt")
    model.eval()
    validation_one_epoch(val_loader, model, "cuda")
    images, _, targets = next(iter(val_loader))
    idx = 3
    preds, seg_pred = eval_forward(model, images)
    preds = preds[idx]
    bboxes = targets[idx]["boxes"]
    print(f'Target-Labels: {targets[idx]["labels"]}')
    print(f'Predicted Labels: {preds["labels"]}')
    print("Num boxes real", len(targets[idx]["labels"]))
    print("Num boxes pred", len(preds["labels"]))
    plt.imshow(draw_bounding_boxes((images[idx]*255).type(torch.uint8), bboxes).permute(1, 2, 0))
    plt.show()
    plt.imshow(draw_bounding_boxes((images[idx]*255).type(torch.uint8), preds["boxes"]).permute(1, 2, 0))
    plt.show()
    plt.imshow(seg_pred.argmax(dim=1)[idx].detach().cpu())
    plt.show()
#     pbar = tqdm(val_loader)
#     for i, (images, segmentations, _) in enumerate(pbar):
#         seg_pred = model(images)
#         loss = loss_fn(seg_pred, segmentations)
#         if i % 50 == 0:
#             plt.imshow(torch.cat([
#                 torch.cat([s for s in seg_pred.argmax(dim=1).detach().cpu()], dim=1),
#                 torch.cat([s for s in segmentations.cpu()], dim=1)
#             ], dim=0))
#             plt.show()
#     print("Finished Prediction on Val Set")
    model.train()
    print("======================================================================================")

In [None]:
## Save model
torch.save(model.state_dict(), "model_3.pt")

### Checkpoint

In [12]:
## load model
model = UNet().to(device)
model.load_state_dict(torch.load("ckpt.pt", map_location=device))

<All keys matched successfully>

## Eval

In [13]:
def evaluate_on_whole_images(model,test_slides:list, image_folder_path:Path, device:str = 'cuda', overlap:int=50, crop_size:tuple=(256,256), compute_map:bool=True, annotations:dict = None):
    """Evaluates a model on the whole image. All parts of the image are processed seperately, with an overlap of defined size. Detections are filtered by non maximal surpression.
     If compute_map is False, only predictions are made, using the model passed.

    Args:
        model (_type_): Pytorch Model
        test_slides (list): Name of the slides to perform evaluation on
        image_folder_path (Path): path to the folder where the images are stored
        device (str, optional): Which device to use for making predictions. Defaults to 'cuda'.
        overlap (int, optional): Overlap of subsequent crops in Pixels. Defaults to 50.
        crop_size (tuple, optional): Size of the crops which are passed to the model. Defaults to (256,256).
        compute_map (bool, optional): Wether to compute maP. Defaults to True.
        annotations (dict, optional): Pass the annotations dict if you want to calculate the map. Defaults to None.

    Returns:
        If compute_map is False, only a dict holding the predictions for each image is returned. Otherwise, also a dict with metrics is returned
    """
    transform = transforms.Compose([transforms.ToTensor()])
    model.to(device)

    if compute_map:
        metric = MeanAveragePrecision()

    results_boxes = {}
    with torch.no_grad():
        for slide_name in test_slides:
            totalpredictions = {'boxes': [], 'labels': [], 'scores': []}
            results_boxes[slide_name] = []
            # process one image
            img = Image.open(image_folder_path / Path(slide_name))
            width, height = img.size
            # sample crops with overlap
            for xmin in tqdm(np.arange(0,width + 1, (crop_size[0] - overlap)), f'evaluating on image {slide_name}'):
                for ymin in np.arange(0,height + 1, crop_size[1] - overlap):
                    crop = img.crop((xmin, ymin, xmin+crop_size[0], ymin+crop_size[1])).convert('RGB')
                    crop = transform(crop)

                    
                    # detect figures on crop
                    predictions, _ = eval_forward(model, crop.unsqueeze(dim = 0).to(device))
                    
                    
            
                    # correct offset, so bring the coordinates back to the coordinate system of the whole image
                    predictions[0]['boxes'] = torch.Tensor([[x1+xmin,y1+ymin,x2+xmin,y2+ymin] for x1,y1,x2,y2 in predictions[0]['boxes']])


                    for det in predictions[0]['boxes']:
                        totalpredictions['boxes'].append(det)
                    for det in predictions[0]['scores']:
                        totalpredictions['scores'].append(det)
                    for det in predictions[0]['labels']:
                        totalpredictions['labels'].append(det+1)

            if (len(totalpredictions['boxes'])>0):
                totalpredictions['boxes'] = torch.stack(totalpredictions['boxes']).to('cpu')
                totalpredictions['labels'] = torch.stack(totalpredictions['labels']).to('cpu')
                totalpredictions['scores'] = torch.stack(totalpredictions['scores']).to('cpu')
            else: # stack does not work for empty arrays
                totalpredictions['boxes'] = predictions[0]['boxes'].to('cpu')
                totalpredictions['labels'] = predictions[0]['labels'].to('cpu')
                totalpredictions['scores'] = predictions[0]['scores'].to('cpu') # empty anyways
                
           

            for b,l,sc in zip(totalpredictions['boxes'], totalpredictions['labels'], totalpredictions['scores']):
                results_boxes[slide_name].append([*b, l, sc])

            # get the targets from the annotation data
            if compute_map:
                if annotations == None:
                    print(f"annotations dict required to compute the map!")
                    return None
                else:
                    # get targets from the annotations dict
                    targets = get_targets(annotations,slide_name)
                    # update matric with detections, made on the current image
                    metric.update([totalpredictions],[targets])

        # finally compute the Ap over all test images
        if compute_map:
            metric_values = metric.compute()
            return metric_values, results_boxes
        else:
            return results_boxes

In [14]:
def get_targets(annotations:dict,slide_name:str):
    """Returns a dict with boxes and labels in pytorch format.

    Args:
        annotations (dict): annotations dict
        slide_name (str): name of the slide for which to return the boxes and labels

    Returns:
        (dict): Dict holding boxes and labels in pytorch format
    """
    boxes = []
    labels = []
    for annotation in annotations[slide_name].values():
        maxx,minx = max(annotation['x']), min(annotation['x'])
        maxy,miny = max(annotation['y']), min(annotation['y'])
        boxes.append([minx,miny,maxx,maxy])
        labels.append(annotation['class'])

    targets = {
        'boxes': torch.tensor(boxes, dtype= torch.float32),
        'labels': torch.tensor(labels, dtype = torch.int64)
    }
    return targets

In [15]:
## Training Image: AgNOR_0484.tiff, AgNOR_0517.tiff
## Validation Image: AgNOR_0622.tiff
path_to_slides = Path('AgNOR_ROI/')
annotations = pickle.load(open(path_to_slides / Path("annotations_dict_train.p"),"rb"))
test_slides = list(annotations.keys())[2]

In [16]:
metric_values, predictions = evaluate_on_whole_images(
    model=model,
    test_slides=[test_slides],
    image_folder_path=path_to_slides,
    annotations = annotations)

evaluating on image AgNOR_0622.tiff: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.05it/s]


In [17]:
print(f"map_50 on Validation Image: {metric_values['map_50']:.5f}")

map_50 on Validation Image: 0.51250


## Final Evaluations

In [18]:
path_to_slides = Path('AgNOR_ROI/test_images/')
test_slides = os.listdir(path_to_slides)[:-1]
print(test_slides)

['AgNOR_0495.tiff', 'AgNOR_2876.tiff', 'AgNOR_2906.tiff', 'AgNOR_8581.tiff', 'AgNOR_9845.tiff']


In [137]:
predictions = evaluate_on_whole_images(
    model=model,
    test_slides=test_slides,
    image_folder_path=path_to_slides,
    annotations = annotations,
    compute_map = False)

evaluating on image AgNOR_0495.tiff: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.17it/s]
evaluating on image AgNOR_2876.tiff: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.36it/s]
evaluating on image AgNOR_2906.tiff: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.37it/s]
evaluating on image AgNOR_8581.tiff: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.34it/s]
evaluating on image AgNOR_9845.tiff: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.37it/s]


In [138]:
pickle.dump(predictions, open('test_predictions_index_1', 'wb'))