# Imports

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
from djimaging.user.alpha.schemas.alpha_schema import *
from djimaging.user.alpha.utils import populate_alpha
from djimaging.utils.dj_utils import get_primary_key

populate_alpha.load_alpha_config(schema_name=populate_alpha.SCHEMA_PREFIX + "glu")
populate_alpha.load_alpha_schema(create_schema=False, create_tables=False)

In [None]:
schema

## Preprocessing

In [None]:
populate_alpha.populate_rf_glms_traces(verbose=True, processes=20)

In [None]:
GLMDNoiseTraceParams()

In [None]:
GLMDNoiseTrace().plot1(xlim=(100, 101))

## Fit

In [None]:
RFGLMParams()

In [None]:
from tqdm.notebook import tqdm
from datajoint.errors import LostConnectionError

for key in tqdm((RFGLM().key_source - RFGLM().proj()).fetch('KEY')):
    rf_entry = RFGLM()._fetch_and_compute(key=key, clear_outputs=True, suppress_outputs=True)
    try:
        RFGLM().insert1(rf_entry, allow_direct_insert=True)
    except LostConnectionError:
        import time
        time.sleep(3)
        RFGLM().insert1(rf_entry, allow_direct_insert=True)

In [None]:
(RFGLM() & "rf_glm_params_id=10").plot1()

## Metrics

In [None]:
populate_alpha.populate_rf_glm_properties(verbose=True, processes=20)

In [None]:
FitPosDoG2DRFGLM().populate(display_progress=True, processes=10)
FitPosGauss2DRFGLM().populate(display_progress=True, processes=10)

## Fit RF

### Parametric

In [None]:
SplitRFGLMParams()

In [None]:
(SplitRFGLM() & "rf_glm_params_id=10").plot1()

In [None]:
(FitPosGauss2DRFGLM() & "rf_glm_params_id=10").plot1()

### Contours

In [None]:
GLMContoursParams().add_default(rf_contours_params_id=1, blur_std=0., blur_npix=0., norm_kind='amp_one', levels=(0.3, 0.35, 0.4), skip_duplicates=True)
GLMContoursParams()

In [None]:
GLMContours().populate(make_kwargs=dict(plot=False), processes=10, display_progress=True)

In [None]:
GLMContourMetrics().populate(processes=10, display_progress=True)

In [None]:
GLMContourOffset().populate(processes=1, display_progress=True)

In [None]:
key = np.random.choice((GLMContours() & "rf_cdia_um>70" & "rf_glm_params_id=10" & (SplitRFGLM() & "split_qidx>0.3" & "split_qidx<0.5")).fetch('KEY'))
(SplitRFGLM() & key).plot1()
(GLMContours() & key).plot1()
(FitPosGauss2DRFGLM() & key).plot1()
(FitPosDoG2DRFGLM() & key).plot1()

## Compare fits

In [None]:
df_q_rf = ((SplitRFGLM & (RoiKind & "roi_kind='roi'")).proj("split_qidx") * (FitPosDoG2DRFGLM & (RoiKind & "roi_kind='roi'")).proj("rf_qidx")).fetch(format='frame').reset_index()
sns.pairplot(data=df_q_rf, vars=['split_qidx', 'rf_qidx'], hue='rf_glm_params_id', palette='tab10', kind='kde', plot_kws=dict(levels=[0.25, 0.5, 0.75]));

for rf_glm_params_id, df_q_rf_i in df_q_rf.groupby(['rf_glm_params_id']):
    print(rf_glm_params_id, np.sum((df_q_rf_i.split_qidx > 0.45) & (df_q_rf_i.rf_qidx > 0.45)))

In [None]:
plt.hist((FitGauss2DRFGLM & "rf_glm_params_id=10" & "rf_qidx>0.2").fetch('rf_cdia_um'));
plt.hist((GLMContours & "rf_glm_params_id=10").fetch('rf_cdia_um'), alpha=0.5);

In [None]:
(SplitRFGLM() & f"split_qidx>{0.45}")

In [None]:
thresh = 0.35
(FitPosDoG2DRFGLM & f"rf_qidx<{thresh}" & (RoiKind & "roi_kind='roi'") & (SplitRFGLM() & f"split_qidx>{thresh}") & dict(rf_glm_params_id=10)).plot1()

In [None]:
(TempRFGLMProperties() & dict(rf_glm_params_id=10)).plot();

In [None]:
def plot_glm_tab_param_ids(data_tab, param, i_list=[1, 2, 3, 4]):
    def get_tab(i):
        return (data_tab & dict(rf_glm_params_id=i)).proj(
            **{f"{param}_{i}": param, f"rf_glm_params_id_{i}": 'rf_glm_params_id'}) 

    tab = get_tab(i_list[0])
    for i in i_list[1:]:
        tab *= get_tab(i)

    df_param = tab.fetch(format='frame')

    fig, axs = plt.subplots(1, len(i_list)-1, figsize=((len(i_list)-1)*3.5, 3), sharex='all', sharey='all')
    for ii, i in enumerate(i_list[1:]):
        sns.regplot(ax=axs[ii], data=df_param, x=f'{param}_{i_list[0]}', y=f'{param}_{i}',
                    scatter_kws=dict(s=2, alpha=0.5, color='gray'), order=2)

    for ax in axs.flat:
        ax.axline(xy1=(0, 0), xy2=(1, 1), c='k')
        ax.set_aspect('equal', 'box')

    plt.tight_layout()
    plt.show()

In [None]:
plot_glm_tab_param_ids(data_tab=SplitRFGLM, param='split_qidx', i_list=[1, 2, 10, ])

In [None]:
RFGLMParams().fetch1()

## Estimate release events per second

In [None]:
EventsPerSecondParams().add_default(
    dnoise_params_id=1, fupsample_trace=4, fupsample_stim=12, lowpass_cutoff=0, ref_time='stim',
    fit_kind='events', pre_blur_sigma_s=0., post_blur_sigma_s=0., skip_duplicates=True)
EventsPerSecondParams()

In [None]:
EventsPerSecond().populate(display_progress=True, processes=20)

In [None]:
EventsPerSecond().plot1(xlim=(20, 30))

In [None]:
EventsPerSecond()

In [None]:
key = get_primary_key(EventsPerSecond)
time, trace, events_per_frame = (EventsPerSecond & key).fetch1('time', 'trace', 'stim')
events_per_frame

In [None]:
all_events_per_frame = (EventsPerSecond).fetch('stim')

In [None]:
np.mean(all_events_per_frame[all_events_per_frame>0])

In [None]:
plt.hist(all_events_per_frame, bins=100);