In [None]:
import os
import numpy as np

import numpy as np
import matplotlib.pyplot as plt
from cellpose import models, io
from cellpose.io import imread

# Suite2p trace extraction
from suite2p.extraction.extract import extraction_wrapper

from cellpose_track2p.io import tiff_loader, get_session_s2p_paths
from cellpose_track2p.compute import get_spks, generate_all_meantiff, masks_to_stat, str_gad
from cellpose_track2p.plot import plot_fov_masks, plot_raster

# prevent having to refresh
%load_ext autoreload
%autoreload


In [None]:
subject_id = 'jm038'

model_type = 'cyto3'

diameter = 6
flow_threshold = 0.6
cellprob_threshold = -6.0

take_first_nsessions = 3 # change to None later on the server

In [None]:
cellpose_subject_id = 'jm038_gad'

In [None]:
def str_degad(string):
    # remove 'gad' from string
    return string.replace('_gad', '')

In [None]:
session_paths, all_s2p_path = get_session_s2p_paths(subject_id, take_first_nsessions=take_first_nsessions)

In [None]:
# make the full file structure for cellpose
for path in all_s2p_path:
    path_gad = str_gad(path, subject_id)
    if not os.path.exists(path_gad):
        os.makedirs(path_gad)
        

In [None]:
generate_all_meantiff(all_s2p_path)

In [None]:
img_paths = [os.path.join(str_gad(s2p_path, subject_id), 'meanImg_chan2.tiff') for s2p_path in all_s2p_path]

In [None]:
io.logger_setup()

# model_type='cyto' or 'nuclei' or 'cyto2' or 'cyto3'
model = models.CellposeModel(model_type=model_type)

files = img_paths

imgs = [imread(f) for f in files]
nimg = len(imgs)

channels = [[0,0]]

masks, flows, styles = model.eval(imgs, diameter=diameter, channels=channels, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)

In [None]:
# do same as above but with matching histograms of the imshow and then in a single figure (subplots) (TODO: make bottom one into a function!, this one is the same as bottom, just without the contours)

plot_fov_masks(imgs, masks, show_masks=False)
plot_fov_masks(imgs, masks)


In [None]:
# for each one in session replace jm038 with jm038_cellpose
session_paths_gad = [str_gad(f, subject_id) for f in session_paths]

In [None]:
masks_to_stat(masks, session_paths, session_paths_gad)

In [None]:

#  paths to motion corrected files (for now just root of session)
for session_path in session_paths_gad:
    
    ops = np.load(os.path.join(session_path, 'suite2p', 'plane0', 'ops.npy'), allow_pickle=True).item()
    stat = np.load(os.path.join(session_path, 'suite2p', 'plane0', 'stat.npy'), allow_pickle=True)
    
    tiffs_path_ch0 = os.path.join(str_degad(session_path), 'suite2p', 'plane0', 'reg_tif') # change this before deploying properly
    tiffs_path_ch1 = os.path.join(str_degad(session_path), 'suite2p', 'plane0', 'reg_tif_chan2') # change this before deploying properly

    # get all .tif files in the folder
    tiff_files_ch0 = [f for f in os.listdir(tiffs_path_ch0) if f.endswith('_chan0.tif')]
    tiff_files_ch1 = [f for f in os.listdir(tiffs_path_ch1) if f.endswith('_chan1.tif')]

    tiff_files_ch0.sort()
    tiff_files_ch1.sort()

    tf_ch0 = tiff_loader(tiffs_path_ch0, tiff_files_ch0)
    tf_ch1 = tiff_loader(tiffs_path_ch1, tiff_files_ch1)

    # now extract the traces
    ops['allow_overlap'] = True # to avoid suite2p bug -> anyways overlap is close to 0 in sparse labelling
    stat, F, Fneu, F_chan2, Fneu_chan2 = extraction_wrapper(stat, tf_ch0, f_reg_chan2=tf_ch1, cell_masks=None, neuropil_masks=None, ops=ops)

    spks = get_spks(F, Fneu, ops)

    # plot F raster
    plot_raster(F)


    # save the traces
    F_path = os.path.join(session_path, 'suite2p', 'plane0', 'F.npy')
    Fneu_path = os.path.join(session_path, 'suite2p', 'plane0', 'Fneu.npy')
    F_chan2_path = os.path.join(session_path, 'suite2p', 'plane0', 'F_chan2.npy')
    Fneu_chan2_path = os.path.join(session_path, 'suite2p', 'plane0', 'Fneu_chan2.npy')
    spks_path = os.path.join(session_path, 'suite2p', 'plane0', 'spks.npy')
    
    np.save(F_path, F)
    np.save(Fneu_path, Fneu)
    np.save(F_chan2_path, F_chan2)
    np.save(Fneu_chan2_path, Fneu_chan2)
    np.save(spks_path, spks)
    
    


In [None]:
# add plotting of all rasters