In [None]:
from cellpose import (
    core,
    models,
    denoise
)
from torch.utils.data import (
    ConcatDataset,
    DataLoader
)

import torch
import numpy as np
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt

assert core.use_gpu() == 1, "No GPU detected"


In [None]:
model = denoise.CellposeDenoiseModel(gpu=1, model_type='cyto3', restore_type="denoise_cyto3")

# Dataset

In [None]:
import numpy as np

class BBox:
	min: np.ndarray
	max: np.ndarray

	def __init__(
		self,
		_min: np.ndarray,
		_max: np.ndarray
	):
		assert _min.shape == _max.shape, "Both must have the same dimension"
		self.min = _min
		self.max = _max

	def __repr__(self) -> str:
		return f"BBox({self.min}; {self.max})"

	def intersect(self, bb):
		assert self.min.shape == bb.min.shape, "Both must have the same dimension"

		return BBox(
			_min = np.array([self.min, bb.min]).max(axis=0),
			_max = np.array([self.max, bb.max]).min(axis=0)
		)
	
	def union(self, bb):
		assert self.min.shape == bb.min.shape, "Both must have the same dimension"

		return BBox(
			_min = np.array([self.min, bb.min]).min(axis=0),
			_max = np.array([self.max, bb.max]).max(axis=0)
		)

	def IoU(self, bb) -> float:
		a_its = self.intersect(bb).area()
		a_bb = bb.area()
		a_self = self.area()

		# Area(A U B) = Area(A) + Area(B) - Area(A inter B)
		return a_its / (a_bb + a_self - a_its)

	def area(self) -> float:
		return (self.max - self.min).prod()


In [None]:
from dataclasses import dataclass

@dataclass
class MendeleyCategory:
    supercategory: str
    name: str


@dataclass
class MendeleyCategoryEntryCell:
    category: MendeleyCategory
    bbox: BBox

In [None]:
class MendeleyDatasetEntry:
    width: int
    height: int
    filename: str

    cells: list[MendeleyCategoryEntryCell]

    def __init__(self, width: int, height: int, filename: str):
        self.width = width
        self.height = height
        self.filename = filename
        self.cells = []
    
    def add_entry(self, entry: MendeleyCategoryEntryCell):
        self.cells.append(entry)


In [None]:
import os
import json
from typing import Tuple

import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class MendeleyDataset(Dataset):
    images_path: str
    entries: list[MendeleyDatasetEntry]
    categories: dict[int, MendeleyCategory]

    def __init__(self, annotation_path: str, images_path: str):
        self.images_path = images_path
        with open(annotation_path) as af:
            infos = json.load(af)
            entries = {
                img["id"]: (
                    MendeleyDatasetEntry(
                        width = img["width"],
                        height = img["height"],
                        filename = img["file_name"]
                    )
                ) for img in infos["images"]
            }
            self.categories = {
                cat["id"]: (
                    MendeleyCategory(
                        name = cat["name"],
                        supercategory = cat["supercategory"]
                    )
                ) for cat in infos["categories"]
            }
            for anno in infos["annotations"]:
                top_left = np.array(anno["bbox"][0:2])
                size = np.array(anno["bbox"][2:4])
                bbox = BBox(
                    top_left,
                    top_left + size,
                )
                entries[anno["image_id"]].add_entry(
                    MendeleyCategoryEntryCell(
                        category = self.categories[anno["category_id"]],
                        bbox = bbox
                    )
                )
            self.entries = list(entries.values())
    
    def __getitem__(self,idx: int) -> Tuple[np.array, np.array, list[MendeleyCategoryEntryCell]]:
        entry = self.entries[idx]
        img = np.array(Image.open("{}/{}".format(
            self.images_path,
            entry.filename
        )))
        real_mask = np.zeros(img.shape, dtype=bool)
        for info in entry.cells:
            real_mask[
                int(info.bbox.min[1]): int(np.ceil(info.bbox.max[1])),
                int(info.bbox.min[0]): int(np.ceil(info.bbox.max[0]))
            ] = True
        return img / 255., real_mask, entry.cells

    def __len__(self) -> int:
        return len(self.entries)


In [None]:
dt = MendeleyDataset("images/livecell_coco_test.json", "images/livecell_test_images")

In [None]:
img, real_mask, infos = dt[1]

ax = plt.subplot(1, 2, 1)
ax.imshow(img)
ax.axis('off')

ax = plt.subplot(1, 2, 2)
ax.imshow(img)
for info in infos:
    ax.plot([info.bbox.min[0], info.bbox.max[0]], [info.bbox.min[1]]*2, color='red')
    ax.plot([info.bbox.min[0], info.bbox.max[0]], [info.bbox.max[1]]*2, color='red')
    ax.plot([info.bbox.min[0]]*2, [info.bbox.min[1], info.bbox.max[1]], color='red')
    ax.plot([info.bbox.max[0]]*2, [info.bbox.min[1], info.bbox.max[1]], color='red')
ax.axis('off')

# Test

In [None]:
masks, flows, styles, imgs_dn = model.eval(img, diameter=None, flow_threshold=None, channels=[0,0])

In [None]:
plt.imshow(~np.isin(masks, [0]))

In [None]:
plt.imshow(img)

In [None]:
plt.imshow(imgs_dn)

In [None]:
from cellpose import utils

plt.imshow(imgs_dn)
outline = utils.outlines_list(masks)
for o in outline:
    plt.plot(o[:,0], o[:,1], color='red')
for info in infos:
    plt.plot([info.bbox.min[0], info.bbox.max[0]], [info.bbox.min[1]]*2, color='darkgreen')
    plt.plot([info.bbox.min[0], info.bbox.max[0]], [info.bbox.max[1]]*2, color='darkgreen')
    plt.plot([info.bbox.min[0]]*2, [info.bbox.min[1], info.bbox.max[1]], color='darkgreen')
    plt.plot([info.bbox.max[0]]*2, [info.bbox.min[1], info.bbox.max[1]], color='darkgreen')

In [None]:
plt.imshow(real_mask)

In [None]:
ax = plt.subplot(1, 2, 1)
ax.imshow(~real_mask & (~np.isin(masks, [0])))
ax = plt.subplot(1, 2, 2)
ax.imshow(real_mask & ~(~np.isin(masks, [0])))

# Benchmark

In [None]:
def compute_acc_prec(masks: np.array, real_mask: np.array) -> tuple[float, float]:
	masks = np.isin(masks, [0])
	real_mask = np.isin(real_mask, [0])

	tot = masks.shape[0]*masks.shape[1]
	fp_plus_fn = np.count_nonzero(masks ^ real_mask)
	tp = np.count_nonzero(masks & real_mask)
	fp = np.count_nonzero(masks & ~real_mask)

	accuracy = 1 - fp_plus_fn / tot
	precision = tp / (tp + fp)

	return accuracy, precision

In [None]:
from multiprocessing.pool import ThreadPool

def compute_accuracy_and_precision(filter_ = None) -> tuple[float, float]:
	def running_ex(i):
		print(f"{i} / {len(dt)}")
		img, real_mask, _ = dt[i]
		if filter_ is not None:
			img = filter_(img)
		masks, flows, styles, imgs_dn = model.eval(img, diameter=None, flow_threshold=None, channels=[2,1])
		
		return compute_acc_prec(masks, real_mask)

	pool = ThreadPool(16)
	acc_prec = pool.map(running_ex, range(len(dt)))

	sum_accuracy = sum(map(
		lambda x: x[0],
		acc_prec
	))
	sum_precision = sum(map(
		lambda x: x[1],
		acc_prec
	))
	return (sum_accuracy / len(dt)), (sum_precision / len(dt))


In [None]:
output = []
rng = np.random.default_rng()
noise_level = np.arange(0., 0.3+0.01, 0.03)

for noise in noise_level:
	print(f"Testing dataset with noise of {noise}")
	def img_filter(img: np.array) -> np.array:
		out = img + noise * rng.normal(size=img.shape)
		return np.clip(out, 0., 1.)
	
	output.append(compute_accuracy_and_precision(img_filter))
output = np.array(output)

In [None]:
ax = plt.subplot(2, 1, 1)
ax.plot(noise_level, output[:,0], label=f"Accuracy with loc={0}")
ax.legend()
ax = plt.subplot(2, 1, 2)
ax.plot(noise_level, output[:,1], label=f"Precision with loc={0}")
ax.legend()
plt.show()

In [None]:
output_poisson = []
lam = 4.

for noise in noise_level:
	print(f"Testing dataset with noise of {noise}")
	def img_filter(img: np.array) -> np.array:
		out = img + noise * rng.poisson(lam=lam, size=img.shape)
		return np.clip(out, 0., 1.)
	
	output_poisson.append(compute_accuracy_and_precision(img_filter))
output_poisson = np.array(output_poisson)

In [None]:
ax = plt.subplot(2, 1, 1)
ax.plot(noise_level, output_poisson[:,0], label=f"Accuracy with lam={lam}")
ax.legend()
ax = plt.subplot(2, 1, 2)
ax.plot(noise_level, output_poisson[:,1], label=f"Precision with lam={lam}")
ax.legend()
plt.show()