# Installation

This schema loads raw data instead of h5-data.
Therefore, you need to install the ScanM python package: https://github.com/eulerlab/ScanM_support


# Imports

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
import datetime

# Data Base

In [None]:
from djimaging.user.alpha.utils.populate_alpha import load_alpha_config, SCHEMA_PREFIX

load_alpha_config(schema_name=SCHEMA_PREFIX + "soma")

In [None]:
from djimaging.utils.dj_utils import activate_schema
from djimaging.user.alpha.schemas.alpha_somas_schema import *

activate_schema(schema=schema, create_schema=True, create_tables=True)

# ERD

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore", FutureWarning)
    display(dj.ERD(schema))

# Upload user

In [None]:
from alphacnn.paths import PROJECT_ROOT

userinfo = {
    'experimenter': 'Oesterle',
    'data_dir': os.path.join(PROJECT_ROOT, 'data/Oesterle/sONa_somas/'),
    'datatype_loc': 0,
    'animal_loc': 1,
    'region_loc': 2,
    'field_loc': 3,
    'stimulus_loc': 4,
    'condition_loc': 5,
    'outline_alias': 'outline_edge_cut_ol_vessels_vessel',
}

assert os.path.isdir(userinfo['data_dir'])

In [None]:
UserInfo().upload_user(userinfo)

In [None]:
# Plot the data files in the selected folder
UserInfo().plot1(key=None, show_pre=False, show_raw=False, show_header=True)

In [None]:
RawDataParams().add_default(from_raw_data=True, igor_roi_masks='no')
RawDataParams()

## Populate data

### Experiments

In [None]:
Experiment().rescan_filesystem(verboselvl=2)
Experiment()

### Fields

In [None]:
Field().rescan_filesystem(verboselvl=2)

In [None]:
Field().plot1(key=None)

### Stimuli

#### Add default stimuli

In [None]:
Stimulus().add_nostim(skip_duplicates=True)
Stimulus().add_chirp(spatialextent=1000, stim_name='gChirp', alias="chirp_gchirp_globalchirp", skip_duplicates=True)

In [None]:
def color2word(r, g, b):
    if b > 0 and g > 0:
        return 'W'
    elif g > 0:
        return 'G'
    else:
        return 'B'

In [None]:
import random

vspots_trial_info = []

p = {
    "nTrials": 2,
    "TimeOn_s": 1.0,
    "TimeOff_s": 1.0,
    "PauseInterColor_s": 1.0,
    "PauseInterTrial_s": 0.0,
    "SpotSizes_um": (100, 200, 300, 400, 600, 1000),
    # RGB and BGR for setup 1 and 3, respectively. but red is always off.
    "Spot_RGB_order": ((255, 255, 255), (255, 0, 255), (0, 255, 0),
                       (255, 255, 255), (0, 255, 0), (255, 0, 255),),
    "durFr_s": 1 / 60.0,  # Frame duration
    "nFrPerMarker": 3,
}

random_seed = 555
random.seed(random_seed)

size_sequences = []
sizes = list(p["SpotSizes_um"][:])
for iT in range(p["nTrials"]):
    sizes_i = sizes.copy()
    random.shuffle(sizes_i)
    size_sequences.append(sizes_i)

p['SpotSizesShuffled_um'] = size_sequences

# Define stimulus objects
size2idx_dict = {}
for idx, size in enumerate(p["SpotSizes_um"], start=1):
    size2idx_dict[size] = idx

for size_sequence in size_sequences:
    for color in p["Spot_RGB_order"]:
        for size in size_sequence:
            vspots_trial_info.append({
                'name': color2word(*color[::-1]) + str(size).zfill(4), 'ntrigger': 2
            }) # RGB inverted for setup 3

In [None]:
Stimulus().add_stimulus(
    stim_name='vspots', alias='vspots_spots', isrepeated=0, ntrigger_rep=144,
    stim_dict=dict(trigger_delay=-1),
    trial_info=vspots_trial_info, skip_duplicates=True
)    

In [None]:
(Stimulus() & dict(stim_name='vspots')).fetch('trial_info')

## Presentations

In [None]:
Presentation().populate(processes=20, display_progress=True, suppress_errors=True)

In [None]:
Presentation().plot1(key=None)

# AutoROIs

ROI masks for traces, be conservative here and do not include pixels at the border to you can align the ROI across stimuli easily.

In [None]:
HighRes().populate(display_progress=True, suppress_errors=False, order='random')

In [None]:
HighRes()

In [None]:
HighRes().plot1()

## Masks for traces

Don't use border pixels

In [None]:
# If you have save some of you AutoROIs ROI masks you can load them here.
RoiMask().rescan_filesystem(verboselvl=2, roi_mask_dir='ROIs')
RoiMask()

In [None]:
# Find all the fields that still require a ROI mask.
missing_fields = RoiMask().list_missing_field()

In [None]:
if len(missing_fields) > 0:
    field_key = missing_fields.pop()  # Pick one field
    
    # Load ROI canvas, draw the ROI mask, clean it if you want, shift if you want.
    # You can then save it to a file to be able to load it again later.
    roi_canvas = RoiMask().draw_roi_mask(field_key=field_key, canvas_width=15, autorois_models=None, roi_mask_dir='ROIs', max_shift=10)
    display(roi_canvas.start_gui())

In [None]:
field_key = {'experimenter': 'Oesterle', 'date': datetime.date(2024, 7, 18), 'exp_num': 1, 'raw_id': 1, 'field': 'GCL4', 'roi_id': 1}
roi_canvas = RoiMask().draw_roi_mask(field_key=field_key, canvas_width=15, autorois_models=None, roi_mask_dir='ROIs', max_shift=10)
display(roi_canvas.start_gui())

In [None]:
# Load the just saved ROI mask
RoiMask().rescan_filesystem(verboselvl=1, roi_mask_dir='ROIs',)

In [None]:
RoiMask().plot1()

In [None]:
Roi().populate(processes=20, display_progress=True)

In [None]:
Roi().plot1(key=None)

# Responses

## Traces

In [None]:
Traces().populate(processes=20, display_progress=True)
Traces()

## Preprocessed traces

In [None]:
PreprocessParams().add_default(skip_duplicates=True)
PreprocessParams()

In [None]:
PreprocessTraces().populate(processes=20, display_progress=True)
PreprocessTraces()

In [None]:
PreprocessTraces().plot1()

## Blank corrupted repetitions

There was an issue with the setup and therefore there were some z-shifts during the recordings that corrupted the recordings.
Here we remove corrupted repetitions entirely.

### Flag corrupted reps

In [None]:
keys = (Presentation & dict(stim_name='vspots')).fetch('KEY')
problem_keys = []

In [None]:
## Repeat the following for all keys

In [None]:
from djimaging.utils import scanm_utils

key = keys.pop(0)

data_name, alt_name = (UserInfo & key).fetch1('data_stack_name', 'alt_stack_name')
filepath, triggertimes = (Presentation & key).fetch1(Presentation().filepath, 'triggertimes')

from_raw_data = (RawDataParams & key).fetch1('from_raw_data')
ch_stacks, wparams = scanm_utils.load_stacks(filepath, from_raw_data=from_raw_data,
                                             ch_names=('wDataCh0', 'wDataCh1'))

n_frames = ch_stacks['wDataCh0'].shape[2]

trace_times, trace = (Traces() & key).fetch1('trace_times', 'trace')

In [None]:
key

In [None]:
from ipywidgets import widgets

layout = widgets.Layout(width='1000px')

w_frame = widgets.IntSlider(0, min=0, max=n_frames - 1, step=1, layout=layout)
w_left = widgets.IntSlider(0, min=0, max=n_frames - 1, step=1, layout=layout)
w_right = widgets.IntSlider(n_frames - 1, min=0, max=n_frames - 1, step=1, layout=layout)
w_save = widgets.Checkbox(False)

tmin = np.min(trace)
tmax = np.max(trace)

ch1_trace = np.nanpercentile(ch_stacks[alt_name][3:-3, 3:-3], axis=(0, 1), q=98)

cmin = np.min(ch1_trace)
cmax = np.max(ch1_trace)

ch0_min = np.min(ch_stacks[data_name])
ch0_max = np.max(ch_stacks[data_name])

ch1_min = np.min(ch_stacks[alt_name])
ch1_max = np.max(ch_stacks[alt_name])

In [None]:
def _fit_plot(frame=0, left=0, right=n_frames - 1):
    fig, axs = plt.subplot_mosaic([['C', 'C'], ['D', 'D'], ['A', 'B']], figsize=(10, 6), height_ratios=(1, 1, 3))
    
    axs['A'].imshow(ch_stacks[data_name][:, :, frame], cmap='gray', vmin=ch0_min, vmax=ch0_max)
    axs['B'].imshow(ch_stacks[alt_name][:, :, frame], cmap='gray',  vmin=ch1_min, vmax=ch1_max)

    axs['C'].plot(trace_times, trace)
    axs['C'].axvline(trace_times[frame], c='r')
    axs['C'].axvline(trace_times[left], c='k')
    axs['C'].axvline(trace_times[right], c='k')
    axs['C'].vlines(triggertimes[::12], tmin, tmax, color='orange', ls='-', zorder=-100)
    axs['C'].vlines(triggertimes[::36], tmin, tmax, color='orange', ls='-', lw=3, zorder=-100)

    axs['D'].plot(trace_times, ch1_trace)
    axs['D'].axvline(trace_times[frame], c='r')
    axs['D'].axvline(trace_times[left], c='k')
    axs['D'].axvline(trace_times[right], c='k')
    axs['D'].vlines(triggertimes[::12], cmin, cmax, color='orange', ls='-', zorder=-100)
    axs['D'].vlines(triggertimes[::36], cmin, cmax, color='orange', ls='-', lw=3, zorder=-100)

    plt.tight_layout()

In [None]:
@widgets.interact(frame=w_frame, left=w_left, right=w_right, save=w_save)
def plot_fit(frame=0, left=0, right=n_frames - 1, save=False):
    _fit_plot(frame=frame, left=left, right=right)

    if save:
        i0, i1 = (right, left + 1) if right < left else (left, right + 1)
        problem_keys.append((key, i0, i1))
        w_save.value = False

In [None]:
key

### Change in database

"+n" means keep the first n repetitions. <br>
"-x" means delete rep x

In [None]:
clip_keys = [
    (2, {'experimenter': 'Oesterle',
   'date': datetime.date(2024, 7, 16),
   'exp_num': 2,
   'raw_id': 1,
   'field': 'GCL0',
   'stim_name': 'vspots',
   'condition': 'control'}),

    (2, {'experimenter': 'Oesterle',
   'date': datetime.date(2024, 7, 18),
   'exp_num': 1,
   'raw_id': 1,
   'field': 'GCL4',
   'stim_name': 'vspots',
   'condition': 'control'}),

    (-2, {'experimenter': 'Oesterle',
    'date': datetime.date(2024, 7, 18),
    'exp_num': 1,
    'raw_id': 1,
    'field': 'GCL0',
    'stim_name': 'vspots',
    'condition': 'control'}),

    (-2, {'experimenter': 'Oesterle',
    'date': datetime.date(2024, 7, 18),
    'exp_num': 1,
    'raw_id': 1,
    'field': 'GCL3',
    'stim_name': 'vspots',
    'condition': 'control'}),
]

In [None]:
from IPython.display import clear_output
from djimaging.utils import trace_utils

for reps, clip_key in clip_keys:
    triggertimes = (Presentation & clip_key).fetch1('triggertimes')
    fs = (Presentation.ScanInfo & clip_key).fetch1('scan_frequency')
    trace, trace_times, smoothed_trace = (PreprocessTraces & clip_key).fetch1('preprocess_trace', 'preprocess_trace_times', 'smoothed_trace')
    raw_trace = (Traces & clip_key).fetch1('trace')
    
    new_trace = trace.copy()
    
    if reps > 0:
        idx1 = trace_utils.find_closest_after(triggertimes[reps*36], trace_times, as_index=True) + 1
        tt_idx1 = reps*36
    
        new_trace[idx1:] = np.nan
    
    elif reps < 0:
        idx0 = trace_utils.find_closest_after(triggertimes[(abs(reps)-1)*36], trace_times, as_index=True)
        idx1 = trace_utils.find_closest_after(triggertimes[abs(reps)*36], trace_times, as_index=True)
    
        new_trace[idx0:idx1] = np.nan
    
    fig, axs = plt.subplots(1, 2, figsize=(12, 3))
    ax = axs[0]
    ax.plot(trace_times, raw_trace)
    ax.vlines(triggertimes[::12], np.nanmin(raw_trace), np.nanmax(raw_trace), color='r');
    ax.vlines(triggertimes[::36], np.nanmin(raw_trace), np.nanmax(raw_trace), color='r', lw=3);

    ax = axs[1]
    ax.plot(trace_times, new_trace)
    ax.vlines(triggertimes[::12], np.nanmin(trace), np.nanmax(trace), color='r');
    ax.vlines(triggertimes[::36], np.nanmin(trace), np.nanmax(trace), color='r', lw=3);
    plt.show()
    
    PreprocessTraces().update1({**(PreprocessTraces & clip_key).fetch1('KEY'), "preprocess_trace": new_trace})

    if input("Continue? [y/n]") != 'y':
        break
    else:
        clear_output()

## Snippets

In [None]:
Snippets().populate(dict(stim_name='gChirp'), processes=20, display_progress=True)
Snippets()

In [None]:
for key in (Snippets & dict(stim_name='gChirp')):
    Snippets().plot1(key=key)
    break

In [None]:
Averages().populate(processes=1, display_progress=True)
Averages()

In [None]:
(Averages & dict(stim_name='gChirp')).plot()

In [None]:
for key in np.random.choice(Averages.fetch('KEY'), 3):
    Averages().plot1(key=key)

## Quality

In [None]:
ChirpQI().populate(display_progress=True, processes=20)
ChirpQI()

In [None]:
(Averages() & (ChirpQI & "qidx>0.9")).plot1()

## Spot response

In [None]:
GroupSnippets().populate()

In [None]:
GroupSnippets().plot1();

In [None]:
WbgSpots().populate(processes=1, make_kwargs=dict(plot=1))

# Soma sizes

In [None]:
# If you have save some of you AutoROIs ROI masks you can load them here.
SizeRoiMask().rescan_filesystem(verboselvl=2, roi_mask_dir='ROIs_size')
SizeRoiMask()

In [None]:
# Find all the fields that still require a ROI mask.
missing_fields = SizeRoiMask().list_missing_field()
print(len(missing_fields))

In [None]:
field_key = missing_fields.pop()  # Pick one field

# Load ROI canvas, draw the ROI mask, clean it if you want, shift if you want.
# You can then save it to a file to be able to load it again later.
roi_canvas = SizeRoiMask().draw_roi_mask(field_key=field_key, canvas_width=15, autorois_models=None, roi_mask_dir='ROIs_size', max_shift=10, show_diagnostics=False)
display(roi_canvas.start_gui())

In [None]:
# Load the just saved ROI mask
SizeRoiMask().rescan_filesystem(verboselvl=1, roi_mask_dir='ROIs_size',)

In [None]:
SizeRoi().populate(processes=20, display_progress=True)
SizeRoi()

In [None]:
plt.hist(SizeRoi().fetch('roi_dia_um'))

In [None]:
(Averages & ( SizeRoi() & 'roi_dia_um<17')).plot1()

In [None]:
# Double check smallest cells: Around the non-round? Are they larger than all other cells in Field?

for key in (SizeRoi() & "roi_dia_um<18").fetch('KEY'):
    print(key)
    (SizeRoi & key).plot1()
    plt.show()

In [None]:
(Field & {'experimenter': 'Oesterle', 'date': datetime.date(2024, 7, 18), 'exp_num': 1, 'raw_id': 1, 'field': 'GCL4', 'roi_id': 1}).delete()

# Recording location

In [None]:
# From Notes
absx, absy = (Field & dict(date='2024-06-20')).fetch1('absx', 'absy')
odx = absx - 344
ody = absy - (-1608)

odx, ody

In [None]:
OpticDisk().populate(processes=1, display_progress=True)
OpticDisk()

In [None]:
RelativeFieldLocation().populate(processes=20, display_progress=True)
RelativeFieldLocation()

In [None]:
RelativeFieldLocation().plot()

In [None]:
RetinalFieldLocation().populate(processes=20, display_progress=True)
RetinalFieldLocation()

In [None]:
RetinalFieldLocation().plot()

In [None]:
UserInfo()