In [1]:
import tifffile
import imageio
from cellpose import models
import numpy as np
from cellstitch.pipeline import *
from cellstitch.evaluation import *
import pandas as pd
import h5py

In [13]:
dataset = "Sepal"

# Train cellpose model from scratch

In [None]:
for i in range(70): 
    filename = "%s_%02d" % (dataset, i)
    img = np.load("../DATA/%s/%s.npy" % (dataset, filename))
    labels = np.load("../DATA/%s/%s_masks.npy" % (dataset, filename))
    depth = img.shape[0] 
    
    for i in range(depth): 
        imageio.imwrite("../DATA/%s/cellpose_train/%s_%s.tif" % (dataset, filename, i), img[i])
        imageio.imwrite("../DATA/%s/cellpose_train/%s_%s_masks.tif" % (dataset, filename, i), labels[i])

`python -m cellpose --train --use_gpu --dir ../DATA/<dataset>/cellpose_train --pretrained_model None --n_epochs 100  --verbose`

###  Generate cellpose3d results

In [3]:
model_name = 'cellpose_residual_on_style_on_concatenation_off_cellpose_train_2023_05_30_20_25_23.614818'
model_dir = '../DATA/%s/cellpose_train/models/%s' % (dataset, model_name) 

In [4]:
flow_threshold = 1
model = models.CellposeModel(gpu=True, pretrained_model=model_dir)

In [None]:
for i in range(70, 100): 
    test_filename = "%s_%02d" % (dataset, i)
    print("Starting %s" % test_filename)
    img = np.load("../DATA/%s/%s.npy" % (dataset, test_filename))
    masks, _, _ = model.eval(img, do_3D=True, flow_threshold=flow_threshold, channels = [0,0]) 
    np.save("./results/%s/cellpose3d/%s.npy" % (dataset, test_filename), masks) 

### Generate cellpose2d results

In [5]:
for i in range(70, 100): 
    test_filename = "%s_%02d" % (dataset, i)
    print("Starting %s" % test_filename)
    img = np.load("../DATA/%s/%s.npy" % (dataset, test_filename))
    masks, _, _ = model.eval(list(img), do_3D=False, flow_threshold=flow_threshold, channels = [0,0])
    masks = cp_utils.stitch3D(np.array(masks))

    np.save("./results/%s/cellpose2d/%s.npy" % (dataset, test_filename), masks) 

Starting Valve_70
Starting Valve_71
Starting Valve_72
Starting Valve_73
Starting Valve_74
Starting Valve_75
Starting Valve_76
Starting Valve_77
Starting Valve_78
Starting Valve_79
Starting Valve_80
Starting Valve_81
Starting Valve_82
Starting Valve_83
Starting Valve_84
Starting Valve_85
Starting Valve_86
Starting Valve_87
Starting Valve_88
Starting Valve_89
Starting Valve_90
Starting Valve_91
Starting Valve_92
Starting Valve_93
Starting Valve_94
Starting Valve_95
Starting Valve_96
Starting Valve_97
Starting Valve_98
Starting Valve_99


### Generate cellstitch results

In [6]:
for i in range(70, 100): 
    test_filename = "%s_%02d" % (dataset, i)
    print("Starting %s" % test_filename)
    img = np.load("../DATA/%s/%s.npy" % (dataset, test_filename))
    
    cellstitch, _, _ = model.eval(list(img), flow_threshold=flow_threshold, channels = [0,0])
    cellstitch = np.array(cellstitch)

    yz_masks, _, _ = model.eval(list(img.transpose(1,0,2)), flow_threshold=flow_threshold, channels = [0,0])
    yz_masks = np.array(yz_masks).transpose(1,0,2)

    xz_masks, _, _ = model.eval(list(img.transpose(2,1,0)), flow_threshold=flow_threshold, channels = [0,0])
    xz_masks = np.array(xz_masks).transpose(2,1,0)

    full_stitch(cellstitch, yz_masks, xz_masks)
    
    np.save("./results/%s/cellstitch/%s.npy" % (dataset, test_filename), cellstitch)  

Starting Valve_70
Starting Valve_71
Starting Valve_72
Starting Valve_73
Starting Valve_74
Starting Valve_75
Starting Valve_76
Starting Valve_77
Starting Valve_78
Starting Valve_79
Starting Valve_80
Starting Valve_81
Starting Valve_82
Starting Valve_83
Starting Valve_84
Starting Valve_85
Starting Valve_86
Starting Valve_87
Starting Valve_88
Starting Valve_89
Starting Valve_90
Starting Valve_91
Starting Valve_92
Starting Valve_93
Starting Valve_94
Starting Valve_95
Starting Valve_96
Starting Valve_97
Starting Valve_98
Starting Valve_99


## PlantSeg
### PlantSeg 
- First, created a plantseg virtual enviroment: 
    - `conda install -c conda-forge mamba` 
    - `mamba create -n plant-seg -c pytorch -c nvidia -c conda-forge -c lcerrone plantseg pytorch-cuda=11.7` 
- activate the environment: `conda activate plant-seg` 
- download the ovules test dataset: https://osf.io/uzq3w/ to `../DATA/<dataset>/plantseg_test/` 
- set the `path` in `config.yaml` to `../DATA/<dataset>/plantseg_test/` 
- perform segmentation with the `generic_confocal_3D_unet` by running `plantseg --config config.yaml`

In [3]:
plantseg_results_folder = "../DATA/%s/plantseg_test/PreProcessing/generic_confocal_3D_unet/MultiCut" % dataset

In [4]:
for i in range(70, 100): 
    test_filename = "%s_%02d" % (dataset, i)
    print("Starting %s" % test_filename)
    
    with h5py.File("%s/%s_predictions_multicut.h5" % (plantseg_results_folder, test_filename), "r") as f:
        plantseg = np.array(list(f['segmentation'])) 
        
    plantseg[np.where(plantseg == 1)] = 0 # plantseg use 1 as labels
    np.save("./results/%s/plantseg/%s.npy" % (dataset, test_filename), plantseg)

Starting Filament_70
Starting Filament_71
Starting Filament_72
Starting Filament_73
Starting Filament_74
Starting Filament_75
Starting Filament_76
Starting Filament_77
Starting Filament_78
Starting Filament_79
Starting Filament_80
Starting Filament_81
Starting Filament_82
Starting Filament_83
Starting Filament_84
Starting Filament_85
Starting Filament_86
Starting Filament_87
Starting Filament_88
Starting Filament_89
Starting Filament_90
Starting Filament_91
Starting Filament_92
Starting Filament_93
Starting Filament_94
Starting Filament_95
Starting Filament_96
Starting Filament_97
Starting Filament_98
Starting Filament_99


# Benchmark results

In [14]:
methods = ["plantseg"]
plantseg_results_folder = "../DATA/%s/plantseg_test/PreProcessing/generic_confocal_3D_unet/MultiCut" % dataset

In [15]:
for method in methods: 
    data = [] 
    for i in range(70, 100): 
        filename = "%s_%02d" % (dataset, i)
        labels = np.load('../DATA/%s/%s_masks.npy' % (dataset, filename))
        
        with h5py.File("%s/%s_predictions_multicut.h5" % (plantseg_results_folder, filename), "r") as f:
            masks = np.array(list(f['segmentation'])) 

            masks[np.where(masks == 1)] = 0 # plantseg use 1 as labels

        ap25, _, _, _ = average_precision(labels, masks, 0.25)
        ap50, tp, fp, fn = average_precision(labels, masks, 0.5)
        ap75, _, _, _ = average_precision(labels, masks, 0.75) 
        
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)

        row = [ 
            filename, 
            ap25,
            ap50,
            ap75,
            tp, 
            fp, 
            fn, 
            precision,
            recall
        ]

        data.append(row)

    df = pd.DataFrame(data, columns=[
        "filename",
        "ap25", 
        "ap50",
        "ap75", 
        "tp", 
        "fp", 
        "fn",
        "precision",
        "recall"
    ])

    df.to_csv("./results/%s/%s.csv" % (dataset, method), index=False)

  iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
