# Interpolation-related benchmarks

CellStitch segmentation
- Comparison btw CellStitch (anisotropic images) and other methods (isotropically upsampled images)

CellStitch interpolation
- Comparison between CellStitch interpolation vs. image interpolation (bilinear)
- Comparison between 1-Wasserstein vs. 2-Wasserstein in cost matrix design

## Segmentation benchmarks

In [1]:
import h5py
import numpy as np
import pandas as pd

from scipy.ndimage import zoom
from cellpose import models
from cellstitch.evaluation import *

from IPython.display import display

In [2]:
test_filenames = ["N_294_final_crop_ds2.npy", 
                 "N_435_final_crop_ds2.npy",
                 "N_441_final_crop_ds2.npy",
                 "N_511_final_crop_ds2.npy",
                 "N_522_final_crop_ds2.npy",
                 "N_590_final_crop_ds2.npy",
                 "N_593_final_crop_ds2.npy"]
anisotropy = 4

### Generate images

First, generate and store the interpolated images.

In [4]:
for test_filename in test_filenames:
    print("Starting %s ..." % test_filename)
    img = np.load('../DATA/ovules/raw/%s' % test_filename)
    interp_img = zoom(img, (anisotropy, 1, 1))
    
    np.save(
        "../DATA/ovules_interp/raw/%s" % test_filename, 
        interp_img
    )

Starting N_294_final_crop_ds2.npy ...
Starting N_435_final_crop_ds2.npy ...
Starting N_441_final_crop_ds2.npy ...
Starting N_511_final_crop_ds2.npy ...
Starting N_522_final_crop_ds2.npy ...
Starting N_590_final_crop_ds2.npy ...
Starting N_593_final_crop_ds2.npy ...


### Cellpose3d

Benchmark cellpose3d on interpolated images. 

In [4]:
model_dir = '../DATA/ovules/cellpose_train/models/cellpose_residual_on_style_on_concatenation_off_cellpose_train_2023_05_08_09_31_06.231473'
flow_threshold = 1
model = models.CellposeModel(gpu=True, pretrained_model=model_dir)

In [None]:
for test_filename in test_filenames[1:]: 
    print("Starting %s" % test_filename)
    img = np.load("../DATA/ovules_interp/raw/%s" % test_filename)
    masks, _, _ = model.eval(img, do_3D=True, flow_threshold=flow_threshold, channels = [0,0])
    np.save("./results/ovules_interp/%s" % test_filename, masks)

Starting N_435_final_crop_ds2.npy


In [5]:
for test_filename in test_filenames[0:]: 
    img = np.load("../DATA/ovules_interp/raw/%s" % test_filename)
    print(test_filename, img.shape)

N_294_final_crop_ds2.npy (1280, 960, 1000)
N_435_final_crop_ds2.npy (1552, 1101, 1110)
N_441_final_crop_ds2.npy (1776, 1095, 1028)
N_511_final_crop_ds2.npy (1040, 810, 715)
N_522_final_crop_ds2.npy (1480, 810, 935)
N_590_final_crop_ds2.npy (680, 555, 770)
N_593_final_crop_ds2.npy (560, 480, 1203)


Evaluate: 

In [10]:
data = []

for test_filename in ["N_294_final_crop_ds2.npy", 
                      "N_511_final_crop_ds2.npy",
                      "N_522_final_crop_ds2.npy",
                      "N_590_final_crop_ds2.npy",
                      "N_593_final_crop_ds2.npy"]:
    print("Starting %s" % test_filename) 
    labels = np.load('../DATA/ovules/labels/%s' % test_filename)
    masks = np.load("./results/ovules_interp/%s" % test_filename)
    masks = masks[::anisotropy,] # downsample to original size
    
    ap25, _, _, _ = average_precision(labels, masks, 0.25)
    ap50, tp, fp, fn = average_precision(labels, masks, 0.5)
    ap75, _, _, _ = average_precision(labels, masks, 0.75) 

    if (tp + fp) != 0: 
        precision = tp / (tp + fp)
    else: 
        precision = 0

    if tp + fp != 0: 
        recall = tp / (tp + fn)
    else: 
        precision = 0

    row = [ 
        test_filename, 
        ap25,
        ap50,
        ap75,
        tp, 
        fp, 
        fn, 
        precision,
        recall
    ]

    data.append(row)

Starting N_294_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy
Starting N_522_final_crop_ds2.npy
Starting N_590_final_crop_ds2.npy
Starting N_593_final_crop_ds2.npy


NameError: name 'pd' is not defined

In [12]:
df = pd.DataFrame(data, columns=[
        "filename",
        "ap25", 
        "ap50",
        "ap75", 
        "tp", 
        "fp", 
        "fn",
        "precision",
        "recall"
    ])

df.to_csv("./results/ovules_interp/cellpose3d.csv", index=False)

### PlantSeg

In [19]:
for test_filename in test_filenames: 
    img = np.load("../DATA/ovules_interp/raw/%s" % test_filename)
        
    with h5py.File("../DATA/ovules_interp/plantseg_test/%s.h5" % test_filename[0:-4], 'w') as hf:
        hf.create_dataset("raw",  data=img)

- activate the environment: `conda activate plant-seg` 
- set the `path` in `config.yaml` to `../DATA/ovules_interp/plantseg_test/` 
- perform segmentation with the `confocal_3D_unet_ovules_ds1x` by running `plantseg --config config.yaml`

In [3]:
plantseg_results_folder = "../DATA/ovules_interp/plantseg_test/PreProcessing/confocal_3D_unet_ovules_ds1x/MultiCut"

In [4]:
for test_filename in test_filenames: 
    print("Starting %s" % test_filename) 
    
    with h5py.File("%s/%s_predictions_multicut.h5" % (plantseg_results_folder, test_filename[:-4]), "r") as f:
        plantseg = np.array(list(f['segmentation'])) 
        
    plantseg[np.where(plantseg == 1)] = 0 # plantseg use 1 as labels
    np.save("./results/ovules_interp/plantseg/%s" % test_filename, plantseg)

Starting N_294_final_crop_ds2.npy
Starting N_435_final_crop_ds2.npy
Starting N_441_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy
Starting N_522_final_crop_ds2.npy
Starting N_590_final_crop_ds2.npy
Starting N_593_final_crop_ds2.npy


In [5]:
data = []

for test_filename in test_filenames:
    print("Starting %s" % test_filename) 
    labels = np.load('../DATA/ovules/labels/%s' % test_filename)
    masks = np.load("./results/ovules_interp/plantseg/%s" % test_filename)
    masks = masks[::anisotropy,] # downsample to original size
    
    ap25, _, _, _ = average_precision(labels, masks, 0.25)
    ap50, tp, fp, fn = average_precision(labels, masks, 0.5)
    ap75, _, _, _ = average_precision(labels, masks, 0.75) 

    if (tp + fp) != 0: 
        precision = tp / (tp + fp)
    else: 
        precision = 0

    if tp + fp != 0: 
        recall = tp / (tp + fn)
    else: 
        precision = 0

    row = [ 
        test_filename, 
        ap25,
        ap50,
        ap75,
        tp, 
        fp, 
        fn, 
        precision,
        recall
    ]

    data.append(row)

Starting N_294_final_crop_ds2.npy


  iou = overlap / (n_pixels_pred + n_pixels_true - overlap)


Starting N_435_final_crop_ds2.npy
Starting N_441_final_crop_ds2.npy
Starting N_511_final_crop_ds2.npy
Starting N_522_final_crop_ds2.npy
Starting N_590_final_crop_ds2.npy
Starting N_593_final_crop_ds2.npy


In [6]:
df = pd.DataFrame(data, columns=[
        "filename",
        "ap25", 
        "ap50",
        "ap75", 
        "tp", 
        "fp", 
        "fn",
        "precision",
        "recall"
    ])

df.to_csv("./results/ovules_interp/plantseg.csv", index=False)

## Interpolation benchmark

In [None]:
import os
import sys
import time

import cv2
import h5py
import napari
import numpy as np
import pandas as pd

from skimage.color import label2rgb

In [None]:
sys.path.append('../')
from cellpose import models as cp_models
from cellstitch.interpolate import full_interpolate
from cellstitch import evaluation


In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager
import seaborn as sns
from matplotlib import rcParams

sns.set_style('white')

font_list = []
fpaths = matplotlib.font_manager.findSystemFonts()
for i in fpaths:
    try:
        f = matplotlib.font_manager.get_font(i)
        font_list.append(f.family_name)
    except RuntimeError:
        pass

font_list = set(font_list)
plot_font = 'Helvetica' if 'Helvetica' in font_list else 'FreeSans'

rcParams['font.family'] = plot_font
rcParams.update({'font.size': 15})

params = {'mathtext.default': 'regular'}
plt.rcParams.update(params)

### Load dataset
7 Ovules images (test set)

In [None]:
data_path = '../data/ovules/test/'

imgs = [
    h5py.File(os.path.join(data_path, f))['raw']
    for f in sorted(os.listdir(data_path))
    if 'h5' in f[-2:]
]

In [None]:
def upsample_img(img, anisotropy, method='bilinear'):
    assert method == 'bilinear' or method == 'bicubic'
    order = 1 if method == 'bilinear' else 3
    return zoom(img, (anisotropy, 1, 1), order=order)

### (1). Cellpose 3D on isotropic upsampled image

In [None]:
cp_model = cp_models.CellposeModel(
    gpu=False, 
    pretrained_model='../results/cellpose_residual_on_style_on_concatenation_off_cellpose_train_2023_05_08_09_31_06.231473'
)

In [None]:
preds = []

for img in imgs:
    up_img = upsample_img(img, 4)
    res = cp_model.eval(up_img, do_3D=True, channels=[0, 0])
    upsampled_pred = res[0].copy()
    del up_img, res
    
    # Reconver predictions in orig. slices
    nz = upsampled_pred.shape[0]
    slc_indices = np.arange(0, nz, 4)
    preds.append(upsampled_pred[slc_indices])
    
    del upsampled_pred
    gc.collect()
    
del nz, slc_indices

### (2). CellStitch interpolation (1-Wasserstein vs. 2-Wasserstein)

In [None]:
low_anis_path = '../results/ovules/cellstitch/'
high_anis_path = '../results/ovules_subsampled/cellstitch/'
mask_path = '../data/ovules/test/'

# Interpolate mask predictions
high_anis_preds = [
    np.load(os.path.join(high_anis_path, f)).astype(np.int32)
    for f in sorted(os.listdir(high_anis_path))
    if f[-3:] == 'npy'
]

# 1-Wasserstein
low_anis_interp_preds = []
t0 = time.perf_counter()
for high_anis_pred in high_anis_preds:
    interp_pred = full_interpolate(high_anis_pred, dist='cityblock')
    low_anis_interp_preds.append(interp_pred)
t1 = time.perf_counter()

# 2-Wasserstein
low_anis_interp_l2_preds = []
t0 = time.perf_counter()
for high_anis_pred in high_anis_preds:
    interp_pred = full_interpolate(high_anis_pred, dist='euclidean')
    low_anis_interp_l2_preds.append(interp_pred)
t1 = time.perf_counter()

print('2-Wasserstein takes {} seconds'.format(t1-t0))

print('1-Wasserstein takes {} seconds'.format(t1-t0))

# Load ground-truth masks
fnames = [
    f.rpartition('.')[0]
    for f in sorted(os.listdir(mask_path))
]
masks = [
    h5py.File(os.path.join(mask_path, f))['label'][:]
    for f in sorted(os.listdir(mask_path))
    if f[-2:] == 'h5'
]

interp_metrics = np.zeros((len(low_anis_interp_preds), 4))

for i, (y_true, y_pred) in enumerate(zip(masks, low_anis_interp_preds)):
    # interpretation predictions will have depth = depth(mask)-1
    # compare the first (z-1) layers
    y_pred = y_pred.astype(np.int64)
    y_true = y_true[:y_pred.shape[0], :, :]
    ap, tp, fp, fn = evaluation.average_precision(y_true, y_pred, 0.5)
    prec, recall, f1 = tp/(tp+fp), tp/(tp+fn), tp/(tp+0.5*(fp+fn))
    
    interp_metrics[i] = [ap, prec, recall, f1]
    
del y_true, y_pred, ap, tp, fp, fn, prec, recall, f1

interp_metrics_df = pd.DataFrame(interp_metrics, index=fnames, columns=['ap', 'prec', 'recall', 'f1'])
display(interp_metrics_df)

# save results to output
interp_metrics_df.to_csv('../results/ovules/interp_metrics_manhattan.csv', index=0)