### Notebook 2: Training the Cascade R-CNN model

In [1]:
import detectron2
import torch
from pathlib import Path
import random, cv2, os
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.logger import setup_logger
from detectron2.evaluation.evaluator import DatasetEvaluator
from detectron2.structures import polygons_to_bitmask
setup_logger()

<Logger detectron2 (DEBUG)>

#### Pretraining on the LIVECell data

In [2]:
dataDir=Path('../LIVECell_dataset_2021/images/livecell_train_val_images')
cfg = get_cfg()
register_coco_instances('sartorius_train',{}, '../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json', dataDir)
register_coco_instances('sartorius_val',{},'../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_val.json', dataDir)
register_coco_instances('sartorius_test',{}, '../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json', dataDir)

[32m[01/04 07:58:58 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json takes 9.06 seconds.
[32m[01/04 07:58:59 d2.data.datasets.coco]: [0mLoaded 3253 images in COCO format from ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json


In [2]:
def polygon_to_rle(polygon, shape=(520, 704)):
    mask = polygons_to_bitmask([np.asarray(polygon) + 0.25], shape[0], shape[1])
    rle = mask_util.encode(np.asfortranarray(mask))
    return rle

def precision_at(threshold, iou):
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1
    false_positives = np.sum(matches, axis=0) == 0
    false_negatives = np.sum(matches, axis=1) == 0
    return np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)

def score(pred, targ):
    pred_masks = pred['instances'].pred_masks.cpu().numpy()
    enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in pred_masks]
    enc_targs = list(map(lambda x:x['segmentation'], targ))
    enc_targs = [polygon_to_rle(enc_targ[0]) for enc_targ in enc_targs]
    ious = mask_util.iou(enc_preds, enc_targs, [0]*len(enc_targs))
    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tp, fp, fn = precision_at(t, ious)
        p = tp / (tp + fp + fn)
        prec.append(p)
    return np.mean(prec)

class MAPIOUEvaluator(DatasetEvaluator):
    def __init__(self, dataset_name):
        dataset_dicts = DatasetCatalog.get(dataset_name)
        self.annotations_cache = {item['image_id']:item['annotations'] for item in dataset_dicts}
            
    def process(self, inputs, outputs):
        for inp, out in zip(inputs, outputs):
            if len(out['instances']) == 0:
                self.scores.append(0)    
            else:
                targ = self.annotations_cache[inp['image_id']]
                self.scores.append(score(out, targ))

    def evaluate(self):
        return {"MaP IoU": np.mean(self.scores)}

class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        return MAPIOUEvaluator(dataset_name)

In [None]:
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("sartorius_train", "sartorius_test")
cfg.DATASETS.TEST = ("sartorius_val",)
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.0005
cfg.SOLVER.MAX_ITER = 100000
cfg.SOLVER.STEPS = []
cfg.SOLVER.CHECKPOINT_PERIOD = (len(DatasetCatalog.get('sartorius_train')) + len(DatasetCatalog.get('sartorius_test'))) // cfg.SOLVER.IMS_PER_BATCH  # Once per epoch
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .5
cfg.TEST.EVAL_PERIOD = (len(DatasetCatalog.get('sartorius_train')) + len(DatasetCatalog.get('sartorius_test'))) // cfg.SOLVER.IMS_PER_BATCH  # Once per epoch

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
print(cfg.OUTPUT_DIR)
trainer = Trainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

[32m[01/03 22:40:20 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json takes 8.94 seconds.
[32m[01/03 22:40:21 d2.data.datasets.coco]: [0mLoaded 3253 images in COCO format from ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json
[32m[01/03 22:40:28 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json takes 3.81 seconds.
[32m[01/03 22:40:28 d2.data.datasets.coco]: [0mLoaded 1564 images in COCO format from ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json
[32m[01/03 22:40:39 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json takes 9.34 seconds.
[32m[01/03 22:40:39 d2.data.datasets.coco]: [0mLoaded 3253 images in COCO format from ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json
[32m[01/03 22:40:47 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021

Skip loading parameter 'roi_heads.box_predictor.0.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (9, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.0.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (9,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.1.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (9, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.1.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (9,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.2.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (9, 1024) 

[32m[01/03 22:41:11 d2.engine.train_loop]: [0mStarting training from iteration 0


  max_size = (max_size + (stride - 1)) // stride * stride
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[01/03 22:41:21 d2.utils.events]: [0m eta: 11:33:37  iter: 19  total_loss: 13.71  loss_cls_stage0: 2.169  loss_box_reg_stage0: 0.5272  loss_cls_stage1: 2.13  loss_box_reg_stage1: 0.7249  loss_cls_stage2: 2.335  loss_box_reg_stage2: 0.6518  loss_mask: 0.6915  loss_rpn_cls: 4.05  loss_rpn_loc: 0.3246  time: 0.4247  data_time: 0.0214  lr: 9.9905e-06  max_mem: 7824M
[32m[01/03 22:41:26 d2.utils.memory]: [0mAttempting to copy inputs of <function pairwise_iou at 0x7f88c8065d30> to CPU due to CUDA OOM
[32m[01/03 22:41:31 d2.utils.events]: [0m eta: 11:10:49  iter: 39  total_loss: 10.29  loss_cls_stage0: 2.071  loss_box_reg_stage0: 0.5426  loss_cls_stage1: 2.046  loss_box_reg_stage1: 0.7413  loss_cls_stage2: 2.136  loss_box_reg_stage2: 0.6098  loss_mask: 0.6874  loss_rpn_cls: 0.9625  loss_rpn_loc: 0.3089  time: 0.4653  data_time: 0.0064  lr: 1.998e-05  max_mem: 7824M
[32m[01/03 22:41:38 d2.utils.events]: [0m eta: 10:48:15  iter: 59  total_loss: 8.879  loss_cls_stage0: 1.892  loss_bo

In [4]:
# last cell crashed (out of memory), we resume training here
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("sartorius_train", "sartorius_test")
cfg.DATASETS.TEST = ("sartorius_val",)
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.0005
cfg.SOLVER.MAX_ITER = 100000
cfg.SOLVER.STEPS = []
cfg.SOLVER.CHECKPOINT_PERIOD = (len(DatasetCatalog.get('sartorius_train')) + len(DatasetCatalog.get('sartorius_test'))) // cfg.SOLVER.IMS_PER_BATCH  # Once per epoch
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .5
cfg.TEST.EVAL_PERIOD = (len(DatasetCatalog.get('sartorius_train')) + len(DatasetCatalog.get('sartorius_test'))) // cfg.SOLVER.IMS_PER_BATCH  # Once per epoch

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
print(cfg.OUTPUT_DIR)
trainer = Trainer(cfg)
trainer.resume_or_load(resume=True)
trainer.train()

[32m[01/04 07:59:44 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json takes 9.06 seconds.
[32m[01/04 07:59:44 d2.data.datasets.coco]: [0mLoaded 3253 images in COCO format from ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json
[32m[01/04 07:59:51 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json takes 3.84 seconds.
[32m[01/04 07:59:51 d2.data.datasets.coco]: [0mLoaded 1564 images in COCO format from ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json
[32m[01/04 08:00:02 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json takes 9.24 seconds.
[32m[01/04 08:00:02 d2.data.datasets.coco]: [0mLoaded 3253 images in COCO format from ../LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json
[32m[01/04 08:00:10 d2.data.datasets.coco]: [0mLoading ../LIVECell_dataset_2021

  max_size = (max_size + (stride - 1)) // stride * stride
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[01/04 08:00:45 d2.utils.events]: [0m eta: 5:32:04  iter: 48179  total_loss: 3.287  loss_cls_stage0: 0.2116  loss_box_reg_stage0: 0.3803  loss_cls_stage1: 0.2406  loss_box_reg_stage1: 0.8234  loss_cls_stage2: 0.2541  loss_box_reg_stage2: 0.943  loss_mask: 0.2475  loss_rpn_cls: 0.09018  loss_rpn_loc: 0.1758  time: 0.4005  data_time: 0.0190  lr: 0.0005  max_mem: 6374M
[32m[01/04 08:00:48 d2.utils.memory]: [0mAttempting to copy inputs of <function pairwise_iou at 0x7f406df78d30> to CPU due to CUDA OOM
[32m[01/04 08:00:53 d2.utils.memory]: [0mAttempting to copy inputs of <function pairwise_iou at 0x7f406df78d30> to CPU due to CUDA OOM
[32m[01/04 08:00:59 d2.utils.events]: [0m eta: 5:26:56  iter: 48199  total_loss: 3.264  loss_cls_stage0: 0.1818  loss_box_reg_stage0: 0.3661  loss_cls_stage1: 0.1947  loss_box_reg_stage1: 0.7988  loss_cls_stage2: 0.2304  loss_box_reg_stage2: 0.9902  loss_mask: 0.2489  loss_rpn_cls: 0.0888  loss_rpn_loc: 0.1782  time: 0.5472  data_time: 0.0066  lr:

We choose the last model.

#### Training on the Sartorius data:

In [2]:
dataDir=Path('../')
cfg = get_cfg()
cfg.INPUT.MASK_FORMAT='bitmask'
register_coco_instances('sartorius_train',{}, '../sartorius-annotations-coco-format/annotations_train.json', dataDir)
register_coco_instances('sartorius_val',{},'../sartorius-annotations-coco-format/annotations_val.json', dataDir)

[32m[01/04 18:30:41 d2.data.datasets.coco]: [0mLoaded 485 images in COCO format from ../sartorius-annotations-coco-format/annotations_train.json


In [3]:
def precision_at(threshold, iou):
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1
    false_positives = np.sum(matches, axis=0) == 0
    false_negatives = np.sum(matches, axis=1) == 0
    return np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)

def score(pred, targ):
    pred_masks = pred['instances'].pred_masks.cpu().numpy()
    enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in pred_masks]
    enc_targs = list(map(lambda x:x['segmentation'], targ))
    ious = mask_util.iou(enc_preds, enc_targs, [0]*len(enc_targs))
    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tp, fp, fn = precision_at(t, ious)
        p = tp / (tp + fp + fn)
        prec.append(p)
    return np.mean(prec)

class MAPIOUEvaluator(DatasetEvaluator):
    def __init__(self, dataset_name):
        dataset_dicts = DatasetCatalog.get(dataset_name)
        self.annotations_cache = {item['image_id']:item['annotations'] for item in dataset_dicts}
            
    def reset(self):
        self.scores = []

    def process(self, inputs, outputs):
        for inp, out in zip(inputs, outputs):
            if len(out['instances']) == 0:
                self.scores.append(0)    
            else:
                targ = self.annotations_cache[inp['image_id']]
                self.scores.append(score(out, targ))

    def evaluate(self):
        return {"MaP IoU": np.mean(self.scores)}

class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        return MAPIOUEvaluator(dataset_name)

In [4]:
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("sartorius_train",)
cfg.DATASETS.TEST = ("sartorius_val",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = 'output_2/model_final.pth'
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.0005
cfg.SOLVER.MAX_ITER = 10000
cfg.SOLVER.STEPS = []
cfg.SOLVER.CHECKPOINT_PERIOD = len(DatasetCatalog.get('sartorius_train')) // cfg.SOLVER.IMS_PER_BATCH
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .5
cfg.TEST.EVAL_PERIOD = len(DatasetCatalog.get('sartorius_train')) // cfg.SOLVER.IMS_PER_BATCH

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

[32m[01/04 18:30:43 d2.data.datasets.coco]: [0mLoaded 485 images in COCO format from ../sartorius-annotations-coco-format/annotations_train.json
[32m[01/04 18:30:44 d2.data.datasets.coco]: [0mLoaded 485 images in COCO format from ../sartorius-annotations-coco-format/annotations_train.json
[32m[01/04 18:30:47 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1,

Skip loading parameter 'roi_heads.box_predictor.0.cls_score.weight' to the model due to incompatible shapes: (9, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.0.cls_score.bias' to the model due to incompatible shapes: (9,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.1.cls_score.weight' to the model due to incompatible shapes: (9, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.1.cls_score.bias' to the model due to incompatible shapes: (9,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.2.cls_score.weight' to the model due to incompatible shapes: (9, 1024) in the checkpoint but (4, 1024) in th

[32m[01/04 18:30:48 d2.engine.train_loop]: [0mStarting training from iteration 0


  max_size = (max_size + (stride - 1)) // stride * stride
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[01/04 18:31:01 d2.utils.events]: [0m eta: 1:01:23  iter: 19  total_loss: 7.083  loss_cls_stage0: 1.465  loss_box_reg_stage0: 0.4064  loss_cls_stage1: 1.401  loss_box_reg_stage1: 0.5394  loss_cls_stage2: 1.369  loss_box_reg_stage2: 0.5443  loss_mask: 0.6926  loss_rpn_cls: 0.311  loss_rpn_loc: 0.248  time: 0.6473  data_time: 0.2787  lr: 9.9905e-06  max_mem: 5046M
[32m[01/04 18:31:15 d2.utils.events]: [0m eta: 1:01:16  iter: 39  total_loss: 6.707  loss_cls_stage0: 1.376  loss_box_reg_stage0: 0.3767  loss_cls_stage1: 1.333  loss_box_reg_stage1: 0.5161  loss_cls_stage2: 1.297  loss_box_reg_stage2: 0.4185  loss_mask: 0.6902  loss_rpn_cls: 0.3332  loss_rpn_loc: 0.2939  time: 0.6726  data_time: 0.3332  lr: 1.998e-05  max_mem: 5217M
[32m[01/04 18:31:23 d2.utils.events]: [0m eta: 0:58:48  iter: 59  total_loss: 6.34  loss_cls_stage0: 1.218  loss_box_reg_stage0: 0.4114  loss_cls_stage1: 1.189  loss_box_reg_stage1: 0.5748  loss_cls_stage2: 1.163  loss_box_reg_stage2: 0.5217  loss_mask: 0