In [None]:
!pip install pymaxflow
!git clone https://github.com/jiyuuchc/cellcutter.git

In [4]:
import sys
sys.path.append('cellcutter')

import time
from os.path import join

import numpy as np
from numpy.random import default_rng
np.set_printoptions(precision=4)

import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['image.cmap'] = 'gray'

import cellcutter
import cellcutter.utils

### Load some image data

In [None]:
data = np.load('cellcutter/data/a1data.npz')
train_data = data['data']
input_img = train_data[...,0]
nucleus_img = train_data[..., 2]

# check the images
fig, ax = plt.subplots(1,2)
ax[0].imshow(input_img[100:300,200:400])
ax[0].axis('off')
ax[1].imshow(nucleus_img[100:300,200:400])
ax[1].axis('off')

### Defined the area of analysis

We use the standard graph cut algorithm to produce a binary mask that defines the area of cells. 

The cellcutter.utils module comes with a convenient function to do this.

In [None]:
mask = cellcutter.utils.graph_cut(input_img, prior = 0.985, max_weight=10, sigma = 0.03)

#check results
fig, ax = plt.subplots(1,2)
ax[0].imshow(input_img[100:300,200:400])
ax[0].axis('off')
ax[1].imshow(mask[100:300,200:400])
ax[1].axis('off')

### Get the marker locations
We compute the marker locations from the the nucleus image. 

Here we just use a simple blob detection algorithm, which works well enough for this demonstration. But you can get better results with algorithms that are dedicated to nucleus segmentation. 

In [None]:
# Use a simple blob detection function
from cellcutter.markers import blob_detection
markers = blob_detection(nucleus_img, 10, 5, 0.1)

# check the result
xs, ys = markers.transpose()
label = np.zeros(shape = nucleus_img.shape, dtype=int)
label[(xs,ys)] = 1
fig, ax = plt.subplots(1,2)
ax[0].imshow(nucleus_img[100:300,200:400])
ax[0].axis('off')
ax[1].imshow(nucleus_img[100:300,200:400])
ax[1].axis('off')
for y,x in markers:
  y -= 100
  x -= 200
  c = plt.Circle((x, y), 10, color='gray', linewidth=2, fill=False)
  ax[1].add_patch(c)


### CNN segmentation

Here we train a CNN segmentation model on the fluorescence input. It will take a couple minutes.

In [None]:
dataset = cellcutter.Dataset(input_img, markers, mask_img = ~mask) # actually need the inverse of the mask

start = time.time()
model = cellcutter.UNet4(bn=True)
cellcutter.train_self_supervised(dataset, model, n_epochs = 30)

print('Elapsed time: %f'%(time.time() - start))

### Check the segmentation results

In [None]:
from skimage.color import label2rgb

label = cellcutter.utils.draw_label(dataset, model, np.zeros_like(input_img, dtype=int))
rgb = label2rgb(label, bg_label = 0)
border = cellcutter.utils.draw_border(dataset, model, np.zeros_like(input_img, dtype=int))
fig, ax = plt.subplots(1,3)
ax[0].imshow(input_img[100:300,200:400])
ax[0].axis('off')
ax[1].imshow(rgb[100:300,200:400])
ax[1].axis('off')
ax[2].imshow(border[100:300,200:400])
ax[2].axis('off')
