In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
# import sys
from os.path import join

data_path = 'tissuenet_1.0/test'

In [None]:
import torch
import torch.nn.functional as F

def has_intersection(gt_boxes, boxes):
    y_min1, x_min1, y_max1, x_max1 = torch.split(gt_boxes, 1, -1)
    y_min2, x_min2, y_max2, x_max2 = torch.split(boxes, 1, -1)

    y_min_max = torch.minimum(y_max1, torch.t(y_max2))
    y_max_min = torch.maximum(y_min1, torch.t(y_min2))
    x_min_max = torch.minimum(x_max1, torch.t(x_max2))
    x_max_min = torch.maximum(x_min1, torch.t(x_min2))

    intersect_heights = y_min_max - y_max_min
    intersect_widths = x_min_max - x_max_min
    return (intersect_heights > 0) & (intersect_widths > 0)

def get_intersect(gt_m, m, gt_ids, pred_ids):
    pad_to = ((gt_ids.shape[0] - 1) // 256 + 1) * 256

    pad_gt_ids = torch.zeros([pad_to], device=gt_ids.device, dtype=gt_ids.dtype)
    pad_gt_ids[:gt_ids.shape[0]] = gt_ids
    pad_pred_ids = torch.zeros([pad_to], device=pred_ids.device, dtype=pred_ids.dtype)
    pad_pred_ids[:pred_ids.shape[0]] = pred_ids
    pad_gt_ids = torch.reshape(pad_gt_ids, [-1, 256])
    pad_pred_ids = torch.reshape(pad_pred_ids, [-1, 256])

    intersects = []
    for x1, x2 in zip(pad_gt_ids, pad_pred_ids):
        intersects.append(torch.count_nonzero(gt_m[x1] & m[x2], dim=(1,2)))
    intersects = torch.reshape(torch.stack(intersects), [-1])

    return intersects[:gt_ids.shape[0]]

class AJI:
  def __init__(self):
      self.c = 0
      self.u = 0
  
  def update(self, gt_m, m, gt_b, b):
      gt_ids, pred_ids = torch.where(has_intersection(gt_b, b))
      v = get_intersect(gt_m, m, gt_ids, pred_ids)
      areas = torch.count_nonzero(m, axis=(1,2))
      gt_areas = torch.count_nonzero(gt_m, axis=(1,2))

      intersects = np.zeros((gt_b.shape[0], b.shape[0]))
      intersects[(gt_ids.cpu(), pred_ids.cpu())] = v.cpu()

      ious = np.zeros((gt_b.shape[0], b.shape[0]))
      ious[(gt_ids.cpu(), pred_ids.cpu())] = (v / (areas[pred_ids] + gt_areas[gt_ids] - v + 1e-8)).cpu()
      best_matches = ious.argmax(axis=1)
      best_intersects = np.take_along_axis(intersects, best_matches[:, None], axis=1)
      self.u += best_intersects.sum()

      areas = areas.cpu().numpy()
      self.c += gt_areas.cpu().numpy().sum() + areas[best_matches].sum() - best_intersects.sum()
      areas[best_matches] = 0
      self.c += areas.sum()

  def result(self):
      return self.u / self.c


In [None]:
from skimage.measure import regionprops
from deepcell.applications import Mesmer

app = Mesmer()

gt_x = np.load(join(data_path, 'X.npy'), mmap_mode='r+')
gt_y = np.load(join(data_path, 'y.npy'), mmap_mode='r+')

aji = AJI()

for x, y in tqdm(zip(gt_x, gt_y)):
    pr = app.predict(x[None,...])
    label_in_ch0 = np.argmax(np.count_nonzero(y, axis=(0,1))) == 0
    y = y[..., 0] if label_in_ch0 else y[..., 1]
    pr = pr[0, ..., 0]
    # print(f'{gt.max()} {pr.max()}')
    gt_box = np.zeros([y.max(), 4], int)
    pr_box = np.zeros([pr.max(),4], int)
    for p in regionprops(y):
        gt_box[p['label']-1] = p['bbox'] 
    for p in regionprops(pr):
        pr_box[p['label']-1] = p['bbox'] 
    gt_b = torch.from_numpy(gt_box)
    b = torch.from_numpy(pr_box)
    gt_m = torch.from_numpy((y == np.arange(1, y.max()+1)[:, None, None]))
    m = torch.from_numpy((pr == np.arange(1, pr.max()+1)[:, None, None]))
    aji.update(gt_m, m, gt_b, b)
    #print(aji.result())

print(aji.result())