# Imports

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
%%capture

from djimaging.user.alpha.schemas.alpha_schema import *
from djimaging.user.alpha.utils import populate_alpha
from djimaging.utils.dj_utils import get_primary_key

populate_alpha.load_alpha_config(schema_name=populate_alpha.SCHEMA_PREFIX + "ca")
populate_alpha.load_alpha_schema(create_schema=True, create_tables=True)

# ERD

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore", FutureWarning)
    display(dj.ERD(schema))

# Core

In [None]:
if len(Experiment().proj()) == 0:
    populate_alpha.populate_experiments(verbose=True)

In [None]:
populate_alpha.populate_core(verbose=True, processes=20)

In [None]:
(PreprocessTraces() & dict(preprocess_id=1)).plot1()

In [None]:
(PreprocessTraces() & dict(preprocess_id=2)).plot1()

# Cell positions

In [None]:
populate_alpha.populate_cell_positions()

for key in (RetinalFieldLocationWing() & "wing_side='v'").proj():
    RetinalFieldLocationWing().update1(dict(**key, wing_side='n'))

populate_alpha.populate_cell_tags()

In [None]:
RetinalFieldLocationTableParams().fetch('table_path')

In [None]:
RetinalFieldLocationCat().plot('n_tvd_side')

In [None]:
RetinalFieldLocationWing().plot()

# Field to stack matching

## Make sure field stamps are consistent

In [None]:
field_keys = (Field & (RoiKind & 'roi_kind="roi"')).fetch('KEY')

for field_key in field_keys:
    f_xcoord_um, f_ycoord_um, f_zcoord_um = (Field & field_key).fetch('absx', 'absy', 'absz')
    xcoord_ums, ycoord_ums, zcoord_ums = (Presentation.ScanInfo() & field_key).fetch('xcoord_um', 'ycoord_um', 'zcoord_um')

    dists = ((xcoord_ums-f_xcoord_um)**2 + (ycoord_ums-f_ycoord_um)**2 + (zcoord_ums-f_zcoord_um)**2)**0.5

    if np.any(dists > 10):
        print(field_key)
        print(dists)
        print('x:', xcoord_ums-f_xcoord_um)
        print('y:', ycoord_ums-f_ycoord_um)
        print('z:', zcoord_ums-f_zcoord_um)
        print()

# Morphology

In [None]:
populate_alpha.populate_morphology(verbose=True)

In [None]:
populate_alpha.populate_additional_morph_metrics(verbose=True)

In [None]:
ConvexHull().plot1();

## Match fields

In [None]:
populate_alpha.populate_fit_to_morphology(verbose=True)

### Detect unmatched Fields

Find Fields with not match on the morphology. <br>
Add them to the delete keys in the populate script to remove them from the analysis.

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 3))
sns.histplot(ax=axs[0], data=FieldStackPos.fetch(format='frame'), x='rec_cpos_stack_fit_dist')
sns.histplot(ax=axs[1], data=FieldStackPos.FitInfo.fetch(format='frame'), x='score')
plt.show()

In [None]:
for key in (FieldStackPos() & "rec_c_warning_flag = 1"):
    FieldStackPos().plot1(key=key)
    plt.show()
print('done')

In [None]:
# for row in (FieldStackPos.FitInfo() & "score<-2"):
#     FieldStackPos().plot1(key=get_primary_key(FieldStackPos(), key=row))

In [None]:
for key in (FieldStackPos() & "rec_cpos_stack_fit_dist>40").fetch('KEY'):
    print(key)
    FieldStackPos().plot1(key=key)

## Plot ROIs on stack

### ROIs on Morph

In [None]:
FieldPosMetrics().populate(processes=20, display_progress=True)

In [None]:
FieldPosMetrics().plot1()

### Relative ROI positions

In [None]:
RelativeRoiPos().populate(display_progress=True, processes=20)

In [None]:
pres_key = get_primary_key(Presentation)
Presentation().plot1(pres_key, plot_field_rois=False)
RelativeRoiPos().plot(pres_key)

In [None]:
(FieldPosMetrics & pres_key).plot1()

In [None]:
# Sanity check
df_roi_pos = (FieldPosMetrics.RoiPosMetrics().proj('roi_pos_xyz') * RelativeRoiPos() & [dict(stim_name='noise_2500'), dict(stim_name='noise_1500')]).fetch(format='frame').reset_index()
df_roi_pos['roi_pos_x'] = df_roi_pos['roi_pos_xyz'].apply(lambda x: x[0])
df_roi_pos['roi_pos_y'] = df_roi_pos['roi_pos_xyz'].apply(lambda x: x[1])

groups = [group for _, group in df_roi_pos.groupby(['experimenter', 'date', 'exp_num', 'field'])]
df_field = np.random.choice(np.asarray(groups, dtype=object))

fig, axs = plt.subplots(1, 2, figsize=(12, 3))
sns.scatterplot(ax=axs[0], data=df_field, x='roi_pos_x', y='roi_pos_y')
sns.scatterplot(ax=axs[1], data=df_field, x='roi_dx_um', y='roi_dy_um')
plt.tight_layout()
plt.show()

# Field and soma ROIs

Field ROIs are simply all ROIs as one

In [None]:
populate_alpha.add_field_rois(verbose=True)

In [None]:
FieldRoiPosMetrics().populate(display_progress=True, processes=1)

In [None]:
RoiKind().populate(display_progress=True, processes=20)
RoiKind()

In [None]:
plt.hist(RoiKind().fetch('roi_kind'));

In [None]:
# Populate core after adding Field ROIs
populate_alpha.populate_core(verbose=True)
populate_alpha.populate_metrics(verbose=True, processes=20)

## Comare proximal dendrites / soma ROIs

In [None]:
df_soma_keys = (Presentation & (RoiKind & "roi_kind='soma'") & ["stim_name='noise_1500'", "stim_name='noise_2500'"]).proj().fetch(format='frame').reset_index()
df_soma_keys['sort_name'] = df_soma_keys.apply(lambda r: r['field'] if r['cond1'] == 'control' else r['cond1'] + r['field'] + '_', axis=1)
df_soma_keys['field_base'] = df_soma_keys.apply(lambda r: r['field'][:2].lower() if r['cond1'] == 'control' else r['cond1'].lower(), axis=1)
df_soma_keys = df_soma_keys.sort_values(['date', 'exp_num', 'sort_name'])

In [None]:
df_soma_keys.head(3)

In [None]:
# Find any hand-drawn SomaROIs that are not soma ROIs in the actual definition
for (date, exp_num, field_base), group in df_soma_keys.groupby(['date', 'exp_num', 'field_base']):

    if 'control' in group['cond1'].values:
        continue
    else:
        print(group['cond1'].values)
    
    base_key = dict(date=date, exp_num=exp_num, field=field_base)

    print(np.sort((FieldPosMetrics.RoiPosMetrics & base_key).fetch('d_dist_to_soma')))
    
    (Field & base_key).plot1()
    plt.show()
    
    for _, key in group.iterrows():
        (Field & key.to_dict()).plot1()
        plt.show()

    print("---------------------------")

## Look at true soma responses
i.e. compare proximal dendrites to somatic responses

In [None]:
ch_avgs, potential_soma_keys = (Field.StackAverages() & "ch_name='wDataCh0'" & (RoiKind & "roi_kind='field'") & (FieldRoiPosMetrics() & f"d_dist_to_soma<{70}")).fetch('ch_average', 'KEY')
len(ch_avgs)

In [None]:
from djimaging.utils import math_utils

n_cols = int(np.ceil(np.sqrt(len(ch_avgs))))
n_rows = int(np.ceil(len(ch_avgs) / n_cols))

fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(15, 10), sharex='all', sharey='all')
axs = axs.flatten()

for i, (ax, ch_avg, key) in enumerate(zip(axs, ch_avgs, potential_soma_keys)):
    norm_avg = math_utils.normalize_soft_zero_one(ch_avg)
    ax.imshow((norm_avg**0.7).T)
    ax.set_title(i)

In [None]:
true_soma_idxs = [3, 8, 12]
true_soma_keys = [potential_soma_keys[idx] for idx in true_soma_idxs]

In [None]:
for idx, key in zip(true_soma_idxs, true_soma_keys):
    key = key.copy()
    key['field'] = key['field'][:-8]
    print(idx)
    print(key)
    FieldStackPos().plot1(key=key)
    plt.show()

In [None]:
# Manually draw soma ROIs
from djimaging.autorois.roi_canvas import InteractiveRoiCanvas
from djimaging.utils import scanm_utils
import os

only_new = False

for key in true_soma_keys[0:]:
    filepaths, stim_names = (Presentation & key).fetch('h5_header', 'stim_name')
    
    outputfiles = []
    for filepath in filepaths:
        outputfiles.append(filepath.replace('.h5', '_TrueSomaROI.pkl'))

    if only_new:
        if np.all([os.path.exists(outputfile) for outputfile in outputfiles]):
            continue

    ch0_stacks = []
    ch1_stacks = []
    
    for filepath in filepaths:
        ch_stacks, wparams = scanm_utils.load_stacks_from_h5(filepath)
        ch0_stacks.append(ch_stacks['wDataCh0'])
        ch1_stacks.append(ch_stacks['wDataCh1'])
    
    main_stim_idx = np.argmax(['noise' in stim_name for stim_name in stim_names])
    gui = InteractiveRoiCanvas(ch0_stacks=ch0_stacks, ch1_stacks=ch1_stacks, output_files=outputfiles,
                               canvas_width=25, stim_names=stim_names, main_stim_idx=main_stim_idx)

    (Field & key).plot1()
    plt.show()
    display(gui.start_gui())

    break

In [None]:
from djimaging.utils import mask_utils
import pickle

for key in true_soma_keys:
    key = (Field & key).fetch1('KEY')
    
    # Get data
    field_entry = (Field & key).fetch1()
    field_mask_entry = (Field.RoiMask & key).fetch1()
    field_avg_entries = (Field.StackAverages & key).fetch()

    pres_entries = (Presentation & key).fetch()
    pres_info_entries = (Presentation.ScanInfo & key).fetch()
    pres_mask_entries = (Presentation.RoiMask & key).fetch()
    pres_avg_entries = (Presentation.StackAverages & key).fetch()

    # Change data

    ## Change field name
    field_name = field_entry['field'].replace('FieldROI', 'TrueSomaROI')
    new_key = key.copy()
    new_key['field'] = field_name
    
    field_entry['field'] = field_name
    field_mask_entry['field'] = field_name
    for entry in field_avg_entries:
        entry['field'] = field_name
    for entry in pres_entries:
        entry['field'] = field_name
    for entry in pres_mask_entries:
        entry['field'] = field_name
    for entry in pres_avg_entries:
        entry['field'] = field_name
    for entry in pres_info_entries:
        entry['field'] = field_name

    if len(Field & new_key) > 0:
        print(f'Entry already present {new_key}')
        continue
    
    ## Load ROI masks
    roi_masks = []
    for pres_entry in pres_entries:
        filepath = pres_entry[7]
        roimask_file = filepath.replace('.h5', '_TrueSomaROI.pkl')
        
        with open(roimask_file, 'rb') as f:
            roi_mask = mask_utils.to_igor_format(pickle.load(f).copy())
        roi_masks.append(roi_mask)

    ## Change ROI masks
    main_stim_idx = np.argmax(['noise' in pres_entry["stim_name"] for pres_entry in pres_entries])

    field_mask_entry['roi_mask'] = roi_masks[main_stim_idx]
    for entry, roi_mask in zip(pres_mask_entries, roi_masks):
        pres_and_field_mask = mask_utils.compare_roi_masks(roi_mask=roi_mask, ref_roi_mask=roi_masks[main_stim_idx], max_shift=2, bg_val=1)[0]
        entry['roi_mask'] = roi_mask
        entry['pres_and_field_mask'] = pres_and_field_mask

    print(f'Adding {new_key}')
    
    # Insert
    Field().insert1(field_entry, allow_direct_insert=True)
    Field.RoiMask().insert1(field_mask_entry, allow_direct_insert=True)
    Field.StackAverages().insert(field_avg_entries, allow_direct_insert=True)

    Presentation().insert(pres_entries, allow_direct_insert=True)
    Presentation.ScanInfo().insert(pres_info_entries, allow_direct_insert=True)
    Presentation.RoiMask().insert(pres_mask_entries, allow_direct_insert=True)
    Presentation.StackAverages().insert(pres_avg_entries, allow_direct_insert=True)

In [None]:
RoiKind().populate(display_progress=True)
np.unique(RoiKind().fetch('roi_kind'), return_counts=True)

In [None]:
populate_alpha.populate_core()

In [None]:
for key in (Averages & (RoiKind & dict(roi_kind='soma')) & dict(stim_name='gChirp')).fetch('KEY'):
    print((ChirpQI & key).fetch1('qidx'))
    Averages().plot1(key)

for key in (Averages & (RoiKind & dict(roi_kind='soma')) & dict(stim_name='lChirp')).fetch('KEY'):
    print((ChirpQI & key).fetch1('qidx'))
    Averages().plot1(key)

# Responses

## Surround index

In [None]:
SineSpotSurroundIndex().populate(processes=20, display_progress=True)

In [None]:
SineSpotSurroundIndex().plot1()

In [None]:
plt.hist(SineSpotSurroundIndex().fetch('sinespot_surround_index'), bins=51);

In [None]:
ChirpSurroundIndex().populate(processes=20, display_progress=True)

In [None]:
ChirpSurroundIndex().plot1()

In [None]:
plt.hist(ChirpSurroundIndex().fetch('chirp_surround_index'), bins=51);

## Quality

In [None]:
populate_alpha.populate_quality(verbose=True, processes=20)

In [None]:
QualityParams()

In [None]:
(QualityIndex() & (RoiKind & "roi_kind='roi'")).plot()

In [None]:
(ChirpQI() & (RoiKind & "roi_kind='field'")).plot()

In [None]:
(QualityIndex() & (RoiKind & "roi_kind='soma'")).plot()

# Text

In [None]:
np.unique([tt.shape[1] for tt in (Snippets() & dict(stim_name='gChirp')).fetch("triggertimes_snippets")], return_counts=True)

In [None]:
np.unique([tt.shape[1] for tt in (Snippets() & dict(stim_name='lChirp')).fetch("triggertimes_snippets")], return_counts=True)

In [None]:
np.unique([tt.shape[1] for tt in (Snippets() & dict(stim_name='sinespot')).fetch("triggertimes_snippets")], return_counts=True)

# Convex Hull

In [None]:
ConvexHull().plot1()

In [None]:
df_convex_hull = (ConvexHull() * RetinalFieldLocationWing().proj(group="wing_side")).fetch(format='frame')
df_convex_hull.head()

In [None]:
df_convex_hull.to_csv('data/convex_hull_calcium.csv')