In [1]:
%config Completer.use_jedi = False # use autocompletion

import torch
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import detectron2
from pathlib import Path
import random, cv2, os
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.logger import setup_logger
from detectron2.evaluation.evaluator import DatasetEvaluator
from detectron2.engine import BestCheckpointer
from detectron2.checkpoint import DetectionCheckpointer
# import PyCOCO tools
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image

setup_logger()

<Logger detectron2 (DEBUG)>

# Evaluation

In [None]:
# use this to unregister datasets
# DatasetCatalog.clear()

In [2]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"))
# cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
cfg.INPUT.MASK_FORMAT='bitmask'
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 
cfg.MODEL.WEIGHTS = 'output/model_best_fold_3.pth'  
cfg.TEST.DETECTIONS_PER_IMAGE = 1000
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
predictor = DefaultPredictor(cfg)

dataDir=Path('sartorius-cell-instance-segmentation')

register_coco_instances('sartorius_val',{},'sartorius-coco-dataset-5-fold-split/annotations_val_fold_3.json', dataDir)
val_ds = DatasetCatalog.get('sartorius_val')

[32m[12/10 10:40:07 d2.data.datasets.coco]: [0mLoaded 121 images in COCO format from sartorius-coco-dataset-5-fold-split/annotations_val_fold_3.json


# Hyperparameters grid search

In [7]:
def precision_at(threshold, iou):
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    return np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)

def score(pred, targ, thresholds, pixels):
    pred_class = torch.mode(pred['instances'].pred_classes)[0]
    # save only masks that fit in threshold
    take = pred['instances'].scores >= thresholds[pred_class]
    pred_masks = pred['instances'].pred_masks[take]
    pred_masks = pred_masks.cpu().numpy()
    # save only masks with size more than MIN_PIXELS
    res_masks = []
    for mask in pred_masks:
        if mask.sum() >= pixels[pred_class]: # skip predictions with small area
            res_masks.append(mask)
    # score the result masks
    enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in res_masks]
    enc_targs = list(map(lambda x:x['segmentation'], targ['annotations']))
    ious = mask_util.iou(enc_preds, enc_targs, [0]*len(enc_targs))
    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tp, fp, fn = precision_at(t, ious)
        p = tp / (tp + fp + fn)
        prec.append(p)
    return np.mean(prec)

def score_all(thresholds, pixels):
    scores = []
    for item in val_ds:
        im =  cv2.imread(item['file_name'])
        pred = predictor(im)       
        sc = score(pred, item, thresholds, pixels)
        scores.append(sc)
    return np.mean(scores)

thresholds = [0.5, 0.5, 0.5]
pixels = [60, 60, 60]

best_threshs = [(0, 0), (0, 0), (0, 0)]

for i in tqdm(range(2, 3)):
    best_score = 0
    for j in tqdm(np.arange(0.4, 0.75, 0.025)):
        thresholds[i] = j
        curr_score = score_all(thresholds, pixels)
        if curr_score > best_score:
            best_score = curr_score
            best_threshs[i] = (j, best_score)
        print(f'Class number is {i+1}, threshold is {round(j, 2)}, score is {curr_score}')
    thresholds[i] = best_threshs[i][0]
    
best_pixels = [(0, 0), (0, 0), (0, 0)]
thresholds = [i[0] for i in best_threshs]

for i in tqdm(range(3)):
    best_score = 0
    for j in tqdm(np.arange(0, 180, 10)):
        pixels[i] = j
        curr_score = score_all(thresholds, pixels)
        if curr_score > best_score:
            best_score = curr_score
            best_pixels[i] = (j, best_score)
        print(f'Class number is {i+1}, pixel thresh is {round(j, 2)}, score is {curr_score}')
    pixels[i] = best_pixels[i][0]
    
best_threshs, best_pixels

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

KeyboardInterrupt: 

Results of grid search

In [None]:
score_all([0.5, 0.65, 0.5], [70, 75, 70])  # val acc 2896, LB 3070, fold 1

In [None]:
score_all([0.5, 0.575, 0.5], [60, 60, 130])  # val acc 2870, LB 3150 fold 2

In [None]:
score_all([0.5, 0.6, 0.5], [60, 60, 130])  # val acc 3002, LB 3090 fold 3

## Lets look at some of the validation files to check if things look reasonable

We show predictions on the left and ground truth on the right

In [9]:
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set a custom testing threshold
predictor = DefaultPredictor(cfg)
dataset_dicts = DatasetCatalog.get('sartorius_val')
outs = []
for d in random.sample(dataset_dicts, 3):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = Visualizer(im[:, :, ::-1],
                   metadata = MetadataCatalog.get('sartorius_train'), 
                    
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    out_pred = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    visualizer = Visualizer(im[:, :, ::-1], metadata=MetadataCatalog.get('sartorius_train'))
    out_target = visualizer.draw_dataset_dict(d)
    outs.append(out_pred)
    outs.append(out_target)
_,axs = plt.subplots(len(outs)//2,2,figsize=(40,45))
for ax, out in zip(axs.reshape(-1), outs):
    ax.imshow(out.get_image()[:, :, ::-1])