# Imports

In [None]:
import pandas as pd
from matplotlib import pyplot as plt

In [None]:
from djimaging.user.alpha.schemas.alpha_schema import *
from djimaging.user.alpha.utils import populate_alpha

# Choose indicator by uncommenting one of the following lines
indicator = 'calcium'
#indicator = 'glutamate'

populate_alpha.load_alpha_config(schema_name=populate_alpha.SCHEMA_PREFIX + ("ca" if indicator == 'calcium' else "glu"))
populate_alpha.load_alpha_schema(create_schema=False, create_tables=False)

In [None]:
schema

# Dewarping

In [None]:
import tifffile as tiff
import os

from djimaging.utils import scanm_utils
from djimaging.utils.math_utils import normalize_soft_zero_one
from djimaging.user.alpha.tables.unwarped_morph import unwarp_utils

In [None]:
def norm_stack(ch_stack):
    mid_layer = ch_stack.shape[2]//2

    ch_fill = np.nanpercentile(ch_stack[:, :, mid_layer-10:mid_layer+10], q=.5)   
    ch_stack[~np.isfinite(ch_stack)] = ch_fill
    ch_stack[ch_stack <= ch_fill] = ch_fill
    ch_stack = (normalize_soft_zero_one(ch_stack, dq=.5, dq_high=99.5, clip=True) * 255).astype(np.uint16)

    return ch_stack

In [None]:
def get_stacks(stack_h5_file):
    ch_stacks, wparams = scanm_utils.load_stacks_from_h5(stack_h5_file)

    ch0_stack = ch_stacks['wDataCh0']
    ch1_stack = ch_stacks['wDataCh1']

    ch0_stack = norm_stack(ch0_stack)
    ch1_stack = norm_stack(ch1_stack)

    return ch0_stack, ch1_stack

## Create stacks

In [None]:
from tqdm.notebook import tqdm

keys = (Field() & "field='stack'").fetch('KEY')

only_new_tifs = True

if input('Create tiff files? [y/n]').lower() == 'y':
    for key in tqdm(keys):
        stack_h5_file = (Field() & key).fetch1('fromfile')
        
        ch0_outputfile = stack_h5_file.replace('/Pre/', '/Morph/').replace('.h5', '_ch0.tif')
        ch1_outputfile = stack_h5_file.replace('/Pre/', '/Morph/').replace('.h5', '_ch1.tif')
    
        if os.path.exists(ch1_outputfile):
            print(ch1_outputfile, 'already existed!')
            if only_new_tifs:
                continue
            
        if not os.path.exists(os.path.dirname(ch1_outputfile)):
            raise FileNotFoundError(os.path.dirname(ch1_outputfile))

        soma_xyz_um = (MorphPaths & key).fetch1('soma_xyz')
        pixel_size_um = (Field() & key).fetch1('pixel_size_um')
        z_step_um = (Field() & key).fetch1('z_step_um')
        soma_xyz_px = np.asarray(soma_xyz_um) / np.array([pixel_size_um, pixel_size_um, z_step_um])
        df_paths = pd.DataFrame((MorphPaths() & key).fetch1('df_paths'))
    
        ch0_stack, ch1_stack = get_stacks(stack_h5_file)
        
        tiff.imwrite(ch0_outputfile, ch0_stack.T)
        tiff.imwrite(ch1_outputfile, ch1_stack.T)

# Unwarp stacks

## Iterate over files

This was used to easily open the files in NeuTube

In [None]:
keys = (Field() & "field='stack'").fetch('KEY')

In [None]:
from IPython.display import clear_output

if input('Iterate files? [y/n]').lower() == 'y':
    while len(keys) > 0:
        key = keys.pop(0)
        stack_h5_file = (Field() & key).fetch1('fromfile')
        outputdir = os.path.dirname(stack_h5_file.replace('/Pre/', '/Morph/'))    
        print(os.listdir(outputdir))
    
        if input('Continue [y/n]') == 'y':
            clear_output()
            continue
        else:
            print(os.listdir(outputdir))
            break

## Create flat morphs

In [None]:
%matplotlib widget

from IPython.display import clear_output

keys = (Field() & "field='stack'").fetch('KEY')

plt.close('all')

interactive = True
only_new = False
cell_tags_only = ['t3']

while len(keys) > 0:
    plt.close('all')
    
    key = keys.pop(0)

    stack_h5_file = (Field() & key).fetch1('fromfile')
    swc_path = (SWC & key).fetch1('swc_path')

    outputdir = os.path.dirname(stack_h5_file.replace('/Pre/', '/Morph/'))
    outputfile_swc = os.path.join(outputdir, 'morph_flat.swc')

    cell_tag = (CellTags & key).fetch1('cell_tag')
    
    if only_new:
        if os.path.exists(outputfile_swc):
            clear_output()
            continue

    if cell_tags_only and cell_tag not in cell_tags_only:
        continue
    
    # Fetch data
    soma_xyz_um = (MorphPaths & key).fetch1('soma_xyz')
    pixel_size_um = (Field() & key).fetch1('pixel_size_um')
    z_step_um = (Field() & key).fetch1('z_step_um')
    voxel_size = np.array([pixel_size_um, pixel_size_um, z_step_um])
    
    soma_xyz_px = np.asarray(soma_xyz_um) / voxel_size
    df_paths = pd.DataFrame((MorphPaths() & key).fetch1('df_paths'))

    # Load data
    df_swc = unwarp_utils.pd_read_swc(swc_path)

    if not any(df_swc.type == 3):
        df_swc["type"] = 3
        df_swc.at[0, "type"] = 1
    
    ch0_stack, ch1_stack = get_stacks(stack_h5_file)

    df_lower_bvs = unwarp_utils.pd_read_swc(os.path.join(outputdir, 'lower_bvs.swc'))
    df_upper_bvs = unwarp_utils.pd_read_swc(os.path.join(outputdir, 'upper_bvs.swc'))

    # Fit gams
    gam_lower = unwarp_utils.fit_gam(df_lower_bvs, f_space=50, lam=0.01, plane='none', penalties='derivative')
    gam_upper = unwarp_utils.fit_gam(df_upper_bvs, f_space=50, lam=0.01, plane='none', penalties='derivative')
    
    # Plot vessels
    if interactive:
        #unwarp_utils.plot_stack_and_vessels(df_lower=df_lower_bvs, df_upper=df_upper_bvs, stack=ch1_stack, soma_xyz_px=soma_xyz_px)
        unwarp_utils.plot_fits(gam_lower, df_lower_bvs, gam_upper, df_upper_bvs, [path / voxel_size for path in df_paths.path])
    
    df_swc_flat, d_med_um = unwarp_utils.unwarp_swc(df_swc, gam_lower, df_lower_bvs, gam_upper, df_upper_bvs, pixel_size_um, z_step_um, plot=False)

    print(d_med_um)

    df_swc_px = df_swc.copy()
    df_swc_px.x /= pixel_size_um
    df_swc_px.y /= pixel_size_um
    df_swc_px.z /= z_step_um
    
    unwarp_utils.pd_save_swc(df_swc_px, outputfile_swc.replace('_flat', '_px_space'))
    unwarp_utils.pd_save_swc(df_swc_flat, outputfile_swc, comment=f'IPL-width={d_med_um:.2f} um')
    
    print(f'saved to {outputfile_swc}', end='\n\n')

    if interactive:
        fig, axs = plt.subplots(1, 2, figsize=(12, 3))
        axs[0].plot(df_swc.x, df_swc.z, '.')
        axs[1].plot(df_swc_flat.x, df_swc_flat.z, '.')
        plt.show()
    
    if interactive:
        break