In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.exposure import match_histograms

import numpy as np
import matplotlib.pyplot as plt
from cellpose import models, io
from cellpose.io import imread

In [None]:
subject_id = 'jm038'

model_type = 'cyto3'

diameter = 6
flow_threshold = 0.6
cellprob_threshold = -6.0

In [None]:
# get all suite2p/plane0 paths
subject_path = os.path.join('data_proc', subject_id[:2], subject_id)

# now get all sub-directories that end with '_a'
session_paths = [os.path.join(subject_path, f) for f in os.listdir(subject_path) if f.endswith('_a')]
session_paths.sort()

all_s2p_path = []

for session_path in session_paths:
    s2p_path = os.path.join(session_path, 'suite2p', 'plane0')
    if os.path.exists(s2p_path):
        all_s2p_path.append(s2p_path)

In [None]:
all_ch2_meanimg = []

for s2p_path in all_s2p_path:
    ops = np.load(os.path.join(s2p_path, 'ops.npy'), allow_pickle=True).item()

    ch2_meanimg = ops['meanImg_chan2']
    all_ch2_meanimg.append(ch2_meanimg)

    ch2_meanimg_uint16 = (ch2_meanimg*2).astype(np.uint16)
    tiff_path = os.path.join(s2p_path, 'meanImg_chan2.tiff')
    plt.imsave(tiff_path, ch2_meanimg_uint16, cmap='gray')


In [None]:
img_paths = [os.path.join(session_paths, 'meanImg_chan2.tiff') for session_paths in all_s2p_path]

In [None]:
io.logger_setup()

# model_type='cyto' or 'nuclei' or 'cyto2' or 'cyto3'
model = models.CellposeModel(model_type=model_type)

files = img_paths

imgs = [imread(f) for f in files]
nimg = len(imgs)

channels = [[0,0]]

masks, flows, styles = model.eval(imgs, diameter=diameter, channels=channels, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)

In [None]:
# do same as above but with matching histograms of the imshow and then in a single figure (subplots) (TODO: make bottom one into a function!, this one is the same as bottom, just without the contours)

fig, axs = plt.subplots(1, nimg, figsize=(10*nimg, 10), dpi=300)

for i, (img, mask) in enumerate(zip(imgs, masks)):
    # match the histogram to the first image
    img_matched = match_histograms(img, imgs[0])
    img_matched = np.clip(img_matched, 0, np.percentile(img_matched, 74))
    img_matched = (img_matched - img_matched.min()) / (img_matched.max() - img_matched.min())

    axs[i].imshow(img_matched, cmap='gray')
    axs[i].axis('off')

In [None]:
# do same as above but with matching histograms of the imshow and then in a single figure (subplots)

fig, axs = plt.subplots(1, nimg, figsize=(10*nimg, 10), dpi=300)

for i, (img, mask) in enumerate(zip(imgs, masks)):
    # match the histogram to the first image
    img_matched = match_histograms(img, imgs[0])
    img_matched = np.clip(img_matched, 0, np.percentile(img_matched, 74))
    img_matched = (img_matched - img_matched.min()) / (img_matched.max() - img_matched.min())

    axs[i].imshow(img_matched, cmap='gray')
    axs[i].axis('off')

    for u in np.unique(mask)[1:]:
        axs[i].contour(mask==u, [0.5], colors=[plt.cm.jet(u/len(np.unique(mask)))], linewidths=0.5)


In [None]:
# for each one in session replace jm038 with jm038_cellpose
session_paths_cellpose = [f.replace(subject_id, f'{subject_id}_cellpose') for f in session_paths]

In [None]:
# now save these masks as stat.npy the same as would be done in suite2p


for i, img_masks in enumerate(masks):
    stat = []
    for u in np.unique(img_masks)[1:]:
        ypix, xpix = np.where(img_masks==u)
        # get the centroid
        med = np.median(xpix), np.median(ypix)
        npix = len(xpix)
        lam = np.ones(npix, np.float32)
        stat.append({'xpix': xpix, 'ypix': ypix, 'med': med, 'npix': npix, 'lam': lam})
    
    iscell = np.ones((len(stat), 2))
        
    ops = np.load(os.path.join(session_paths[i], 'suite2p', 'plane0', 'ops.npy'), allow_pickle=True).item()

    stat_save_path = os.path.join(session_paths_cellpose[i], 'suite2p', 'plane0', 'stat.npy')
    iscell_save_path = os.path.join(session_paths_cellpose[i], 'suite2p', 'plane0', 'iscell.npy')
    ops_save_path = os.path.join(session_paths_cellpose[i], 'suite2p', 'plane0', 'ops.npy')

    # TODO: load ops and make sure to save it in the same path

    os.makedirs(os.path.dirname(stat_save_path), exist_ok=True)
    os.makedirs(os.path.dirname(iscell_save_path), exist_ok=True)
    os.makedirs(os.path.dirname(ops_save_path), exist_ok=True)


    np.save(stat_save_path, stat)
    np.save(iscell_save_path, iscell)
    np.save(ops_save_path, ops)
