In [19]:
import imageio
import random
import numpy as np  
import pandas as pd 

from cellpose import models
from cellpose.utils import fill_holes_and_remove_small_masks
from cellstitch.utils import *
from cellstitch.pipeline import *
from cellstitch.evaluation import *

In [2]:
filenames = set(get_filenames("../DATA/ATAS/raw/"))

In [3]:
# test_filenames = set(random.sample(filenames, int(0.3 * len(filenames))))
# train_filenames = filenames - test_filenames

### Train cellpose model from scratch

In [4]:
# generate training data
data_folder = "../DATA/ATAS"
cellpose_folder = "../DATA/ATAS/cellpose_train"

In [5]:
for train_filename in train_filenames: 
    img = np.load("%s/raw/%s" % (data_folder, train_filename))
    labels = np.load("%s/labels/%s" % (data_folder, train_filename)) 
    depth = img.shape[0] 
    
    for i in range(depth): 
        imageio.imwrite("%s/%s_%s.tif" % (cellpose_folder, train_filename, i), img[i])
        imageio.imwrite("%s/%s_%s_masks.tif" % (cellpose_folder, train_filename, i), labels[i])

NameError: name 'train_filenames' is not defined

`python -m cellpose --train --use_gpu --dir ../DATA/ATAS/cellpose_train --pretrained_model None --n_epochs 100  --verbose`

### Test file names

In [4]:
filenames = set(get_filenames("../DATA/ATAS/raw/"))

In [5]:
train_filenames = []

In [6]:
for file in os.listdir("../DATA/ATAS/cellpose_train/"):
    if file.endswith(".npy_95.tif"):
        train_filenames.append(file[:-7])

In [7]:
test_filenames = filenames - set(train_filenames)

### Generate cellpose3d results

In [11]:
model_dir = '../DATA/ATAS/cellpose_train/models/cellpose_residual_on_style_on_concatenation_off_cellpose_train_2023_04_28_14_46_49.933105'

In [12]:
flow_threshold = 1
model = models.CellposeModel(gpu=True, pretrained_model=model_dir)

In [37]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename)
    img = np.load("../DATA/ATAS/raw/%s" % (test_filename)) 
    masks, _, _ = model.eval(img, do_3D=True, flow_threshold=flow_threshold, channels = [0,0]) 
    np.save("./results/ATAS/pipeline/cellpose3d/%s" % test_filename, masks)

Starting 28hrs_plant1_trim-acylYFP.npy
Starting 20hrs_plant2_trim-acylYFP.npy
Starting 36hrs_plant2_trim-acylYFP.npy
Starting 0hrs_plant18_trim-acylYFP.npy
Starting 72hrs_plant2_trim-acylYFP.npy
Starting 8hrs_plant15_trim-acylYFP.npy
Starting 48hrs_plant4_trim-acylYFP.npy
Starting 60hrs_plant4_trim-acylYFP.npy
Starting 84hrs_plant18_trim-acylYFP.npy
Starting 76hrs_plant15_trim-acylYFP.npy
Starting 20hrs_plant13_trim-acylYFP.npy
Starting 44hrs_plant4_trim-acylYFP.npy
Starting 64hrs_plant13_trim-acylYFP.npy
Starting 40hrs_plant1_trim-acylYFP.npy
Starting 8hrs_plant1_trim-acylYFP.npy
Starting 40hrs_plant2_trim-acylYFP.npy
Starting 76hrs_plant1_trim-acylYFP.npy
Starting 24hrs_plant13_trim-acylYFP.npy
Starting 44hrs_plant18_trim-acylYFP.npy
Starting 36hrs_plant15_trim-acylYFP.npy
Starting 52hrs_plant2_trim-acylYFP.npy
Starting 72hrs_plant18_trim-acylYFP.npy
Starting 60hrs_plant2_trim-acylYFP.npy
Starting 44hrs_plant2_trim-acylYFP.npy
Starting 4hrs_plant15_trim-acylYFP.npy
Starting 0hrs_plan

### Generate cellstitch results

In [28]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename)
    img = np.load("../DATA/ATAS/raw/%s" % test_filename)
    
    cellstitch, _, _ = model.eval(list(img), flow_threshold=flow_threshold, channels = [0,0])
    cellstitch = np.array(cellstitch)

    yz_masks, _, _ = model.eval(list(img.transpose(1,0,2)), flow_threshold=flow_threshold, channels = [0,0])
    yz_masks = np.array(yz_masks).transpose(1,0,2)

    xz_masks, _, _ = model.eval(list(img.transpose(2,1,0)), flow_threshold=flow_threshold, channels = [0,0])
    xz_masks = np.array(xz_masks).transpose(2,1,0)

    full_stitch(cellstitch, yz_masks, xz_masks)
    
    np.save("./results/ATAS/pipeline/cellstitch/%s" % test_filename, cellstitch) 

Starting 8hrs_plant1_trim-acylYFP.npy




Starting 32hrs_plant13_trim-acylYFP.npy




Starting 20hrs_plant13_trim-acylYFP.npy




Starting 8hrs_plant18_trim-acylYFP.npy




Starting 40hrs_plant1_trim-acylYFP.npy




Starting 76hrs_plant1_trim-acylYFP.npy




Starting 36hrs_plant15_trim-acylYFP.npy




Starting 60hrs_plant4_trim-acylYFP.npy




Starting 4hrs_plant15_trim-acylYFP.npy




Starting 48hrs_plant15_trim-acylYFP.npy




Starting 8hrs_plant2_trim-acylYFP.npy




Starting 44hrs_plant4_trim-acylYFP.npy




Starting 72hrs_plant18_trim-acylYFP.npy




Starting 36hrs_plant2_trim-acylYFP.npy




Starting 44hrs_plant2_trim-acylYFP.npy




Starting 76hrs_plant15_trim-acylYFP.npy




Starting 24hrs_plant15_trim-acylYFP.npy




Starting 12hrs_plant2_trim-acylYFP.npy




Starting 40hrs_plant13_trim-acylYFP.npy




Starting 20hrs_plant2_trim-acylYFP.npy




Starting 64hrs_plant13_trim-acylYFP.npy




Starting 0hrs_plant13_trim-acylYFP.npy




Starting 24hrs_plant13_trim-acylYFP.npy




Starting 0hrs_plant18_trim-acylYFP.npy




Starting 52hrs_plant2_trim-acylYFP.npy




Starting 68hrs_plant1_trim-acylYFP.npy




Starting 28hrs_plant1_trim-acylYFP.npy




Starting 84hrs_plant18_trim-acylYFP.npy




Starting 80hrs_plant15_trim-acylYFP.npy




Starting 60hrs_plant2_trim-acylYFP.npy




Starting 0hrs_plant2_trim-acylYFP.npy




Starting 12hrs_plant18_trim-acylYFP.npy




Starting 44hrs_plant18_trim-acylYFP.npy




Starting 48hrs_plant4_trim-acylYFP.npy




Starting 40hrs_plant2_trim-acylYFP.npy




Starting 72hrs_plant2_trim-acylYFP.npy




Starting 8hrs_plant15_trim-acylYFP.npy




### Generate cellstitch3d results

In [13]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename)
    img = np.load("../DATA/ATAS/raw/%s" % test_filename)
    
    cellstitch = np.load("./results/ATAS/pipeline/cellpose3d/%s" % test_filename)
    
    yz_masks, _, _ = model.eval(list(img.transpose(1,0,2)), flow_threshold=flow_threshold, channels = [0,0])
    yz_masks = np.array(yz_masks).transpose(1,0,2)

    xz_masks, _, _ = model.eval(list(img.transpose(2,1,0)), flow_threshold=flow_threshold, channels = [0,0])
    xz_masks = np.array(xz_masks).transpose(2,1,0)
    
    full_stitch(cellstitch, yz_masks, xz_masks)
    
    np.save("./results/ATAS/pipeline/cellstitch3d/%s" % test_filename, cellstitch)

Starting 4hrs_plant15_trim-acylYFP.npy




Starting 24hrs_plant15_trim-acylYFP.npy




Starting 48hrs_plant15_trim-acylYFP.npy




Starting 80hrs_plant15_trim-acylYFP.npy




Starting 40hrs_plant1_trim-acylYFP.npy




Starting 72hrs_plant2_trim-acylYFP.npy




Starting 76hrs_plant15_trim-acylYFP.npy




Starting 24hrs_plant13_trim-acylYFP.npy




Starting 68hrs_plant1_trim-acylYFP.npy




Starting 76hrs_plant1_trim-acylYFP.npy




Starting 0hrs_plant13_trim-acylYFP.npy




Starting 8hrs_plant18_trim-acylYFP.npy




Starting 36hrs_plant2_trim-acylYFP.npy




Starting 0hrs_plant2_trim-acylYFP.npy




Starting 12hrs_plant2_trim-acylYFP.npy




Starting 72hrs_plant18_trim-acylYFP.npy




Starting 52hrs_plant2_trim-acylYFP.npy




Starting 60hrs_plant2_trim-acylYFP.npy




Starting 36hrs_plant15_trim-acylYFP.npy




Starting 60hrs_plant4_trim-acylYFP.npy




Starting 12hrs_plant18_trim-acylYFP.npy




Starting 8hrs_plant2_trim-acylYFP.npy




Starting 28hrs_plant1_trim-acylYFP.npy




Starting 32hrs_plant13_trim-acylYFP.npy




Starting 0hrs_plant18_trim-acylYFP.npy




Starting 48hrs_plant4_trim-acylYFP.npy




Starting 44hrs_plant4_trim-acylYFP.npy




Starting 84hrs_plant18_trim-acylYFP.npy




Starting 40hrs_plant2_trim-acylYFP.npy




Starting 8hrs_plant15_trim-acylYFP.npy




Starting 64hrs_plant13_trim-acylYFP.npy




Starting 44hrs_plant2_trim-acylYFP.npy




Starting 40hrs_plant13_trim-acylYFP.npy




Starting 20hrs_plant13_trim-acylYFP.npy




Starting 8hrs_plant1_trim-acylYFP.npy




Starting 44hrs_plant18_trim-acylYFP.npy




Starting 20hrs_plant2_trim-acylYFP.npy




# Benchmark Results

In [8]:
ap_threshold = 0.5

In [28]:
method = "cellstitch3d"

data = [] 
for filename in test_filenames:
    print("Starting %s" % filename)
    labels = np.load('../DATA/ATAS/labels/%s' % filename)
    true_num_cells = np.unique(labels).size - 1 
    true_avg_vol = get_avg_vol(labels) 
    
    masks = np.load("./results/ATAS/pipeline/%s/%s" % (method, filename))
    
    num_cells = np.unique(masks).size - 1
    d_num_cells = abs(num_cells - true_num_cells) / true_num_cells

    avg_vol = get_avg_vol(masks)
    d_avg_vol = abs(true_avg_vol - avg_vol) / true_avg_vol

    ap, tp, fp, fn = average_precision(labels, masks, ap_threshold)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    
    row = [
        filename, 
        d_num_cells, 
        d_avg_vol, 
        ap, 
        tp, 
        fp, 
        fn, 
        precision,
        recall, 
        true_num_cells, 
        true_avg_vol, 
        num_cells, 
        avg_vol
    ]
    
    data.append(row)

df = pd.DataFrame(data, columns=[
    "filename",  
    "d_num_cells", 
    "d_avg_vol", 
    "ap", 
    "tp", 
    "fp", 
    "fn",
    "precision",
    "recall",
    "true_num_cells", 
    "true_avg_vol", 
    "num_cells", 
    "avg_vol"
])

df.to_csv("./results/ATAS/pipeline/%s.csv" % method, index=False)

Starting 76hrs_plant15_trim-acylYFP.npy
Starting 40hrs_plant1_trim-acylYFP.npy
Starting 12hrs_plant18_trim-acylYFP.npy
Starting 44hrs_plant4_trim-acylYFP.npy
Starting 20hrs_plant13_trim-acylYFP.npy
Starting 8hrs_plant1_trim-acylYFP.npy
Starting 36hrs_plant15_trim-acylYFP.npy
Starting 20hrs_plant2_trim-acylYFP.npy
Starting 0hrs_plant18_trim-acylYFP.npy
Starting 32hrs_plant13_trim-acylYFP.npy
Starting 40hrs_plant2_trim-acylYFP.npy
Starting 64hrs_plant13_trim-acylYFP.npy
Starting 60hrs_plant2_trim-acylYFP.npy
Starting 72hrs_plant18_trim-acylYFP.npy
Starting 44hrs_plant18_trim-acylYFP.npy
Starting 68hrs_plant1_trim-acylYFP.npy
Starting 36hrs_plant2_trim-acylYFP.npy
Starting 4hrs_plant15_trim-acylYFP.npy
Starting 72hrs_plant2_trim-acylYFP.npy
Starting 8hrs_plant2_trim-acylYFP.npy
Starting 40hrs_plant13_trim-acylYFP.npy
Starting 84hrs_plant18_trim-acylYFP.npy
Starting 48hrs_plant4_trim-acylYFP.npy
Starting 52hrs_plant2_trim-acylYFP.npy
Starting 0hrs_plant2_trim-acylYFP.npy
Starting 28hrs_pla

# Analyze results

In [29]:
method = 'cellstitch3d'

In [30]:
df = pd.read_csv('./results/ATAS/pipeline/%s.csv' % method)

In [31]:
df.mean()

  df.mean()


d_num_cells           0.604424
d_avg_vol             0.277375
ap                    0.788906
tp                 1763.675676
fp                  503.540541
fn                   38.270270
precision             0.802261
recall                0.977402
true_num_cells      920.108108
true_avg_vol      16411.465053
num_cells          1385.378378
avg_vol           11543.335754
dtype: float64

In [32]:
df.std()

  df.std()


d_num_cells          0.904017
d_avg_vol            0.202845
ap                   0.169275
tp                 522.671665
fp                 582.087937
fn                  17.827333
precision            0.171143
recall               0.012356
true_num_cells     267.135191
true_avg_vol      3972.600788
num_cells          598.932307
avg_vol           3886.932469
dtype: float64

In [17]:
cp3d_df = pd.read_csv('./results/ATAS/pipeline/cellpose3d.csv')

In [18]:
cp3d_df

Unnamed: 0,filename,d_num_cells,d_avg_vol,ap,tp,fp,fn
0,28hrs_plant1_trim-acylYFP.npy,0.054291,0.052825,0.957374,2246,81.0,19.0
1,8hrs_plant15_trim-acylYFP.npy,0.441606,0.315988,0.77728,1560,405.0,42.0
2,52hrs_plant2_trim-acylYFP.npy,0.055291,0.053814,0.941558,2030,92.0,34.0
3,44hrs_plant2_trim-acylYFP.npy,0.01249,0.017088,0.95904,2318,57.0,42.0
4,76hrs_plant15_trim-acylYFP.npy,2.524642,0.724024,0.417428,1188,1623.0,35.0
5,76hrs_plant1_trim-acylYFP.npy,0.566667,0.376663,0.739779,3076,1000.0,82.0
6,40hrs_plant2_trim-acylYFP.npy,0.020017,0.024512,0.959931,2228,58.0,35.0
7,36hrs_plant2_trim-acylYFP.npy,0.045326,0.049644,0.946445,2050,82.0,34.0
8,12hrs_plant18_trim-acylYFP.npy,0.038462,0.039095,0.961656,1580,47.0,16.0
9,8hrs_plant1_trim-acylYFP.npy,0.03208,0.033158,0.956995,1758,54.0,25.0


In [25]:
mask = np.load("./results/ATAS/pipeline/cellstitch3d/76hrs_plant15_trim-acylYFP.npy")

In [21]:
labels = np.load('../DATA/ATAS/labels/76hrs_plant15_trim-acylYFP.npy')

In [22]:
average_precision(labels, mask, 0.5)

[0.2954717935745004, 1168, 2740.0, 45.0]

In [27]:
average_precision(labels, mask, 0.5)

[0.3686868686868687, 1168, 1955.0, 45.0]