In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import seaborn as sns
from pathlib import Path
from pystackreg import StackReg

from scipy.stats import zscore

import utils as utl

# User settings

In [None]:
# number of channels and z slices
n_ch, n_z = 2, 15
# frequencies used for imaging, ball velocities, and behavior
f_ca, f_ball, f_beh = 2, 50, 200

# transformation used for registration https://pystackreg.readthedocs.io/en/latest/readme.html#usage
reg = StackReg.SCALED_ROTATION 

# path to folder
parent_dir = Path(r'\\mpfi.org\public\sb-lab\Nino_2P_for_Salil\for_Nico\stop1_imaging\stop1-GCaMP6f-tdTomato_VNC')

# selection rule for tif files
p_tifs = parent_dir.glob('**/trials_to_register/*/trial*_000??.tif')

# folders to skip
skip = [
    r'\\mpfi.org\public\sb-lab\Nino_2P_for_Salil\for_Nico\stop1_imaging\stop1-GCaMP6f-tdTomato_VNC\fed\female1\trials_to_register\trial9_00001',
    r'\\mpfi.org\public\sb-lab\Nino_2P_for_Salil\for_Nico\stop1_imaging\stop1-GCaMP6f-tdTomato_VNC\fed\female1\trials_to_register\trial8_00001',
]
skip = [ Path(p) for p in skip ]
p_tifs = [ p for p in p_tifs if p.parent not in skip ]

# collect all data
p_all = Path('./all_data.parquet')

# Step 1: Registration

In [None]:
for p_tif in p_tifs:
    print()
    
    p_out =  lambda x: p_tif.parent / '{}_{}'.format(p_tif.with_suffix('').name, x)

    if p_out('ch1.tif').is_file():
        print(f'INFO output file exists, skipping registration for {p_tif.parent}')
        continue
        
    print(f'INFO registering {p_tif}')

    # load
    stack = utl.load_tiff(p_tif)
    ch1, ch2 = utl.split_channels(stack, n_z=15, n_ch=2)
    ch1 = utl.maxproj_z(ch1)
    ch2 = utl.maxproj_z(ch2)

    # register
    tmats = utl.get_tmats(ch2, reg)
    ch1_a = utl.align(ch1, tmats, reg)
    ch2_a = utl.align(ch2, tmats, reg)

    # mean image
    ch1_am = np.mean(ch1_a, axis=0)
    ch2_am = np.mean(ch2_a, axis=0)

    # save to disk
    utl.write_tif(p_out('ch1.tif'), ch1_a.astype('int16'))
    utl.write_tif(p_out('ch2.tif'), ch2_a.astype('int16'))

    utl.write_tif(p_out('ch1reg.tif'), ch1_a.astype('int16'))
    utl.write_tif(p_out('ch2reg.tif'), ch2_a.astype('int16'))

    utl.save_img(p_out('ch1mean.bmp'), ch1_am)
    utl.save_img(p_out('ch2mean.bmp'), ch2_am)

    utl.save_dual_movie(p_out('ch1ch2.mp4'), ch1, ch2)
    utl.save_dual_movie(p_out('ch1reg.mp4'), ch1, ch1_a)
    utl.save_dual_movie(p_out('ch2reg.mp4'), ch2, ch2_a)
    


# Step 2: ROI extraction

In [None]:
for p_tif in p_tifs:
    print()
    
    p_out =  lambda x: p_tif.parent / '{}_{}'.format(p_tif.with_suffix('').name, x)

    # check if ROI traces have already been extracted
    p_roi = p_out('roi_traces.npy')
    if p_roi.is_file():
        print(f'INFO output files exists, skipping ROI extraction for {p_tif.parent}')
        continue

    # check if only one RoiSet.zip file is present
    l_zip = [ *p_tif.parent.glob('*RoiSet.zip') ]
    if len(l_zip) == 0:
        print(f'INFO no *RoiSet.zip file found. Skipping {p_tif.parent}')
        continue
    elif len(l_zip) > 1:
        print(f'WARNING folder must contain no more than one `*RoiSet.zip` file: skipping {p_tif.parent}')
        continue

    p_zip = l_zip[0]
    print(f'INFO loading ROIs from {p_zip}')

    # load aligned ch1
    stack = utl.load_tiff(p_out('ch1reg.tif'))
    img = np.mean(stack, axis=0)

    # load ROIs
    rois = utl.read_imagej_rois(p_zip, img)
    img_rois = utl.draw_rois(img, rois)
    utl.save_img(p_out('ch1mean_rois.bmp'), img_rois)

    # extract traces
    ca = utl.get_mean_trace(rois, stack, subtract_background=True, sigma=0)
    np.save(p_roi, ca)    
    print(f'INFO saving ROI traces to {p_roi}')



# Step 3: Merge data

In [None]:
for p_tif in p_tifs:
    print()
    
    p_out =  lambda x: p_tif.parent / '{}_{}'.format(p_tif.with_suffix('').name, x)

    p_df = p_out('data.parquet')

    if p_df.is_file():
        print(f'INFO output files exists, skipping data merging for {p_tif.parent}')
        continue

    # load ROI traces
    p_roi = p_out('roi_traces.npy')
    if not p_roi.is_file():
        print(f'INFO file with ROI traces not found, skipping {p_tif.parent}')
    else:
        ca = np.load(p_roi)

    # load behavior data
    p_ball = p_tif.parent / (p_tif.name.split('_')[0] + '.mat')
    l_act = [*p_tif.parent.glob('*-actions.mat')]
    if not (p_ball.is_file() and len(l_act) == 1):
        print(f'INFO matlab files missing or incorrect name: skipping {p_tif.parent}')
        continue

        
    ball = utl.load_ball(p_ball)

    p_beh = l_act[0]
    beh = utl.load_behavior(p_beh)

    df = utl.upsample_to_behavior(ca, beh, ball, f_ca, f_ball, f_beh)

    df = utl.zscore_rois(df)

    df = utl.convolute_ca_kernel(df, f=f_beh)

    pt = p_tif.parts
    cond, fly, trial = pt[-5], pt[-4], pt[-2]
    df.loc[:, 'cond'] = cond
    df.loc[:, 'fly'] = fly
    df.loc[:, 'trial'] = trial
    print(f'INFO parsing folder names: fly {fly} | trial {trial} | condition {cond}')

    # plot data
    utl.plot_data(df, f_beh, path=p_out('data.png'))
    # plot pearson r heatmap
    utl.plot_corr_heatmap(df, path=p_out('heatmap.png'))
    # plot ccf
    utl.plot_ccf(df, f=f_beh, pool_fly=True, path=p_out('ccf.png'))

    print(f'INFO writing merged data to {p_df}')
    df.to_parquet(p_df)



# Step 4: merge all trials

In [None]:
# merge all trials and flies

g = parent_dir.glob('**/trials_to_register/*/*_data.parquet')

l = []
for f in g:  
    print()
    
    print(f'INFO loading file {f}')
    d = pd.read_parquet(f)
    l.append(d)

print(f'INFO writing all data to {p_all}')
df = pd.concat(l, ignore_index=True)
df.to_parquet(p_all)


In [None]:
# read data from disk
df = pd.read_parquet(p_all)
print('INFO dataframe contains')
for f, d in df.groupby('fly'):
    print(f'     {f}', end=': ')
    for t, _ in d.groupby('trial'):
        print(f'{t}', end=' ')
    print()

# plot averages
utl.plot_corr_heatmap(df, path=p_all.parent / 'heatmap.png')
utl.plot_ccf(df, f=f_beh, pool_fly=True,  path=p_all.parent / 'ccf.png')
utl.plot_ccf(df, f=f_beh, pool_fly=False, path=p_all.parent / 'ccf_indv.png')
