In [None]:
%load_ext autoreload
%autoreload 2
import datajoint as dj
from pipeline import lab, get_schema_name, experiment, foraging_model, ephys

# Migrate my previous code

## Define models

In [None]:
dj.ERD(foraging_model)

In [None]:
foraging_model.ModelClass()

In [None]:
foraging_model.Model.load_models()

In [None]:
foraging_model.Model & 'is_choice_kernel'

In [None]:
foraging_model.Model() * foraging_model.Model.Param() & 'model_id=0'

In [None]:
foraging_model.FittedSessionModel.key_source

## Prepare data for each session

In [None]:
dj.ERD(experiment.BehaviorTrial) +2

In [None]:
dj.ERD(experiment.WaterPort)+1 + dj.ERD(experiment.SessionBlock) 

In [None]:
key = {'subject_id': 447921, 'session': 3, 'model_id': 5}
experiment.WaterPortChoice.proj(choice='water_port') * experiment.BehaviorTrial.proj('outcome', 'early_lick') * \
experiment.SessionBlock.BlockTrial & key

## Populate model fitting

In [None]:
schema = dj.schema(get_schema_name('foraging_model'))
schema.jobs

In [None]:
schema.jobs.fetch('key', 'host')

In [None]:
schema.jobs.delete()

In [None]:
finished = len(foraging_model.FittedSessionModel())
total = len(foraging_model.FittedSessionModel.key_source)
print(f'Fitted session: {finished}/{total}, {finished/total:.2%}')

In [None]:
foraging_model.FittedSessionModel()

Overall statistics for all models

In [None]:
foraging_model.Model.proj('model_notation') * (foraging_model.Model.aggr(foraging_model.FittedSessionModel, aver_lpt_aic='avg(lpt_aic)', aver_lpt_bic='avg(lpt_bic)', n='count(*)'))

In [None]:
dj.U('subject_id','session').aggr(foraging_model.FittedSessionModel, max_lpt_aic='max(lpt_aic)')

# Model comparison

In [None]:
from pipeline.plot.foraging_model_plot import plot_session_model_comparison, plot_session_fitted_choice, _get_model_comparison_results

In [None]:
date, imec, unit = '2021-04-18', 0, 541
unit_key = (ephys.Unit() * experiment.Session & {'session_date': date, 'subject_id': 473361, 'insertion_number':imec + 1, 'unit_uid': unit}).fetch1("KEY")
plot_session_model_comparison(unit_key, model_comparison_idx=0, sort='aic')

In [None]:
a,_ = _get_model_comparison_results({'subject_id': 473361, 'session': 47}, sort='aic')
a

In [None]:
plot_session_fitted_choice(unit_key, first_n=2, last_n=1, smooth_factor=7)