In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import cv2
from tqdm.notebook import tqdm
from matplotlib import patches
from skimage.color import label2rgb
from skimage.measure import regionprops

import sys
from os.path import join

import lacss

data_path = 'tissuenet_1.0'
model_path = 'tissuenet_1.0'

In [None]:
def tissue_net_gen_fn(data_path):
    X = np.load(join(data_path, 'X.npy'), mmap_mode='r+')
    Y = np.load(join(data_path, 'y.npy'), mmap_mode='r+')
    platforms = np.load(join(data_path, 'platform_list.npy'))
    tissues = np.load(join(data_path, 'tissue_list.npy'))
    for x, y, pf, t in zip(X, Y, platforms, tissues):
        img = x.astype('float32')
        label_in_ch0 = np.argmax(np.count_nonzero(y, axis=(0,1))) == 0
        y = y[..., 0] if label_in_ch0 else y[..., 1]
        # y = y[..., 0]
        binary_mask = (y > 0).astype('float32')
        locs = [prop['centroid'] for prop in regionprops(y)]
        mis = []
        mi_lengths= []
        bboxes = []
        for prop in regionprops(y):
            bboxes.append(prop['bbox'])
            mi = np.array(np.where(prop['image'])).transpose()
            mi = mi + bboxes[-1][:2]
            mi_lengths.append(mi.shape[0])
            mis.append(mi)
        
        bboxes = np.array(bboxes, dtype='float32')
        mis = tf.RaggedTensor.from_row_lengths(np.concatenate(mis), mi_lengths)

        yield {
            'image': img,
            'locations': locs,
            'binary_mask': binary_mask,
            'bboxes': bboxes,
            'mask_indices': mis,
            'platform': pf,
            'tissue': t,
        }

ds_test = tf.data.Dataset.from_generator(
    lambda: tissue_net_gen_fn(join(data_path, 'test')),
    output_signature = {
        'image': tf.TensorSpec([None, None, 2], dtype=tf.float32),
        'locations': tf.TensorSpec([None, 2], dtype=tf.float32),
        'binary_mask': tf.TensorSpec([None, None], dtype=tf.float32),
        'bboxes': tf.TensorSpec([None, 4], dtype=tf.float32),
        'mask_indices': tf.RaggedTensorSpec([None, None, 2], tf.int64, 1),
        'platform': tf.TensorSpec([], tf.string),
        'tissue': tf.TensorSpec([], tf.string),
    }
)

In [None]:
import torch
import torch.nn.functional as F

def pred2label(y, th=0.5):
    coords = y['instance_coords'][0]
    seg = (y['instance_output'][0].numpy() >= th).astype('uint8')
    n_patches, patch_size, _, _ = seg.shape
    seg_c = np.zeros_like(seg)

    for k in range(seg.shape[0]):
        contours, _ = cv2.findContours(seg[k], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cv2.fillPoly(seg_c[k], pts=contours[:1], color=1)

    too_small = np.count_nonzero(seg_c, axis=(1,2,3)) < 15
    seg_c = seg_c[~ too_small]
    coords = coords[~ too_small]

    pps = np.argwhere(seg_c[...,0] > 0)
    indices = tf.gather_nd(coords, pps)

    label = tf.zeros([256,256], dtype=tf.int32) 
    label = tf.tensor_scatter_nd_max(label, indices, seg_c.shape[0] - pps[:,0])
    return label

def x2label(x):
    indices = x['mask_indices'].values
    ids = x['mask_indices'].value_rowids()
    gt_label = tf.scatter_nd(indices, ids+1, [256,256])
    return gt_label

In [None]:

import json
cfg = json.load(open(join(model_path, 'config.json')))
cfg['test_min_score']=.35
model = lacss.models.LacssModel.from_config(cfg)
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer)
model.load_weights(join(model_path, 'chkpt'))

In [None]:
from deepcell_toolbox.metrics import Metrics

gts = []
preds = []

# model._config_dict['test_min_score'] = 0.4
# model._config_dict['detection_nms_threshold'] = 1.5

mask_AP = lacss.metrics.MaskMeanAP([.5, .55, .6, .65, .7, .75, .8, .85, .9, .95])

for x in tqdm(ds_test):
    gts.append(x2label(x).numpy())
    y = model(x)
    preds.append(pred2label(y).numpy())

    scores = y['pred_location_scores'][0]
    patches = y['instance_output'][0]
    coords = y['instance_coords'][0]
    pred_bboxes = lacss.ops.bboxes_of_patches(patches, coords)
    pred = (patches, coords, pred_bboxes)
    gt_bboxes=x['bboxes']
    gt_mi=x['mask_indices']
    gt = (gt_mi,gt_bboxes)
    mask_AP.update_state(gt, pred, scores)

print(mask_AP.result().numpy())

m = Metrics('lacss')
stat = m.calc_object_stats(np.array(gts), np.array(preds))

f1 = stat['correct_detections'].sum() * 2 / (stat['n_pred'].sum() + stat['n_true'].sum())
print(f'f1 = {f1}')

In [None]:
platforms = np.load(join(data_path, 'test', 'platform_list.npy'))
for pt in np.unique(platforms):
  ds_t = ds_test.filter(lambda x: x['platform']==pt)
  
  mask_AP = lacss.metrics.MaskMeanAP([.5, .55, .6, .65, .7, .75, .8, .85, .9, .95])
  for x in tqdm(ds_t):
      x['img_id'] = 0
      xx = lacss.data.parse_test_data_func(x)
      y = model(xx)

      scores = y['pred_location_scores'][0]
      patches = y['instance_output'][0]
      coords = y['instance_coords'][0]
      pred_bboxes = lacss.ops.bboxes_of_patches(patches, coords)
      pred = (patches, coords, pred_bboxes)

      gt_bboxes=x['bboxes']
      gt_mi=x['mask_indices']
      gt = (gt_mi,gt_bboxes)

      mask_AP.update_state(gt, pred, scores)

  print(pt)
  print(mask_AP.result().numpy())

In [None]:
tissues = np.load(join(data_path, 'test', 'tissue_list.npy'))
for t in np.unique(tissues):
  ds_t = ds_test.filter(lambda x: x['tissue']==t)
  
  mask_AP = lacss.metrics.MaskMeanAP([.5, .55, .6, .65, .7, .75, .8, .85, .9, .95])
  for x in tqdm(ds_t):
      x['img_id'] = 0
      xx = lacss.data.parse_test_data_func(x)
      y = model(xx)

      scores = y['pred_location_scores'][0]
      patches = y['instance_output'][0]
      coords = y['instance_coords'][0]
      pred_bboxes = lacss.ops.bboxes_of_patches(patches, coords)
      pred = (patches, coords, pred_bboxes)

      gt_bboxes=x['bboxes']
      gt_mi=x['mask_indices']
      gt = (gt_mi,gt_bboxes)

      mask_AP.update_state(gt, pred, scores)

  print(t)
  print(mask_AP.result().numpy())

In [None]:
from lacss.metrics import AJI

# model._config_dict['test_min_score']=.35
aji = AJI(.25)
for x in tqdm(ds_test):
    y = model(x, training=False)
    patches = y['instance_output'][0]
    coords = y['instance_coords'][0]
    gt_bboxes=x['bboxes']
    gt_mi=x['mask_indices']
    aji.update(gt_mi, gt_bboxes, patches, coords)
print (aji.result())

In [None]:
fig, axs=plt.subplots(2,6, figsize=(30,10))
choices = [np.random.choice(np.where(platforms==pt)[0]) for pt in np.unique(platforms)]
X = np.load(join(data_path, 'test', 'X.npy'), mmap_mode='r+')
Y = np.load(join(data_path, 'test', 'y.npy'), mmap_mode='r+')

for k, c in enumerate(choices):
    img = X[c].astype('float32')
    gt = Y[c]
    label_in_ch0 = np.argmax(np.count_nonzero(gt, axis=(0,1))) == 0
    gt = gt[..., 0] if label_in_ch0 else gt[..., 1]

    x = {'image': img}
    y = model(x)

    img = (np.pad(img, [[0,0],[0,0],[0,1]])[:,:,::-1] * 255).astype('uint8')
    # mask = tf.scatter_nd(x['mask_indices'].values, x['mask_indices'].value_rowids() + 1, x['binary_mask'].shape)
    mask_rgb = label2rgb(gt, bg_label=0)

    coords = y['instance_coords'][0]
    seg = y['instance_output'][0][...,0]
    n_patches, patch_size, _ = seg.shape
    page_n = tf.tile(tf.range(n_patches)[:,None,None,None], [1, patch_size, patch_size, 1])
    coords_ext = tf.concat([page_n, coords], axis=-1)
    stack_shape = tuple([n_patches,]) + x['image'].shape[:2]
    img_stack = tf.scatter_nd(coords_ext, seg, stack_shape)
    img_stack = (img_stack.numpy() >= 0.5).astype('uint8')
    contour_img = mask_rgb.copy()
    for page in img_stack:
        contours, _ = cv2.findContours(page, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(contour_img, contours[:1], -1, (128,128,128), 1, cv2.LINE_AA)
        tmp = np.zeros_like(img)
        cv2.drawContours(tmp, contours[:1], -1, (255,0,0), 1, cv2.LINE_AA)
        img += tmp

    # print(x['platform'].numpy())
    # print(x['tissue'].numpy())

    ax = axs[:, k]
    ax[0].imshow(img)
    ax[1].imshow(contour_img)
    ax[0].axis('off')
    ax[1].axis('off')

plt.tight_layout()
plt.show()


In [None]:
fig, axs=plt.subplots(2,6, figsize=(30,10))
choices = [np.random.choice(np.where(tissues==t)[0]) for t in np.unique(tissues)]
X = np.load(join(data_path, 'test', 'X.npy'), mmap_mode='r+')
Y = np.load(join(data_path, 'test', 'y.npy'), mmap_mode='r+')

#model._config_dict['detection_nms_threshold'] = 1.0

for k, c in enumerate(choices):
    img = X[c].astype('float32')
    gt = Y[c]
    label_in_ch0 = np.argmax(np.count_nonzero(gt, axis=(0,1))) == 0
    gt = gt[..., 0] if label_in_ch0 else gt[..., 1]

    x = {'image': img}
    y = model(x)

    img = (np.pad(img, [[0,0],[0,0],[0,1]])[:,:,::-1] * 255).astype('uint8')
    # mask = tf.scatter_nd(x['mask_indices'].values, x['mask_indices'].value_rowids() + 1, x['binary_mask'].shape)
    mask_rgb = label2rgb(gt, bg_label=0)

    coords = y['instance_coords'][0]
    seg = y['instance_output'][0][...,0]
    n_patches, patch_size, _ = seg.shape
    page_n = tf.tile(tf.range(n_patches)[:,None,None,None], [1, patch_size, patch_size, 1])
    coords_ext = tf.concat([page_n, coords], axis=-1)
    stack_shape = tuple([n_patches,]) + x['image'].shape[:2]
    img_stack = tf.scatter_nd(coords_ext, seg, stack_shape)
    img_stack = (img_stack.numpy() >= 0.5).astype('uint8')
    contour_img = mask_rgb.copy()
    for page in img_stack:
        contours, _ = cv2.findContours(page, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(contour_img, contours[:1], -1, (128,128,128), 1, cv2.LINE_AA)
        tmp = np.zeros_like(img)
        cv2.drawContours(tmp, contours[:1], -1, (255,0,0), 1, cv2.LINE_AA)
        img += tmp

    # print(x['platform'].numpy())
    # print(x['tissue'].numpy())

    ax = axs[:, k]
    ax[0].imshow(img)
    ax[1].imshow(contour_img)
    ax[0].axis('off')
    ax[1].axis('off')

plt.tight_layout()
plt.show()
