In [None]:
!git clone https://github.com/sartorius-research/LIVECell.git

In [None]:
import sys
import torch, detectron2
import numpy as np
import os, json, cv2, random
import tifffile
from os.path import join
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
from tqdm.notebook import tqdm
import torch

json_fn = 'LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json'
image_path= 'LIVECell_dataset_2021/images/livecell_test_images'
weight_fn = 'LIVECell_dataset_2021/models/Anchor_based/ALL/LIVECell_anchor_based_model.pth'

In [None]:
coco = COCO(annotation_file=json_fn)

from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
cfg = get_cfg()
cfg.merge_from_file("LIVECell/model/anchor_based/livecell_config.yaml")
cfg.MODEL.WEIGHTS = weight_fn
predictor = DefaultPredictor(cfg)

In [None]:
import torch.nn.functional as F

def get_annotations(id):
    bboxes = []
    masks = []
    ann_ids = coco.getAnnIds(imgIds=id)
    for ann_id in ann_ids:
        ann = coco.anns[ann_id]
        bbox = ann['bbox']
        bbox = np.array([bbox[1], bbox[0], bbox[1]+bbox[3], bbox[0]+bbox[2]])
        bboxes.append(bbox)
        mask = coco.annToMask(ann)
        masks.append(mask)
    gt_bboxes = np.stack(bboxes)
    gt_bboxes = torch.from_numpy(gt_bboxes).cuda()
    gt_masks = np.stack(masks).astype(bool)
    gt_masks = torch.from_numpy(gt_masks).cuda()
    return gt_bboxes, gt_masks

def get_predictions(id, th=0.5):
    imginfo = coco.imgs[id]
    fn = imginfo['file_name']
    fn = join(image_path, fn.split('_')[0], fn)
    img = tifffile.imread(fn)

    preds = predictor(np.repeat(img[...,None], 3 , axis=-1))
    bboxes = preds['instances'].pred_boxes.tensor
    masks = preds['instances'].pred_masks
    scores = preds['instances'].scores
    bboxes = bboxes[scores >= th]
    bboxes = bboxes[:, (1,0,3,2)]
    masks = masks[scores >= th]
    return bboxes, masks

def has_intersection(gt_boxes, boxes):
    y_min1, x_min1, y_max1, x_max1 = torch.split(gt_boxes, 1, -1)
    y_min2, x_min2, y_max2, x_max2 = torch.split(boxes, 1, -1)

    y_min_max = torch.minimum(y_max1, torch.t(y_max2))
    y_max_min = torch.maximum(y_min1, torch.t(y_min2))
    x_min_max = torch.minimum(x_max1, torch.t(x_max2))
    x_max_min = torch.maximum(x_min1, torch.t(x_min2))

    intersect_heights = y_min_max - y_max_min
    intersect_widths = x_min_max - x_max_min
    return (intersect_heights > 0) & (intersect_widths > 0)

def get_intersect(gt_m, m, gt_ids, pred_ids):
    pad_to = ((gt_ids.shape[0] - 1) // 256 + 1) * 256

    pad_gt_ids = torch.zeros([pad_to], device=gt_ids.device, dtype=gt_ids.dtype)
    pad_gt_ids[:gt_ids.shape[0]] = gt_ids
    pad_pred_ids = torch.zeros([pad_to], device=pred_ids.device, dtype=pred_ids.dtype)
    pad_pred_ids[:pred_ids.shape[0]] = pred_ids
    pad_gt_ids = torch.reshape(pad_gt_ids, [-1, 256])
    pad_pred_ids = torch.reshape(pad_pred_ids, [-1, 256])

    intersects = []
    for x1, x2 in zip(pad_gt_ids, pad_pred_ids):
        intersects.append(torch.count_nonzero(gt_m[x1] & m[x2], dim=(1,2)))
    intersects = torch.reshape(torch.stack(intersects), [-1])

    return intersects[:gt_ids.shape[0]]

class AJI:
  def __init__(self):
      self.c = 0
      self.u = 0
  
  def update(self, gt_m, m, gt_b, b):
      gt_ids, pred_ids = torch.where(has_intersection(gt_b, b))
      v = get_intersect(gt_m, m, gt_ids, pred_ids)
      areas = torch.count_nonzero(m, axis=(1,2))
      gt_areas = torch.count_nonzero(gt_m, axis=(1,2))

      intersects = np.zeros((gt_b.shape[0], b.shape[0]))
      intersects[(gt_ids.cpu(), pred_ids.cpu())] = v.cpu()

      ious = np.zeros((gt_b.shape[0], b.shape[0]))
      ious[(gt_ids.cpu(), pred_ids.cpu())] = (v / (areas[pred_ids] + gt_areas[gt_ids] - v + 1e-8)).cpu()
      best_matches = ious.argmax(axis=1)
      best_intersects = np.take_along_axis(intersects, best_matches[:, None], axis=1)
      self.u += best_intersects.sum()

      areas = areas.cpu().numpy()
      self.c += gt_areas.cpu().numpy().sum() + areas[best_matches].sum() - best_intersects.sum()
      areas[best_matches] = 0
      self.c += areas.sum()

  def result(self):
      return self.u / self.c


In [None]:
aji = {}
for id in tqdm(coco.getImgIds()):
    c = coco.imgs[id]['file_name'].split('_')[0]
    gt_b, gt_m = get_annotations(id)
    b,m = get_predictions(id, .21)
    if not c in aji:
        aji[c] = AJI()
    aji[c].update(gt_m, m, gt_b, b)
    # print(f'{gt_b.shape[0]} : {b.shape[0]}')

In [None]:
for c in sorted(aji.keys()):
  print(c)
  print(aji[c].result())