In [1]:
import argparse
import os
import sys
from pathlib import Path

import pytorch_lightning as pl
import torch
import yaml
from faim_ipa.utils import get_git_root

sys.path.append(str(get_git_root()))

from source.rdcnet.model import RDCNet2d
from source.data.MoNuSeg import MoNuSeg

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [2]:
dm = MoNuSeg(get_git_root() / "raw_data" / "MoNuSeg_dataset.pth")

In [3]:
model = RDCNet2d.load_from_checkpoint("/home/tibuch/pytorch-rdc-net/processed_data/0a550d5_MoNuSeg-baseline/lightning_logs/version_0/checkpoints/rdcnet-epoch=80-f1=0.78.ckpt")

In [4]:
from napari import Viewer
import numpy as np

In [5]:
viewer = Viewer()



In [6]:
for x, y in dm.test_dataloader():
    break

In [7]:
viewer.add_image(x[0].detach().moveaxis(0, -1), rgb=True)

<Image layer 'Image' at 0x7aad770ef890>

In [8]:
viewer.add_labels(y[0,0].detach().cpu(), name='GT')

<Labels layer 'GT' at 0x7aaedcee8da0>

In [9]:
instances = model.predict_instances(x.to(model.device))

In [10]:
viewer.add_labels(instances[0,0], name='instances')

<Labels layer 'instances' at 0x7aad8b455430>

In [11]:
model.eval()
with torch.no_grad():
    pred = model(x.to(model.device))

In [12]:
len(pred)

3

In [13]:
import torch.nn.functional as F
import torch.nn as nn
import torchist

In [14]:
embeddings, weights,  semantic_classes = pred

embeddings = embeddings.detach()
weights = weights.detach()
semantic = semantic_classes.detach()

padding = model.hparams.margin * 2
embeddings = F.pad(embeddings, (padding, padding, padding, padding))
semantic = F.pad(semantic, (padding, padding, padding, padding))

shape = embeddings.shape[-2:]

fg_mask = torch.argmax(semantic[0], dim=0).type(torch.bool)

grid = model._get_coordinate_grid(embeddings)
embeddings = (embeddings + grid)[0]
fg_embeddings = torch.round(embeddings[:, fg_mask])

In [15]:
votes = torchist.histogramdd(
                fg_embeddings.moveaxis(0, -1),
                bins=(shape[0], shape[1]),
                low=(0, 0),
                upp=shape,
            )
# votes[votes < model.hparams.margin] = 0
# votes[fg_mask == 0] = 0

# threshold = torch.std(votes.type(torch.float32)).detach()
# print(threshold)
# votes[votes <= threshold] = 0

In [16]:
#viewer.add_image(votes[padding:-padding, padding:-padding].detach().cpu(), name='votes')

In [17]:
max_filtered = torch.clip(F.max_pool2d(
    votes.unsqueeze(0).type(torch.float32),
    kernel_size=model.hparams.margin * 2 +1,
    stride=1,
    padding=model.hparams.margin,
)[0]-1, 0, votes.max())
# votes[votes > 0] += 1

# select embeddings which are less than margin away from any center
centers = (torch.clip(votes - max_filtered, 0, 1).type(torch.bool) * votes ) >= model.hparams.margin
center_coords = grid[0, :, centers]

In [18]:
center_coords.shape

torch.Size([2, 674])

In [19]:
dists = torch.norm(grid[0].unsqueeze(1) - center_coords.unsqueeze(-1).unsqueeze(-1), dim=0)

In [20]:
vote_dists = torch.concat([torch.ones_like(dists[:1]) * (dists.shape[0] + 1), dists], dim=0) * (votes.unsqueeze(0) > 0)

In [21]:
closest_center = torch.argmin(vote_dists, dim=0)

In [22]:
from ignite.utils import to_onehot

In [23]:
cc_onehot = to_onehot(closest_center.unsqueeze(0).detach(), num_classes=closest_center.max()+1)[0, 1:]

In [24]:
scattered_votes = votes.type(torch.int32).unsqueeze(0) * cc_onehot

In [25]:
estimated_centers = []
covs = []
m = torch.stack([g.ravel() for g in grid[0]])
for sv, oh in zip(scattered_votes, cc_onehot):
    pred_centers = grid[0] * oh.unsqueeze(0)
    estimated_centers.append(
        torch.sum(grid[0] * sv.unsqueeze(0), dim=(1, 2)) / torch.sum(sv)
    )
    cov = torch.cov(m, fweights=sv.ravel())
    
    if not torch.any(torch.isnan(cov)) and torch.det(cov) > 0:
        covs.append(
            cov
        )
    else:
        covs.append(
            torch.eye(2).to(sv.device)
        )

estimated_centers = torch.stack(estimated_centers)
covs = torch.stack(covs)

In [26]:
#viewer.add_points(estimated_centers.detach().cpu().numpy()-padding, face_color="#00000000")

In [27]:
from tqdm.notebook import tqdm

In [28]:
#viewer.add_image(scattered_votes[i].detach().cpu().numpy())

In [29]:
n_samples = 200
sampled_label_img = np.zeros((center_coords.shape[1] + 1, *embeddings.shape[1:]), dtype=np.int32)

for i in tqdm(range(center_coords.shape[1])):
    if covs[i][0, 0] != 0 and covs[i][1,1] != 0:
        mvg = torch.distributions.MultivariateNormal(estimated_centers[i], covs[i])
        samples = mvg.sample((n_samples,)).T
    else:
        samples = torch.repeat_interleave(estimated_centers[i].unsqueeze(0), n_samples, dim=0).T
    dists = torch.norm(embeddings.unsqueeze(1) - samples.unsqueeze(-1).unsqueeze(-1), dim=0, p=None)
    sigma = model.hparams.margin * (-2 * np.log(0.5)) ** -0.5
    rois = torch.sum(torch.exp(-0.5 * (dists / sigma) ** 2) >= 0.5, dim=0).type(torch.int32)

    sampled_label_img[i+1] = rois.detach().cpu().numpy()

uncertainty = sampled_label_img.astype(np.float32) / float(n_samples)
uncertainty_counts = np.sum(uncertainty > 0, axis=0)
aggregated_uncertainty = np.true_divide(np.sum(uncertainty, axis=0), uncertainty_counts, where=uncertainty_counts>0)
# sampled_label_img = np.argmax(sampled_label_img * (uncertainty > certainty_th), axis=0)

  0%|          | 0/674 [00:00<?, ?it/s]

In [30]:
certainty_th = 0.5

In [31]:
sampled_labeling = np.argmax(sampled_label_img * (uncertainty > certainty_th), axis=0)# * fg_mask.detach().cpu().numpy()

In [32]:
cleaned_sampled_label_img = sampled_label_img.copy()
cleaned_uncertainty = uncertainty.copy()
tmp = np.sum(cleaned_sampled_label_img * (cleaned_uncertainty > 0), axis=0) > n_samples

overlapping_labels = cleaned_sampled_label_img * tmp

overlapping_labels.max()

overlapping_label_ids = np.arange(overlapping_labels.shape[0])[np.max(overlapping_labels > 0, axis=(1,2))]

merge_candidates = {}
ths = 0
for i in range(len(overlapping_label_ids)):
    current = overlapping_label_ids[i]
    for j in range(i+1,len(overlapping_label_ids)):
        target = overlapping_label_ids[j]
        intersection = np.sum(np.logical_and(cleaned_sampled_label_img[current] > ths, cleaned_sampled_label_img[target] > ths))
        union = np.sum(np.logical_or(cleaned_sampled_label_img[current] > ths, cleaned_sampled_label_img[target] > ths))
        iou = intersection / union
        if iou > 0.95:
            print(iou)
            added = False
            for k, v in merge_candidates.items():
                added += current in v

            if not added:
                if current not in merge_candidates.keys():
                    merge_candidates[current] = []
                merge_candidates[current].append(target)


print(merge_candidates)

for k, v in merge_candidates.items():
    for label_id in v:
        cleaned_sampled_label_img[k] = np.clip(cleaned_sampled_label_img[k] + cleaned_sampled_label_img[label_id], 0, n_samples)
        cleaned_uncertainty[k] = np.clip(cleaned_uncertainty[k] + cleaned_uncertainty[label_id], 0, 1)
        cleaned_sampled_label_img[label_id] *= 0
        cleaned_uncertainty[label_id] *= 0

{}


In [33]:
cleaned_labeling = np.argmax(cleaned_sampled_label_img * (cleaned_uncertainty > certainty_th), axis=0)#* fg_mask.detach().cpu().numpy()

In [34]:
from source.matching import matching

In [35]:
matching(
    y.detach().cpu().numpy()[0, 0], 
    sampled_labeling[padding:-padding, padding:-padding])

Matching(criterion='iou', thresh=0.5, fp=164, tp=509, fn=67, precision=0.7563150074294205, recall=0.8836805555555556, accuracy=0.6878378378378378, f1=0.8150520416333067, n_true=576, n_pred=673, mean_true_score=0.6667058732774522, mean_matched_score=0.7544647996224214, panoptic_quality=0.6149280752727182)

In [36]:
matching(
    y.detach().cpu().numpy()[0, 0], 
    cleaned_labeling[padding:-padding, padding:-padding])

Matching(criterion='iou', thresh=0.5, fp=164, tp=509, fn=67, precision=0.7563150074294205, recall=0.8836805555555556, accuracy=0.6878378378378378, f1=0.8150520416333067, n_true=576, n_pred=673, mean_true_score=0.6667058732774522, mean_matched_score=0.7544647996224214, panoptic_quality=0.6149280752727182)

In [37]:
matching(y.detach().cpu().numpy()[0, 0], instances)

Matching(criterion='iou', thresh=0.5, fp=151, tp=505, fn=71, precision=0.7698170731707317, recall=0.8767361111111112, accuracy=0.6946354883081155, f1=0.8198051948051948, n_true=576, n_pred=656, mean_true_score=0.6616769896613227, mean_matched_score=0.7547048436533107, panoptic_quality=0.6187109513716265)

In [38]:

viewer.add_labels(cleaned_labeling[padding:-padding, padding:-padding], name='cleand labeling')
# viewer.add_image(aggregated_uncertainty[padding:-padding, padding:-padding], name='uncertainty', visible=False)

<Labels layer 'cleand labeling' at 0x7aad8a7a6660>

In [39]:
viewer.add_labels(sampled_labeling[padding:-padding, padding:-padding], name='sampled')

<Labels layer 'sampled' at 0x7aad8a7feb40>