# BioImageIO Colab - CellPose Finetuning

This notebook demonstrates how to use image annotations collected from [BioImage.IO Colab](https://bioimage-io.github.io/bioimageio-colab/) to train and fine-tune a CellPose model.

For more details on cellpose 2.0 check out the [paper](https://www.biorxiv.org/content/10.1101/2022.04.01.486764v1) or the [talk](https://www.youtube.com/watch?v=3ydtAhfq6H0).

*Most of this notebook is based on the original [CellPose notebook](https://github.com/mouseland/cellpose).*

## 0.1. Installation

We will first install all the dependencies required for cellpose 2.0. By default the torch GPU version is installed in COLAB notebook.

In [None]:
#@markdown ###Install these required dependencies:

#@markdown * tifffile
#@markdown * matplotlib
#@markdown * opencv-python-headless
#@markdown * cellpose

## Install required dependencies

!pip install tifffile matplotlib "opencv-python-headless<=4.3" cellpose

#@markdown You will have to restart the runtime after this finishes to include the new packages. In the menu above do: Runtime --> Restart session

#@markdown Don't worry about all the errors that pip give below, these are resolved in the end. We apologise for the ugly installation - a consequence of using Colab.

## 0.2. Mount google drive

Please mount your google drive and find your [BioImage.IO Colab](https://bioimage-io.github.io/bioimageio-colab/) folder with source images and annotations. This also ensures that any models you train are saved to your google drive.

In [None]:

#@markdown ###Connect your Google Drive to Colab

from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

#@markdown * Connect to Google Drive.

#@markdown * Sign into your Google Account.

#@markdown * Click "Continue"

#@markdown * Select the permission to "See, edit, create, and delete all of your Google Drive files".

#@markdown * Click "Continue"

#@markdown Your Google Drive folder should now be available here as "gdrive".


# 1. Display manual annotations

In [None]:
#@markdown ###Select your mounted folder from BioImage.IO Colab:
import os

path2images = "/content/gdrive/MyDrive/hpa_demo" #@param {type:"string"}
path2annotations = os.path.join(path2images, "annotations")

We will first match all pairs of source images and annotation masks and then display up to 6 of these pairs.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tifffile import imread


# List to hold pairs of image and corresponding annotation masks
image_annotation_pairs = []

# Get list of all images and annotations
image_files = os.listdir(path2images)
annotation_files = os.listdir(path2annotations)

# Iterate through each image in the images folder
for image_file in image_files:
    # Get the base name of the image file (without extension)
    image_name = os.path.splitext(image_file)[0]

    # Find all corresponding annotation masks in the annotations folder
    corresponding_masks = [os.path.join(path2annotations, annotation_file)
                           for annotation_file in annotation_files
                           if annotation_file.startswith(image_name) and annotation_file.endswith('.tif')]

    # If any corresponding masks are found, add them as tuples to the list
    if corresponding_masks:
        image_path = os.path.join(path2images, image_file)
        for mask in corresponding_masks:
            image_annotation_pairs.append((image_path, mask))

# Print the numer of annotations
num_pairs = len(image_annotation_pairs)
print(f"Number of annotations: {num_pairs}")

def read_image(path):
    img = imread(path)
    if img.ndim == 3 and img.shape[0] == 3:
        img = np.transpose(img, [1, 2, 0])
    return img

if num_pairs < 6:
    # Plot one single annotation starting from the first pair
    k = 0
    plt.figure(figsize=(10, 20))
    plt.subplot(1, 2, 1)
    plt.imshow(read_image(image_annotation_pairs[k][0]))
    plt.title(f"Image: {os.path.basename(image_annotation_pairs[k][0])}")
    plt.subplot(1, 2, 2)
    plt.imshow(read_image(image_annotation_pairs[k][1]))
    plt.title(f"Annotation: {os.path.basename(image_annotation_pairs[k][1])}")
    plt.show()
else:
    # Plot several random annotations
    choices = np.random.choice(num_pairs, 6, replace=False)
    plt.figure(figsize=(17, 5))
    for i in range(6):
        plt.subplot(2, 6, 2 * (i + 1) - 1)
        plt.imshow(read_image(image_annotation_pairs[choices[i]][0]))
        plt.axis('off')
        plt.title(f"{os.path.basename(image_annotation_pairs[choices[i]][0])}")
        plt.subplot(2, 6, 2 * (i + 1))
        plt.imshow(read_image(image_annotation_pairs[choices[i]][1]))
        plt.axis('off')
        plt.title("Annotation")
    plt.show()

# 2. Running cellpose 2.0 with a GPU

In [None]:
#@markdown ### Check CUDA version and that GPU is working in cellpose. Also import other libraries.
!nvcc --version
!nvidia-smi

import shutil
from cellpose import core, utils, io, models, metrics
from glob import glob

use_GPU = core.use_gpu()
yn = ['NO', 'YES']
print(f'>>> GPU activated? {yn[use_GPU]}')

##  2.1. Split manual annotations into train and test

**Paths for training, predictions and results**

**`train_dir:`, `test_dir`:** These are the paths to your folders train_dir (with images and masks of training images) and test_dir (with images and masks of test images). You can leave the test_dir blank, but it's recommended to have some test images to check the model's performance. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.


In [None]:
from tifffile import imwrite

#@markdown Run this cell to split your pairs of images and annotations into training and testing groups and then save them to `train_dir`and `test_dir`.

#@markdown ###Path to images and masks:
train_dir = "/content/train" #@param {type:"string"}
test_dir = "/content/test" #@param {type:"string"}
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Check if the folders are empty
if os.listdir(train_dir):
    print(f"Warning: Folder '{train_dir}' is not empty.")
if os.listdir(test_dir):
    print(f"Warning: Folder '{test_dir}' is not empty.")

# Define where the patch file will be saved
base = os.path.dirname(train_dir)

#@markdown ###Define the test portion

test_portion = 0.2 #@param {type:"slider", min:0.01, max:0.99, step:0.01}

# Calculate the number of test images
t = np.floor(test_portion * len(image_annotation_pairs)).astype(int)

# Function to convert RGB images to grayscale
def rgb2gray(image):
    if len(image.shape) == 3:
        if image.shape[0] == 3:
            image = np.transpose(image, [1, 2, 0])
        image = np.float32(image)
        image = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140])
        image = np.int8(image)
    return image

# Helper function to generate new filenames
def generate_new_filenames(image_path, annotation_path):
    image_name = os.path.basename(image_path)
    base_name = image_name.split('.tif')[0]

    # Extract the mask number from the annotation filename
    mask_number = annotation_path.split('_mask_')[-1].split('.tif')[0]

    # Generate new filenames
    new_image_name = f"{base_name}_{mask_number}.tif"
    new_annotation_name = f"{base_name}_{mask_number}_seg.tif"

    return new_image_name, new_annotation_name

# Split the data into training and test sets
for i in range(len(image_annotation_pairs) - t):
    image_path, annotation_path = image_annotation_pairs[i]

    # Generate new filenames
    new_image_name, new_annotation_name = generate_new_filenames(image_path, annotation_path)

    # Process and save the training images
    image = imread(image_path)
    image = rgb2gray(image)
    imwrite(os.path.join(train_dir, new_image_name), image)

    # Copy the corresponding annotation with the new name
    shutil.copyfile(annotation_path, os.path.join(train_dir, new_annotation_name))

for i in range(len(image_annotation_pairs) - t, len(image_annotation_pairs)):
    image_path, annotation_path = image_annotation_pairs[i]

    # Generate new filenames
    new_image_name, new_annotation_name = generate_new_filenames(image_path, annotation_path)

    # Process and save the test images
    image = imread(image_path)
    image = rgb2gray(image)
    imwrite(os.path.join(test_dir, new_image_name), image)

    # Copy the corresponding annotation with the new name
    shutil.copyfile(annotation_path, os.path.join(test_dir, new_annotation_name))

# Print the number of training and test images
print(f"Training images: {len(os.listdir(train_dir)) // 2}")
print(f"Test images: {len(os.listdir(test_dir)) // 2}")

## 2.2. Training parameters

**Pretrained model and new model name**

Fill out the form below with the parameters to start training.

**`initial_model`:** Choose a model from the cellpose [model zoo](https://cellpose.readthedocs.io/en/latest/models.html#model-zoo) to start from.

**`model_name`**: Enter the path where your model will be saved once trained (for instance your result folder).

**Training parameters**

**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. At least 100 epochs are recommended, but sometimes 250 epochs are necessary, particularly from scratch. **Default value: 100**



In [None]:
# model name and path
#@markdown ###Name of the pretrained model to start from and new model name:
from cellpose import models
initial_model = "cyto3" #@param ["cyto", "cyto3","nuclei","tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "scratch"]
model_name = "CP_HPA_CrowdSourcing" #@param {type:"string"}

# other parameters for training.
#@markdown ###Training Parameters:
#@markdown Number of epochs:
n_epochs =  10#@param {type:"number"}

Channel_to_use_for_training = "Grayscale" #@param ["Grayscale", "Blue", "Green", "Red"]

# @markdown ###If you have a secondary channel that can be used for training, for instance nuclei, choose it here:

Second_training_channel= "None" #@param ["None", "Blue", "Green", "Red"]


#@markdown ###Advanced Parameters

Use_Default_Advanced_Parameters = False #@param {type:"boolean"}
#@markdown ###If not, please input:
learning_rate = 0.000001 #@param {type:"number"}
weight_decay = 0.0001 #@param {type:"number"}

if (Use_Default_Advanced_Parameters):
  print("Default advanced parameters enabled")
  learning_rate = 0.1
  weight_decay = 0.0001

#here we check that no model with the same name already exist, if so delete
model_path = train_dir + 'models/'
if os.path.exists(model_path+'/'+model_name):
  print("!! WARNING: "+model_name+" already exists and will be deleted in the following cell !!")

if len(test_dir) == 0:
  test_dir = None

# Here we match the channel to number
if Channel_to_use_for_training == "Grayscale":
  chan = 0
elif Channel_to_use_for_training == "Blue":
  chan = 3
elif Channel_to_use_for_training == "Green":
  chan = 2
elif Channel_to_use_for_training == "Red":
  chan = 1


if Second_training_channel == "Blue":
  chan2 = 3
elif Second_training_channel == "Green":
  chan2 = 2
elif Second_training_channel == "Red":
  chan2 = 1
elif Second_training_channel == "None":
  chan2 = 0

if initial_model=='scratch':
  initial_model = 'None'

Here's what the command to train would be on the command line -- make sure if you run this locally to correct the paths for your local computer.

In [None]:
# run_str = f'python -m cellpose --use_gpu --verbose --train --dir {train_dir} --pretrained_model {initial_model} --chan {chan} --chan2 {chan2} --n_epochs {n_epochs} --learning_rate {learning_rate} --weight_decay {weight_decay}'
# if test_dir is not None:
#     run_str += f' --test_dir {test_dir}'
# run_str += ' --mask_filter _seg.npy' # if you want to use _seg.npy files for training
# print(run_str)

## 2.3. Train new model

Using settings from form above, train model in notebook.

In [None]:
from cellpose import train

# start logger (to see training across epochs)
logger = io.logger_setup()

# DEFINE CELLPOSE MODEL (without size model)
model = models.CellposeModel(gpu=use_GPU, model_type=initial_model)

# set channels
channels = [chan, chan2]

# get files
output = io.load_train_test_data(train_dir, test_dir, mask_filter="_seg")
train_data, train_labels, _, test_data, test_labels, _ = output

new_model_path = train.train_seg(model.net, train_data=train_data,
                              train_labels=train_labels,
                              test_data=test_data,
                              test_labels=test_labels,
                              channels=channels,
                              save_path=train_dir,
                              n_epochs=n_epochs,
                              learning_rate=learning_rate,
                              weight_decay=weight_decay,
                              SGD=True,
                              nimg_per_epoch=1,
                              model_name=model_name,
                              min_train_masks=1)

# diameter of labels in training images
# use model diameter if user diameter is 0
diameter=0
diameter = model.diam_labels if diameter==0 else diameter
diam_labels = model.diam_labels.item()

## 2.4. Evaluate on test data (optional)

If you have test data, check performance

In [None]:
# get files (during training, test_data is transformed so we will load it again)
output = io.load_train_test_data(test_dir, mask_filter='_seg')
test_data, test_labels = output[:2]
# use model diameter if user diameter is 0

# run model on test images
masks = model.eval(test_data,
                   channels=[chan, chan2],
                   diameter=diam_labels)[0]

# check performance using ground truth labels
ap = metrics.average_precision(test_labels, masks)[0]
print('')
print(f'>>> average precision at iou threshold 0.5 = {ap[:,0].mean():.3f}')

plot masks

In [None]:
plt.figure(figsize=(12,8), dpi=150)
# use model diameter if user diameter is 0
cols = 5 if len(test_data)>5 else len(test_data)
for k,im in enumerate(test_data):
    if k<cols:
      img = im.copy()
      plt.subplot(3,cols, k+1)
      img = np.vstack((img, np.zeros_like(img)[:1]))
      #img = img.transpose(1,2,0)
      plt.imshow(img)
      plt.axis('off')
      if k==0:
          plt.title('image')

      plt.subplot(3,cols, cols + k+1)
      plt.imshow(masks[k])
      plt.axis('off')
      if k==0:
          plt.title('predicted labels')

      plt.subplot(3,cols, 2*cols+ k+1)
      plt.imshow(test_labels[k])
      plt.axis('off')
      if k==0:
          plt.title('true labels')
plt.tight_layout()

# 3. Use custom model to segment images

Take custom trained model from above, or upload your own model to google drive / colab runtime.

## Parameters

In [None]:
# model name and path

#@markdown ###Custom model path (full path):

model_path = "/content/train/models/CP_HPA_CrowdSourcing" #@param {type:"string"}

#@markdown ###Path to images:

dir = "/content/test" #@param {type:"string"}

#@markdown ###Channel Parameters:

Channel_to_use_for_segmentation = "Red" #@param ["Grayscale", "Blue", "Green", "Red"]

# @markdown If you have a secondary channel that can be used, for instance nuclei, choose it here:

Second_segmentation_channel= "Blue" #@param ["None", "Blue", "Green", "Red"]


# Here we match the channel to number
if Channel_to_use_for_segmentation == "Grayscale":
  chan = 0
elif Channel_to_use_for_segmentation == "Blue":
  chan = 3
elif Channel_to_use_for_segmentation == "Green":
  chan = 2
elif Channel_to_use_for_segmentation == "Red":
  chan = 1


if Second_segmentation_channel == "Blue":
  chan2 = 3
elif Second_segmentation_channel == "Green":
  chan2 = 2
elif Second_segmentation_channel == "Red":
  chan2 = 1
elif Second_segmentation_channel == "None":
  chan2 = 0

#@markdown ### Segmentation parameters:

#@markdown diameter of cells (set to zero to use diameter from training set):
diameter =  0#@param {type:"number"}
#@markdown threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded):
flow_threshold = 0.4 #@param {type:"slider", min:0.0, max:3.0, step:0.1}
#@markdown threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)):
cellprob_threshold=0 #@param {type:"slider", min:-6, max:6, step:1}


if you're using the example test data we'll copy it to a new folder

In [None]:
src = 'human_in_the_loop/test'
if dir[:len(src)] == src:
    files = io.get_image_files(dir, '_masks')
    dir = 'human_in_the_loop/eval/'
    os.makedirs(dir, exist_ok=True)
    for f in files:
        dst = dir + os.path.split(f)[1]
        print(f'{f} > {dst}')
        shutil.copyfile(f, dst)

Here's what the command to train would be on the command line -- make sure if you run this locally to correct the paths for your local computer.

In [None]:
run_str = f'python -m cellpose --use_gpu --verbose --dir {dir} --pretrained_model {model_path} --chan {chan} --chan2 {chan2} --diameter {diameter} --flow_threshold {flow_threshold} --cellprob_threshold {cellprob_threshold}'
print(run_str)

## run custom model

how to run the custom model in a notebook

In [None]:
# gets image files in dir (ignoring image files ending in _masks)
files = io.get_image_files(dir, '_masks')
print(files)
images = [io.imread(f) for f in files]

# declare model
model = models.CellposeModel(gpu=True,
                             pretrained_model=model_path)

# use model diameter if user diameter is 0
diameter = model.diam_labels if diameter==0 else diameter

# run model on test images
masks, flows, styles = model.eval(images,
                                  channels=[chan, chan2],
                                  diameter=diameter,
                                  flow_threshold=flow_threshold,
                                  cellprob_threshold=cellprob_threshold
                                  )

## save output to *_seg.npy

you will see the files save in the Files tab and you can download them from there

In [None]:
from cellpose import io

io.masks_flows_to_seg(images,
                      masks,
                      flows,
                      files,
                      channels=[chan, chan2],
                      diams=diameter*np.ones(len(masks)),
                      )

## save output masks to tiffs/pngs or txt files for imageJ

In [None]:
io.save_masks(images,
              masks,
              flows,
              files,
              channels=[chan, chan2],
              png=True, # save masks as PNGs and save example image
              tif=True, # save masks as TIFFs
              save_txt=True, # save txt outlines for ImageJ
              save_flows=False, # save flows as TIFFs
              save_outlines=False, # save outlines as TIFFs
              save_mpl=True # make matplotlib fig to view (WARNING: SLOW W/ LARGE IMAGES)
              )


In [None]:
f = files[0]
plt.figure(figsize=(12,4), dpi=300)
plt.imshow(io.imread(os.path.splitext(f)[0] + '_cp_output.png'))
plt.axis('off')