In [None]:
import glob
import os
import logging
import sys

import numpy as np
import tifffile
from skimage import io
import skimage as ski
from cellpose import models
import napari
import torch

Set up logging so the cellpose will print information during training

In [None]:
r = logging.getLogger()
r.setLevel(logging.INFO)
h = logging.StreamHandler(sys.stdout)
h.setLevel(logging.INFO)
r.addHandler(h)

In [None]:
viewer = napari.Viewer()

In [None]:
images = sorted(glob.glob("Data/Training/sample_images/*.tif"))
len(images)

In [None]:
if sys.platform == 'darwin':
    d = torch.device('mps')
    model = models.Cellpose(gpu=False, device=d, model_type='cyto2')
else:
    # change gpu=True if on windows, and get rid of device
    model = models.Cellpose(gpu=True, model_type='cyto2')

In [None]:
fname = images[6]  # 14, 6
x = tifffile.imread(fname)
x.shape

In [None]:
viewer.layers.clear()
viewer.add_image(x)

Try the cellpose cyto2 model

In [None]:

masks, flows, _, _ = model.eval(x, channels=[0, 0], diameter=30)

In [None]:
viewer.add_labels(masks)

### Label images for training
- View image
- model.eval the image
- add labels
- edit the labels
- get labels out of the viewer
- create a stack with the image and the mask
- save to file

In [None]:
masks = viewer.layers[-1].data
tx = np.stack([x, masks])
tx.shape, tx.dtype

In [None]:
bn = os.path.basename(fname)
tifffile.imwrite(f"Data/Training/for_training/{bn}", tx.astype(np.uint16))

In [None]:
files = sorted(glob.glob("Data/Training/training_images/*.tif"))
len(files)

Read the images for training. It is of the utmost importance that the masks be label images, not binary images.
I know some of the images in this traing seet are binary, so run the skimage label on the seconds channel to turn it into a labeled image.

In [None]:
images = list()
masks = list()

for f in files:
    x = tifffile.imread(f)
    images.append(x[0])
    masks.append(ski.measure.label(x[1].astype(np.uint16)))
    

In [None]:
masks[5].max(), masks[5].min()

### Training

Use models.CellposeModel now instead of models.Cellpose


In [None]:
if sys.platform == 'darwin':
    d = torch.device('mps')
    model = models.CellposeModel(gpu=False, device=d, model_type='cyto2')
else:
    # change gpu=True if on windows, and get rid of device
    model = models.CellposeModel(gpu=True, model_type='cyto2')

In [None]:
model.train(images, masks, channels=[0, 0], save_path='models', n_epochs=300,
            nimg_per_epoch=24, model_name='custom', batch_size=16,
            min_train_masks=1)

Now check how well the model is doing.

This model only has 3 return values, so get rid of the last 

In [None]:
idx = 7
x = tifffile.imread(files[idx])
masks, flows, _ = model.eval(x[0], channels=[0, 0],
                             cellprob_threshold=-3, flow_threshold=.5)

In [None]:
viewer.layers.clear()
viewer.add_image(x[0])
viewer.add_labels(masks)

In [None]:
viewer.add_image(flows[2])