In [1]:
%load_ext autoreload
%autoreload 2
import datajoint as dj
from pipeline import lab, get_schema_name, experiment, foraging_model, ephys

Connecting hanhou@datajoint.mesoscale-activity-map.org:3306


# Migrate my previous code

## Define models

In [None]:
dj.ERD(foraging_model)

In [None]:
foraging_model.ModelClass()

In [3]:
foraging_model.Model.load_models()

In [5]:
foraging_model.Model()

model_id,"model_class  e.g. LossCounting, RW1972, Hattori2019",model_notation,n_params  Effective param count,is_bias,is_epsilon_greedy,is_softmax,is_choice_kernel,desc  Long name,fit_cmd  Fitting command compatible with the Dynamic-Foraing repo
0,LossCounting,"LossCounting ($\mu_{LC}$, $\sigma_{LC}$)",2,0,0,0,0,"LossCounting: mean, std, no bias",=BLOB=
1,RW1972,"RW1972_epsi ($\alpha$, $\epsilon$)",2,0,1,0,0,"SuttonBarto: epsilon, no bias",=BLOB=
2,RW1972,"RW1972_softmax ($\alpha$, $\sigma$)",2,0,0,1,0,"SuttonBarto: softmax, no bias",=BLOB=
3,LNP,"LNP_softmax ($\tau_1$, $\sigma$)",2,0,0,1,0,"Sugrue2004, Corrado2005: one tau, no bias",=BLOB=
4,LNP,"LNP_softmax ($\tau_1$, $\tau_2$, $w_{\tau_1}$, $\sigma$)",4,0,0,1,0,"Corrado2005, Iigaya2019: two taus, no bias",=BLOB=
5,Bari2019,"Bari2019 ($\alpha$, $\delta$, $\sigma$)",3,0,0,1,0,"RL: chosen, unchosen, softmax, no bias",=BLOB=
6,Hattori2019,"Hattori2019 ($\alpha_{rew}$, $\alpha_{unr}$, $\sigma$)",3,0,0,1,0,"RL: rew, unrew, softmax, no bias",=BLOB=
7,Hattori2019,"Hattori2019 ($\alpha_{rew}$, $\alpha_{unr}$, $\delta$, $\sigma$)",4,0,0,1,0,"RL: rew, unrew, unchosen, softmax, no bias",=BLOB=
8,RW1972,"RW1972_epsi ($\alpha$, $\epsilon$, $b_L$)",3,1,1,0,0,SuttonBarto: epsilon,=BLOB=
9,RW1972,"RW1972_softmax ($\alpha$, $\sigma$, $b_L$)",3,1,0,1,0,SuttonBarto: softmax,=BLOB=


In [None]:
foraging_model.Model() * foraging_model.Model.Param() & 'model_id=0'

In [None]:
foraging_model.FittedSessionModel.key_source

## Prepare data for each session

In [None]:
dj.ERD(experiment.BehaviorTrial) +2

In [None]:
dj.ERD(experiment.WaterPort)+1 + dj.ERD(experiment.SessionBlock) 

In [None]:
key = {'subject_id': 447921, 'session': 3, 'model_id': 5}
experiment.WaterPortChoice.proj(choice='water_port') * experiment.BehaviorTrial.proj('outcome', 'early_lick') * \
experiment.SessionBlock.BlockTrial & key

## Populate model fitting

In [11]:
schema = dj.schema(get_schema_name('foraging_model'))
schema.jobs

table_name  className of the table,key_hash  key hash,"status  if tuple is missing, the job is available",key  structure containing the key,error_message  error message returned if failed,error_stack  error stack if failed,user  database user,host  system hostname,pid  system process id,connection_id  connection_id(),timestamp  automatic timestamp
__fitted_session_model,bd62297b213a9f874e5768ee44c1996c,reserved,=BLOB=,,=BLOB=,hanhou@206.241.1.244,Han-itx,10180,123163,2021-08-15 10:14:13
__fitted_session_model,becb647d2cbf6d2698e935d968768f92,reserved,=BLOB=,,=BLOB=,hanhou@206.241.1.244,Han-itx,2320,123162,2021-08-15 10:11:53
__fitted_session_model,d647911f7271e52f91b2f4aa6fe47937,reserved,=BLOB=,,=BLOB=,hanhou@206.241.1.244,Han-itx,23892,123167,2021-08-15 10:14:17
__fitted_session_model,e0299b386a1f499eefe1f3a78b9202db,reserved,=BLOB=,,=BLOB=,hanhou@206.241.1.244,Han-itx,17852,123160,2021-08-15 10:13:37


In [8]:
schema.jobs.fetch('key', 'host')

[array([{'subject_id': 482350, 'session': 16, 'model_id': 22},
        {'subject_id': 482350, 'session': 18, 'model_id': 22},
        {'subject_id': 482350, 'session': 37, 'model_id': 22},
        {'subject_id': 482350, 'session': 6, 'model_id': 22},
        {'subject_id': 482350, 'session': 22, 'model_id': 22}],
       dtype=object),
 array(['Han-itx', 'Han-itx', 'Han-itx', 'Han-itx', 'Han-itx'],
       dtype=object)]

In [41]:
schema.jobs.delete()

In [3]:
finished = len(foraging_model.FittedSessionModel())
total = len(foraging_model.FittedSessionModel.key_source)
print(f'Fitted session: {finished}/{total}, {finished/total:.2%}')

Fitted session: 45/13840, 0.33%


In [None]:
foraging_model.FittedSessionModel()

Overall statistics for all models

In [None]:
foraging_model.Model.proj('model_notation') * (foraging_model.Model.aggr(foraging_model.FittedSessionModel, aver_lpt_aic='avg(lpt_aic)', aver_lpt_bic='avg(lpt_bic)', n='count(*)'))

In [None]:
dj.U('subject_id','session').aggr(foraging_model.FittedSessionModel, max_lpt_aic='max(lpt_aic)')

# Model comparison

## Populate

In [3]:
foraging_model.FittedSessionModelComparison.populate(display_progress=True)

FittedSessionModelComparison: 100%|████████████████████████████████████████████████| 2450/2450 [03:07<00:00, 13.10it/s]


## Plotting

In [22]:
from pipeline.plot.foraging_model_plot import plot_session_model_comparison, plot_session_fitted_choice, _get_model_comparison_results

In [None]:
date, imec, unit = '2021-04-18', 0, 541
unit_key = (ephys.Unit() * experiment.Session & {'session_date': date, 'subject_id': 473361, 'insertion_number':imec + 1, 'unit_uid': unit}).fetch1("KEY")
plot_session_model_comparison(unit_key, model_comparison_idx=0, sort='aic')

In [None]:
a,_ = _get_model_comparison_results({'subject_id': 473361, 'session': 47}, sort='aic')
a

In [None]:
plot_session_fitted_choice(unit_key, first_n=2, last_n=1, smooth_factor=7)