## Example notebook for running CellStitch

In [None]:
import os
import numpy as np
import torch
import tifffile
from cellpose.models import Cellpose
from skimage import io
import matplotlib.pyplot as plt


import h5py

from cellstitch.pipeline import full_stitch

In [None]:
# Plotting specifications
from matplotlib import rcParams
from IPython.display import display
rcParams.update({'font.size': 10})

### (1). Load example pairs of raw image & ground-truth mask

In [None]:
# Fill in filename for raw image (including the paths)
filename = 'Test_images/BFP_60.tif'
# maskname = '<path>/<filename>'

# Fill in on the path you would like to store the stitched mask
output_path = 'output/'
output_filename = 'BFP_60.npy'
#Todo preprcessing: avy the image so 7 instead of 21 slices

In [None]:
def get_files(folder_path):
    return [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]


file_list= get_files("Test_images")

for file in file_list:
    print(f"doing {file}")
    pathName = f"Test_images/{file}"
    tifffile.imread(pathName)

    flow_threshold = 1
    use_gpu = True if torch.cuda.is_available() else False
    # print(use_gpu)
    model = Cellpose(model_type='cyto3', gpu=use_gpu)
    flow_threshold = 0.4

    xy_masks, _, _, _ = model.eval([img], flow_threshold=flow_threshold, channels = [0,0])
    xy_masks = np.array(xy_masks)
    print(np.unique(xy_masks))
    output_filename = f'{file}.npy'
    print(output_filename)


    np.save(os.path.join(output_path, output_filename), xy_masks)
    image_mask = np.load(f"output/{output_filename}")

    plt.figure(figsize=(10, 5))

    for i in range(1):  # Show first 3 slices
    # plt.subplot(1, 8, i + 1)
        plt.imshow(image_mask[i], cmap="gray")
        plt.title(f"{output_filename}")
        plt.axis("off")

    plt.show()
    


    


    


Example code snippet
```python
filename = '../data/plant_atlas/Anther/Anther_20.tif'
maskname = '../data/plant_atlas/Anther_masks/Anther_20.tif'

output_path = '../results/'
if not os.exist(output_path):
    os.makedirs(output_path, exist_ok=True)
    
output_filename = 'cellstitch_sample_pred.npy'
```

In [None]:
# Load image & masks
if filename[-3:] == 'npy':  # image in .npy format
    img = np.load(filename)
elif filename[-3:] == 'tif': # imagge in TIFF format
    img = tifffile.imread(filename)
else:
    try:
        img = io.imread(filename)
    except:
        raise IOError('Failed to load image {}'.format(filename))
print(img.shape)
# with open()

# if maskname[-3:] == 'npy':  # image in .npy format
#     mask = np.load(maskname)
# elif filename[-3:] == 'tif': # imagge in TIFF format
#     mask = tifffile.imread(maskname)
# else:
#     try:
#         mask = io.imread(maskname)
#     except:
#         raise IOError('Failed to load image {}'.format(filename))


In [None]:

with h5py.File(filename, 'r') as f:
    # Access a dataset
    # data = f['dataset_name'][:] 

    # # Access a group
    # group = f['group_name']

    # List all keys in the file
    # array = f["raw"][:]
    # print(array)
    # print(array.shape)
    # print(list(f.keys()))
    # print(type(f['label']))
    img = np.array(f['raw'][:20])
    print(img.shape)


In [None]:
import numpy as np
from PIL import Image

# Open the input image as numpy array, convert to greyscale and drop alpha
npImage=np.array(Image.open(filename).convert("L"))

# Get brightness range - i.e. darkest and lightest pixels
min=np.min(npImage)        # result=144
max=np.max(npImage)        # result=216

# Make a LUT (Look-Up Table) to translate image values
LUT=np.zeros(256,dtype=np.uint8)
LUT[min:max+1]=np.linspace(start=0,stop=255,num=(max-min)+1,endpoint=True,dtype=np.uint8)
print(type(LUT[npImage]))

print(LUT[npImage].shape)

In [None]:
plt.figure(figsize=(30, 15))
#todo convert grayscale to colour here
for i in range(3, 8):  # Show first 3 slices
    plt.subplot(3, 8, i + 1)
    plt.imshow(img[i], cmap="gray")
    plt.title(f"Slice {i}")
    plt.axis("off")

plt.show()
#if it's good, duplicaet it for each slice so the endpoint is the same dimensions

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(npImage, cmap="gray")
plt.title(f"Slice {i}")
plt.axis("off")

plt.show()

### (2). Define configs & parameters

In [None]:
# load cellpose model for backbone segmentation
# you can also replace with any 2D segmentation model that works the best for your dataset
flow_threshold = 1
use_gpu = True if torch.cuda.is_available() else False
print(use_gpu)
model = Cellpose(model_type='cyto2', gpu=use_gpu)
flow_threshold = 0.4

In [None]:
xy_masks, _, _, _ = model.eval([img], flow_threshold=flow_threshold, channels = [0,0])
xy_masks = np.array(xy_masks)
print(np.unique(xy_masks))

In [None]:
output_filename = 'BFP_60_cyto2_cluster.npy'



### (3). Run CellStitch

In [None]:
print(flow_threshold)
# Should show object labels; if only [0], it's empty.

yz_masks, _, _, _ = model.eval([img.transpose(1,0,2)], flow_threshold=flow_threshold, channels = [0,0])
yz_masks = np.array(yz_masks).transpose(1,0,2)
print("done yz")

xz_masks, _, _, _ = model.eval([img.transpose(2,1,0)], flow_threshold=flow_threshold, channels = [0,0])
xz_masks = np.array(xz_masks).transpose(2,1,0)
print("done xz")

cellstitch_masks = full_stitch(xy_masks, yz_masks, xz_masks)

### (4). Save the Stitching results:

In [None]:
np.save(os.path.join(output_path, output_filename), xy_masks)


In [None]:
np.save(os.path.join(output_path, output_filename), cellstitch_masks)

---

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load the stitched masks

In [None]:
cellstitch_masks1 = np.load("output/60_nuclei_c2.npy")
cellstitch_masks2 = np.load("output/60_cyto3_c2.npy")
cellstitch_masks3 = np.load("output/h5_nuclei.npy")
cellstitch_masks4 = np.load("output/h5_cyto2.npy")
cellstitch_masks5 = np.load("output/h5_cyto3.npy")

In [None]:
image_mask = np.load("output/tcell_T010.tif.npy")

In [None]:
# # Print unique values to check detected masks
# print("Unique labels in stitched mask:", np.unique(cellstitch_masks1))  # If only [0], no masks detected
# print("Unique labels in stitched mask:", np.unique(cellstitch_masks2))  # If only [0], no masks detected
# print("Unique labels in stitched mask:", np.unique(cellstitch_masks3))  # If only [0], no masks detected
# print("Unique labels in stitched mask:", np.unique(cellstitch_masks4))  # If only [0], no masks detected
# print("Unique labels in stitched mask:", np.unique(cellstitch_masks5))  # If only [0], no masks detected

# # Check shape
# print("Shape of cellstitch_masks:", cellstitch_masks.shape)  # Should be (21, 464, 500) if same as input

# Plot a few slices
# plt.figure(figsize=(15, 8))
plt.figure(figsize=(30, 15))

for i in range(1):  # Show first 3 slices
    # plt.subplot(1, 8, i + 1)
    plt.imshow(image_mask[i], cmap="gray")
    plt.title(f"Slice {i}")
    plt.axis("off")
# plt.savefig(f"output/{image_mask}.png")
plt.show()

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(image_mask, cmap="gray")
plt.title(f"Slice")
plt.axis("off")

plt.show()

In [None]:
import tifffile

# Save as TIFF
tifffile.imwrite(f"C:/Users/bobei/OneDrive/Documents/school/college/research/cellstitch/stitched_mask.tif", cellstitch_masks.astype(np.uint16))


take the output and extract statitics (how many cells segmented; volume (???) of the cells (how many pixels per cell), distribution over entire population of cells; compute overalp in pixel assignment{?!?!?})
3 time points; connect cells (how); overlay if the same; propagate labels throughout time; but cells move :(
mimimise sum of shifts??????????? by computing distance ??????? and the sum of the pixels is the minimum distance (pick the one that moves the least) [non trivial problem yay] 
compute on t vs z and max project for each ?? (wait that might be not nothing)
go from microscope file (czi) to tif