In [1]:
import h5py
import numpy as np
from cellpose import models
from cellpose import utils as cp_utils
from cellstitch.pipeline import *
from cellstitch.evaluation import *
from cellstitch.utils import *
import pandas as pd

In [2]:
dataset = "ovules" # or ovules_subsampled

In [3]:
test_filenames = ["N_294_final_crop_ds2.npy", 
                 "N_435_final_crop_ds2.npy",
                 "N_441_final_crop_ds2.npy",
                 "N_511_final_crop_ds2.npy",
                 "N_522_final_crop_ds2.npy",
                 "N_590_final_crop_ds2.npy",
                 "N_593_final_crop_ds2.npy"]

# Pipeline Benchmark

Comparision between cellstitch (2D), cellpose3D (2.5D), plantseg (3D), cellstitch3D; using the same training set.

### PlantSeg
- First, created a plantseg virtual enviroment: 
    - `conda install -c conda-forge mamba` 
    - `mamba create -n plant-seg -c pytorch -c nvidia -c conda-forge -c lcerrone plantseg pytorch-cuda=11.7` 
- activate the environment: `conda activate plant-seg` 
- download the ovules test dataset: https://osf.io/uzq3w/ to `../DATA/<dataset>/plantseg_test/` 
- set the `path` in `config.yaml` to `../DATA/<dataset>/plantseg_test/` 
- perform segmentation with the `confocal_3D_unet_ovules_ds1x` by running `plantseg --config config.yaml`

In [15]:
plantseg_results_folder = "../DATA/%s/plantseg_test/PreProcessing/confocal_3D_unet_ovules_ds1x/MultiCut" % dataset

In [16]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename) 
    
    with h5py.File("%s/%s_predictions_multicut.h5" % (plantseg_results_folder, test_filename[:-4]), "r") as f:
        plantseg = np.array(list(f['segmentation'])) 
        
    plantseg[np.where(plantseg == 1)] = 0 # plantseg use 1 as labels
    np.save("./results/%s/plantseg/%s" % (dataset, test_filename), plantseg)

Starting N_294_final_crop_ds2.npy
Starting N_435_final_crop_ds2.npy
Starting N_441_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy
Starting N_522_final_crop_ds2.npy
Starting N_590_final_crop_ds2.npy
Starting N_593_final_crop_ds2.npy


### Train cellpose model from scratch
First, need to prepare training data for cellpose.

In [None]:
train_filenames = ["N_404_ds2x.npy", 
                  "N_405_A_ds2x.npy", 
                  "N_405_B_ds2x.npy", 
                  "N_416_ds2x.npy",
                  "N_422_ds2x.npy",
                  "N_425_ds2x.npy",
                  "N_428_ds2x.npy",
                  "N_440_ds2x.npy",
                  "N_445_ds2x.npy",
                  "N_449_ds2x.npy",
                  "N_450_ds2x.npy", 
                  "N_451_ds2x.npy",
                  "N_454_ds2x.npy",
                  "N_457_ds2x.npy",
                  "N_458_ds2x.npy",
                  "N_487_ds2x.npy",
                  "N_509_ds2x.npy",
                  "N_512_ds2x.npy",
                   "N_517_ds2x.npy",
                  "N_534_ds2x.npy",
                  "N_535_ds2x.npy",
                  "N_536_ds2x.npy"]

ovules_folder = "../DATA/ovules"
cellpose_folder = "../DATA/ovules/cellpose_train"

In [None]:
for train_filename in train_filenames: 
    img = np.load("%s/raw/%s" % (ovules_folder, train_filename))
    labels = np.load("%s/labels/%s" % (ovules_folder, train_filename)) 
    depth = img.shape[0] 
    
    for i in range(depth): 
        imageio.imwrite("%s/%s_%s.tif" % (cellpose_folder, train_filename, i), img[i])
        imageio.imwrite("%s/%s_%s_masks.tif" % (cellpose_folder, train_filename, i), labels[i])

`python -m cellpose --train --dir ../DATA/ovules/cellpose_train --pretrained_model None --n_epochs 100  --verbose` 

### Generate cellpose3d results

In [4]:
model_dir = '../DATA/ovules/cellpose_train/models/cellpose_residual_on_style_on_concatenation_off_cellpose_train_2023_05_08_09_31_06.231473'

In [5]:
flow_threshold = 1
model = models.CellposeModel(gpu=True, pretrained_model=model_dir)

In [6]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename)
    img = np.load("../DATA/%s/raw/%s" % (dataset, test_filename)) 
    masks, _, _ = model.eval(img, do_3D=True, flow_threshold=flow_threshold, channels = [0,0]) 
    np.save("./results/%s/cellpose3d/%s" % (dataset, test_filename), masks) 

Starting N_294_final_crop_ds2.npy
Starting N_435_final_crop_ds2.npy
Starting N_441_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy
Starting N_522_final_crop_ds2.npy
Starting N_590_final_crop_ds2.npy
Starting N_593_final_crop_ds2.npy


### Generate cellpose2d results

In [7]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename)
    img = np.load("../DATA/%s/raw/%s" % (dataset, test_filename)) 
    masks, _, _ = model.eval(list(img), do_3D=False, flow_threshold=flow_threshold, channels = [0,0])
    masks = cp_utils.stitch3D(np.array(masks))
    
    np.save("./results/%s/cellpose2d/%s" % (dataset, test_filename), masks)

Starting N_294_final_crop_ds2.npy
Starting N_435_final_crop_ds2.npy
Starting N_441_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy
Starting N_522_final_crop_ds2.npy
Starting N_590_final_crop_ds2.npy
Starting N_593_final_crop_ds2.npy


### Generate cellstitch results

In [7]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename)
    img = np.load("../DATA/%s/raw/%s" % (dataset, test_filename)) 
    
    cellstitch, _, _ = model.eval(list(img), flow_threshold=flow_threshold, channels = [0,0])
    cellstitch = np.array(cellstitch)

    yz_masks, _, _ = model.eval(list(img.transpose(1,0,2)), flow_threshold=flow_threshold, channels = [0,0])
    yz_masks = np.array(yz_masks).transpose(1,0,2)

    xz_masks, _, _ = model.eval(list(img.transpose(2,1,0)), flow_threshold=flow_threshold, channels = [0,0])
    xz_masks = np.array(xz_masks).transpose(2,1,0)

    full_stitch(cellstitch, yz_masks, xz_masks)
    
    np.save("./results/%s/cellstitch/%s" % (dataset, test_filename), cellstitch) 

Starting N_294_final_crop_ds2.npy
Starting N_435_final_crop_ds2.npy




Starting N_441_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy




Starting N_522_final_crop_ds2.npy




Starting N_590_final_crop_ds2.npy




Starting N_593_final_crop_ds2.npy




# benchmark results

In [4]:
methods = ["plantseg"]

for method in methods: 
    print("Starting %s" % method) 
    
    data = []
    for filename in test_filenames:
        print("Starting %s" % filename)
        labels = np.load('../DATA/%s/labels/%s' % (dataset, filename)) 
        masks = np.load("./results/%s/%s/%s" % (dataset, method, filename))
        
        ap25, _, _, _ = average_precision(labels, masks, 0.25)
        ap50, tp, fp, fn = average_precision(labels, masks, 0.5)
        ap75, _, _, _ = average_precision(labels, masks, 0.75) 
        
        if (tp + fp) != 0: 
            precision = tp / (tp + fp)
        else: 
            precision = 0
            
        if tp + fp != 0: 
            recall = tp / (tp + fn)
        else: 
            precision = 0

        row = [ 
            filename, 
            ap25,
            ap50,
            ap75,
            tp, 
            fp, 
            fn, 
            precision,
            recall
        ]

        data.append(row)

    df = pd.DataFrame(data, columns=[
        "filename",
        "ap25", 
        "ap50",
        "ap75", 
        "tp", 
        "fp", 
        "fn",
        "precision",
        "recall"
    ])

    df.to_csv("./results/%s/%s.csv" % (dataset, method), index=False)

Starting plantseg
Starting N_294_final_crop_ds2.npy


  iou = overlap / (n_pixels_pred + n_pixels_true - overlap)


Starting N_435_final_crop_ds2.npy
Starting N_441_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy
Starting N_522_final_crop_ds2.npy
Starting N_590_final_crop_ds2.npy
Starting N_593_final_crop_ds2.npy
