## Example notebook for running CellStitch

In [1]:
import os
import numpy as np
import torch
import tifffile
from cellpose.models import Cellpose
from skimage import io

from cellstitch.pipeline import full_stitch



In [2]:
# Plotting specifications
from matplotlib import rcParams
from IPython.display import display
rcParams.update({'font.size': 10})

### (1). Load example pairs of raw image & ground-truth mask

In [3]:
# Fill in filename for raw image (including the paths)
filename = '<path>/<filename>'
maskname = '<path>/<filename>'

# Fill in on the path you would like to store the stitched mask
output_path = '<path>'
output_filename = '<filename>'

Example code snippet
```python
filename = '../data/plant_atlas/Anther/Anther_20.tif'
maskname = '../data/plant_atlas/Anther_masks/Anther_20.tif'

output_path = '../results/'
if not os.exist(output_path):
    os.makedirs(output_path, exist_ok=True)
    
output_filename = 'cellstitch_sample_pred.npy'
```

In [4]:
# Load image & masks
if filename[-3:] == 'npy':  # image in .npy format
    img = np.load(filename)
elif filename[-3:] == 'tif': # imagge in TIFF format
    img = tifffile.imread(filename)
else:
    try:
        img = io.imread(filename)
    except:
        raise IOError('Failed to load image {}'.format(filename))

if maskname[-3:] == 'npy':  # image in .npy format
    mask = np.load(maskname)
elif filename[-3:] == 'tif': # imagge in TIFF format
    mask = tifffile.imread(maskname)
else:
    try:
        mask = io.imread(maskname)
    except:
        raise IOError('Failed to load image {}'.format(filename))


### (2). Define configs & parameters

In [5]:
# load cellpose model for backbone segmentation
# you can also replace with any 2D segmentation model that works the best for your dataset
flow_threshold = 1
use_gpu = True if torch.cuda.is_available() else False
model = Cellpose(model_type='cyto2', gpu=use_gpu)


### (3). Run CellStitch

In [7]:
xy_masks, _, _, _ = model.eval(list(img), flow_threshold=flow_threshold, channels = [0,0])
xy_masks = np.array(xy_masks)

yz_masks, _, _, _ = model.eval(list(img.transpose(1,0,2)), flow_threshold=flow_threshold, channels = [0,0])
yz_masks = np.array(yz_masks).transpose(1,0,2)

xz_masks, _, _, _ = model.eval(list(img.transpose(2,1,0)), flow_threshold=flow_threshold, channels = [0,0])
xz_masks = np.array(xz_masks).transpose(2,1,0)

cellstitch_masks = full_stitch(xy_masks, yz_masks, xz_masks)



### (4). Save the Stitching results:

In [21]:
np.save(os.path.join(output_path, output_filename), cellstitch_masks)

---