# Schema

In [None]:
from tqdm.notebook import tqdm
import numpy as np
import os

from alphacnn.database.encoder_schema import *
from alphacnn.utils.data_utils import load_config
from alphacnn import paths

connect_to_database(
    dj_config_file=paths.CONFIG_FILE,
    create_tables=True, create_schema=True, schema_name=paths.SCHEMA_PREFIX + 'encoder')
encoder_schema

# ERD

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore", FutureWarning)
    display(dj.ERD(encoder_schema))

# Stimulus

## Stimulus IDs

In [None]:
stimulus_config_file = "f002_f003_rot_1975_w_and_wo_test.yml"
stim_conf = load_config(os.path.join(paths.CONF_STIM_PATH, stimulus_config_file))
StimulusConfig().add_stim(stimulus_config_id=1, stimulus_config_file=stimulus_config_file, stimulus_dict=stim_conf)
StimulusConfig()

In [None]:
StimulusIDs().populate()
StimulusIDs()

## Load Stimuli

In [None]:
key = (StimulusConfig().proj() * StimulusIDs().proj()).fetch('KEY')[0]
stimulus_dict = (StimulusConfig & key).fetch1('stimulus_dict')
stimulus_file, video_dict, wo_cricket = (StimulusIDs & key).fetch1('stimulus_file', 'video_dict', 'wo_cricket')

In [None]:
np.load(video_dict['video_path'])

In [None]:
(StimulusIDs & key)

In [None]:
err_list = Stimulus().populate(display_progress=True, processes=1, suppress_errors=True)

In [None]:
if err_list[ 'error_list']:
    (StimulusIDs & [e[0] for e in err_list[ 'error_list']]).delete()

In [None]:
Stimulus()

In [None]:
for key in StimulusConfig.proj():
    key = (StimulusIDs & key).proj().fetch(format='frame').sample(1).reset_index().iloc[0].to_dict()
    key.pop('wo_cricket')
    print(key)
    for sub_key in (StimulusIDs & key).proj():
        print(sub_key)
        Stimulus().plot1(sub_key, n_rows=1, sym=True)

# BCs

## BC sRF configs

In [None]:
BCsRfConfig()

In [None]:
for stimulus_config_id in StimulusConfig.fetch('stimulus_config_id'):
    for i, (name, file) in enumerate(dict(ws='strf_cluster1.pkl', ss='strf_cluster3.pkl').items()):
        BCsRfConfig().add_from_file(
            bc_srf_config_id=i,
            bc_srf_config_name=name,
            bc_cdist=15,
            file=file,
            stimulus_config_id=stimulus_config_id,
        )
BCsRfConfig()

In [None]:
BCsRfConfig().plot()

## Spatial RFs

### Rect

In [None]:
BCSpatialRFOutput().populate(make_kwargs=dict(batch_size=16, batch_size_frames=64))
BCSpatialRFOutput()

In [None]:
for key in StimulusConfig.proj():
    key = (StimulusIDs & key).proj().fetch(format='frame').sample(1).reset_index().iloc[0].to_dict()
    key.pop('wo_cricket')
    print(key)
    for sub_key in (StimulusIDs & key).proj():
        print('Stimulus', sub_key)
        Stimulus().plot1(sub_key, n_rows=1, sym=True)
        for sub_sub_key in (BCsRfConfig * (StimulusIDs & sub_key)).proj():
            print('BC sRF', sub_sub_key)
            BCSpatialRFOutput().plot1(key=sub_sub_key, sym=True)
        print()

### Estimate NLs from data

#### Compute mean outputs of spatial RFs

In [None]:
bc_srf_outputs_ss = (BCSpatialRFOutput() & (BCsRfConfig & dict(bc_srf_config_name='ss'))).fetch('bc_srf_output')
bc_srf_outputs_ss = np.concatenate(bc_srf_outputs_ss)

In [None]:
rnd_idxs = np.random.choice(np.arange(bc_srf_outputs_ss.shape[0]), 100)
plt.hist(bc_srf_outputs_ss[rnd_idxs, :, :].flat, bins=201);

In [None]:
mu_ss = np.mean(bc_srf_outputs_ss)
sd_ss = np.std(bc_srf_outputs_ss)
q95_ss = np.percentile(bc_srf_outputs_ss, q=95)
q05_ss = np.percentile(bc_srf_outputs_ss, q=5)

In [None]:
print(mu_ss, sd_ss, q05_ss, q95_ss)

In [None]:
sd_ss = 0.020349585

In [None]:
bc_srf_outputs_ws = (BCSpatialRFOutput() & (BCsRfConfig & dict(bc_srf_config_name='ws'))).fetch('bc_srf_output')
bc_srf_outputs_ws = np.concatenate(bc_srf_outputs_ws)

In [None]:
rnd_idxs = np.random.choice(np.arange(bc_srf_outputs_ws.shape[0]), 100)
plt.hist(bc_srf_outputs_ws[rnd_idxs, :, :].flat, bins=201);

In [None]:
mu_ws = np.mean(bc_srf_outputs_ws)
sd_ws = np.std(bc_srf_outputs_ws)
q95_ws = np.percentile(bc_srf_outputs_ws, q=95)
q05_ws = np.percentile(bc_srf_outputs_ws, q=5)

In [None]:
print(mu_ws, sd_ws, q05_ws, q95_ws)

In [None]:
sd_ws = 0.09235746

In [None]:
plt.plot(np.ones(3), [mu_ss, q05_ss, q95_ss], '.')
plt.plot(np.ones(2), [mu_ss-sd_ss, mu_ss+sd_ss], '.')

plt.plot(np.ones(3)+1, [mu_ws, q05_ws, q95_ws], '.')
plt.plot(np.ones(2)+1, [mu_ws-sd_ws, mu_ws+sd_ws], '.')

Assume we have maximum steady release 5 vesicles per second. <br>
The framerate is 60 frames per second, so we have 5 vesicles per 60 frames, ~0.083 vesicles per frame. <br>
From the events per second esimate we also get around 0.1 events per frame; so that probably makes sense. <br>

The RRP should be around 8 vesicles, which can be released in a few frames. <br>

In [None]:
from djimaging.tables.receptivefield.non_linearities import apply_sigmoid

In [None]:
data_inputs_ws = np.linspace(-3*sd_ws, +3*sd_ws, 100)
sigmoid_params_ws = dict(k=3., q=29, b=20, v=0, d=mu_ws)
print(apply_sigmoid(y=mu_ws, **sigmoid_params_ws))
print(apply_sigmoid(y=mu_ws+1*sd_ws, **sigmoid_params_ws))
print(apply_sigmoid(y=mu_ws+2*sd_ws, **sigmoid_params_ws))
print(apply_sigmoid(y=mu_ws+3*sd_ws, **sigmoid_params_ws))

fig, axs = plt.subplots(1, 2, figsize=(12, 3))
axs[0].plot(data_inputs_ws, apply_sigmoid(y=data_inputs_ws, **sigmoid_params_ws))

data_outpus_ws = apply_sigmoid(y=bc_srf_outputs_ws[rnd_idxs, :, :].flatten(), **sigmoid_params_ws)
axs[1].hist(data_outpus_ws, bins=100);
axs[1].axvline(np.mean(data_outpus_ws>1))

plt.show()

In [None]:
data_inputs_ss = np.linspace(-3*sd_ss, +3*sd_ss, 100)
sigmoid_params_ss = dict(k=3., q=29, b=20*sd_ws/sd_ss, v=0, d=mu_ss)
print(apply_sigmoid(y=mu_ss, **sigmoid_params_ss))
print(apply_sigmoid(y=mu_ss+1*sd_ss, **sigmoid_params_ss))
print(apply_sigmoid(y=mu_ss+2*sd_ss, **sigmoid_params_ss))
print(apply_sigmoid(y=mu_ss+3*sd_ss, **sigmoid_params_ss))

fig, axs = plt.subplots(1, 2, figsize=(12, 3))
axs[0].plot((data_inputs_ss - mu_ss)/sd_ss, apply_sigmoid(y=data_inputs_ss, **sigmoid_params_ss))
axs[0].plot((data_inputs_ws - mu_ws)/sd_ws, apply_sigmoid(y=data_inputs_ws, **sigmoid_params_ws))

data_outpus_ss = apply_sigmoid(y=bc_srf_outputs_ss[rnd_idxs, :, :].flatten(), **sigmoid_params_ss)
axs[1].hist(data_outpus_ss, bins=100);
axs[1].axvline(np.mean(data_outpus_ss>1))

plt.show()

#### Add nls

In [None]:
sigmoid_params_ws_arr = np.array([sigmoid_params_ws['k'], sigmoid_params_ws['q'], sigmoid_params_ws['b'], sigmoid_params_ws['v'], sigmoid_params_ws['d']])
sigmoid_params_ws_arr

In [None]:
sigmoid_params_ss_arr = np.array([sigmoid_params_ss['k'], sigmoid_params_ss['q'], sigmoid_params_ss['b'], sigmoid_params_ss['v'], sigmoid_params_ss['d']])
sigmoid_params_ss_arr

In [None]:
for stimulus_config_id in StimulusConfig.fetch('stimulus_config_id'):
    BCRectConfig().add_from_data(bc_rect_config_id=0, bc_rect_config_name='ws', stimulus_config_id=stimulus_config_id, nl=sigmoid_params_ws_arr)
    BCRectConfig().add_from_data(bc_rect_config_id=1, bc_rect_config_name='ss', stimulus_config_id=stimulus_config_id, nl=sigmoid_params_ss_arr)
BCRectConfig()

### BC Output

In [None]:
BCRectOutput.populate(display_progress=True, make_kwargs=dict(batch_size=16, batch_size_frames=256))
BCRectOutput()

In [None]:
BCRectOutput.populate_missing(make_kwargs=dict(batch_size=16, batch_size_frames=256))
BCRectOutput()

In [None]:
BCRectOutput()

In [None]:
for stim_key in (StimulusConfig & BCRectOutput).proj():
    print(stim_key)
    key = (Stimulus & BCRectOutput & stim_key).proj().fetch(format='frame').sample(1).reset_index().iloc[0].to_dict()
    print(key)
    print('Stimulus')
    Stimulus().plot1(key=key, n_rows=1, sym=True)
    for bc_key in (BCsRfConfig & BCRectOutput & key).proj():
        print('BC', bc_key)
        key = {**key, **bc_key}
        # print('BC spatial')
        # BCSpatialRFOutput.plot1(key=key, sym=True)
        print('BC Rect')
        BCRectOutput().plot1(key=key, sym=False)
        
    print()

## Noise

In [None]:
for stimulus_config_id in StimulusConfig.fetch('stimulus_config_id'):
    for bc_srf_config_id in BCsRfConfig.fetch('bc_srf_config_id'):
        key = dict(stimulus_config_id=stimulus_config_id, bc_srf_config_id=bc_srf_config_id)
        BCNoiseConfigCore().add(**key, noise_id=1, noise_name="med", noise_dict=dict(bc_stddev=0.1), core_seed=123)
BCNoiseConfigCore()

In [None]:
BCNoiseConfig().populate(display_progress=True)
BCNoiseConfig()

In [None]:
for stimulus_config_id in StimulusConfig.fetch('stimulus_config_id'):
    for bc_srf_config_id in BCsRfConfig.fetch('bc_srf_config_id'):
        for bc_noise_id in tqdm(BCNoiseConfigCore().fetch('bc_noise_id')):
            key = dict(stimulus_config_id=stimulus_config_id, bc_srf_config_id=bc_srf_config_id, bc_noise_id=bc_noise_id)
            BCNoiseSeeds().add_samples(**key, n_samples_tot=1)
BCNoiseSeeds()

### BC Noise Output

In [None]:
BCNoiseOutput().populate(dict(bc_noise_id=1), display_progress=True)

In [None]:
for key in (StimulusConfig.proj() * BCsRfConfig.proj()):
    key = (BCNoiseOutput & key).proj().fetch(format='frame').sample(1).reset_index().iloc[0].to_dict()
    print(key)
    print('Stimulus')
    Stimulus().plot1(key, n_rows=1, sym=True)
    key.pop('bc_rect_config_id')
    key.pop('bc_srf_config_id')
    for bc_key in ((BCsRfConfig & BCRectOutput) & key).proj():
        sub_key = {**key, **bc_key}
        print('BC Noise', bc_key)
        BCNoiseOutput().plot1(key=sub_key, sym=True)
    print()

# RGCs

## Define synaptic weighs

In [None]:
BCsRfConfig().fetch('bc_srf_config_id', 'bc_srf_config_name')

In [None]:
BCRectConfig().fetch('bc_rect_config_id', 'bc_rect_config_name')

In [None]:
RGCSynapticWeights().add(
    rgc_id=0, rgc_name='nsl', rgc_cdist=150, rgc_rf_dia=405,
    bc_srf_config_id_1=0, bc_rect_config_id_1=0, config_weights_1=dict(std=101., cut=193., w_tot=1.),
    bc_srf_config_id_2=1, bc_rect_config_id_2=1, config_weights_2=dict(std=0., cut=0., w_tot=0.),
)

RGCSynapticWeights().add(
    rgc_id=1, rgc_name='tmp', rgc_cdist=75, rgc_rf_dia=315,
    bc_srf_config_id_1=0, bc_rect_config_id_1=0, config_weights_1=dict(std=31., cut=133., w_tot=0.5),
    bc_srf_config_id_2=1, bc_rect_config_id_2=1, config_weights_2=dict(std=31., cut=133., w_tot=0.5),
)

RGCSynapticWeights().add(
    rgc_id=2, rgc_name='tmp_ws', rgc_cdist=75, rgc_rf_dia=315,
    bc_srf_config_id_1=0, bc_rect_config_id_1=0, config_weights_1=dict(std=31., cut=133., w_tot=1.0),
    bc_srf_config_id_2=1, bc_rect_config_id_2=1, config_weights_2=dict(std=0., cut=0., w_tot=0.0),
)

RGCSynapticWeights().add(
    rgc_id=3, rgc_name='tmp_ss', rgc_cdist=75, rgc_rf_dia=315,
    bc_srf_config_id_1=0, bc_rect_config_id_1=0, config_weights_1=dict(std=0., cut=0., w_tot=0.0),
    bc_srf_config_id_2=1, bc_rect_config_id_2=1, config_weights_2=dict(std=31., cut=133., w_tot=1.0),
)

In [None]:
RGCSynapticWeights()

## Synpatic inputs

In [None]:
RGCSynapticInputs().populate(display_progress=True, make_kwargs=dict(batch_size=16, batch_size_frames=256), order='random')

In [None]:
stimulus_id = StimulusIDs.fetch(format='frame').reset_index().sample(1).iloc[0].stimulus_id

for key in (StimulusConfig.proj() * RGCSynapticWeights.proj()):
    if len((RGCSynapticInputs & key & dict(stimulus_id=stimulus_id)).proj()) == 0:
        continue
    
    key = (RGCSynapticInputs & key & dict(stimulus_id=stimulus_id)).proj().fetch(format='frame').sample(1).reset_index().iloc[0].to_dict()
    print(key)
    print('Stimulus')
    Stimulus().plot1(key, n_rows=1, sym=True, drop_first_n=0)
    print('BCs')
    BCSpatialRFOutput().plot1(key=key)
    print('RGCs')
    RGCSynapticInputs().plot1(key=key)
    print()

# Plot

In [None]:
from alphacnn.database.encoder_utils import plot_simulation

stimulus_config_id = StimulusConfig.fetch("stimulus_config_id")[0]
plot_simulation(dict(stimulus_config_id=stimulus_config_id), bc_config=1)