In [9]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
from sys import path, platform

## Set project root
if platform == "linux" or platform == "linux2":
    pass
    # root_dir = '/mnt/g/My Drive/Schiller-Gu-Lab/Projects/SlotMachine/slots-reversal-project/'
    # model_functions_path = f'{root_dir}/derivatives/computational_modeling/'
elif platform == "darwin":
    root_dir = '/Volumes/synapse/projects/SlotsTasks/'
elif platform == "win32":
    root_dir = 'Z:/projects/SlotsTasks/'

project_dir = f'{root_dir}/online/prolific-food-craving/'
# analysis_dir = f'{project_dir}/derivatives/decision/'
model_functions_path = f'{root_dir}/bayesian_models/'
## Add model_functions to system path
path.append(model_functions_path)

from slotscraving.decision import Biased, Heuristic, RescorlaWagner, RWDecay, RWRL


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# Graphing libraries
import arviz as az
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from slotscraving.decision.utils import plotting, load_data
from scipy.special import expit

In [10]:
path_to_summary = f"{project_dir}/rawdata/clean_df_summary.csv"
path_to_longform = f"{project_dir}/rawdata/clean_df_longform.csv"
df_summary, longform = load_data.load_clean_dbs(path_to_summary, path_to_longform)
netcdf_path = f"{project_dir}/derivatives/decision/output/"

In [11]:
longform['PID'].unique()

array(['61281debe85082cc937dd9ae', '58595b56a3149800011e156e',
       '5ca9e89dcb4af3001440da22', '5c284ae48e845b0001fc40aa',
       '5bfadc1846911f0001d7d1eb', '610008383c0a128712403745',
       '5c3848fc19ceb400010c24a0', '558955ebfdf99b6bd06016c9',
       '57c9c13079ad9b0001e416be', '60b78db6151a3abb7f832c22',
       '5ec6d423fc10270f0bbbc90e', '60ff42b20bee1078e198b4e5',
       '5b2122862942cc0001e5856b', '612682ec747ac2d5df40f7df',
       '615200800c7a074171c4968d', '5765c0fcf2e23200017ded5e',
       '6157b6f47949b07006a288d1', '5beb92fbe0a3940001543d8c',
       '5ee24fba3b868b345bb6f7bf', '5eddb464a77718acaaa72ab5',
       '615ddfac4254098aac0104f6', '5dcc948ed9873592337ae827',
       '6108e0b9b254d1ce92a37b95', '61283769a39dd638c256540c',
       '610099a0bb44106c39030d8f', '60ffd8686289b723fe66c706',
       '61366a0250ad37452b08b9a3', '5e64876fa46cc2204f34e98a',
       '61083562ab4a86a37f0807a8', '6112dba8e2eb90519a46e6cf',
       '61202b6be797314915cb1bd7', '612597436eff2c3e153

## Load all models

In [12]:
biased = Biased.Biased(model_name='biased', save_path=netcdf_path, summary=df_summary, longform=longform)
biased.fit(jupyter=True)
biased.calc_Q_table()
biased.calc_bics()

Participant 45 completed...


In [13]:
heuristic = Heuristic.Heuristic(model_name='heuristic', save_path=netcdf_path, summary=df_summary, longform=longform)
heuristic.fit(jupyter=True)
heuristic.calc_Q_table()
heuristic.calc_bics()

Participant 45 completed...


In [14]:
rw = RescorlaWagner.RW(model_name='rw', save_path=netcdf_path, summary=df_summary, longform=longform)
rw.fit(jupyter=True)
rw.calc_Q_table()
rw.calc_bics()

Participant 45 completed...


In [15]:
rwdecay = RWDecay.RWDecay(model_name='rwdecay', save_path=netcdf_path, summary=df_summary, longform=longform)
rwdecay.fit(jupyter=True)
rwdecay.calc_Q_table()
rwdecay.calc_bics()

Participant 45 completed...


In [16]:
rwrl = RWRL.RWRL(model_name='rwrl', save_path=netcdf_path, summary=df_summary, longform=longform)
rwrl.fit(jupyter=True)
rwrl.calc_Q_table()
rwrl.calc_bics()

Participant 45 completed...


## Model comparison

In [17]:
plotting.plot_model_comparison(
    [biased, heuristic, rw, rwdecay, rwrl], 
    sum=True, 
    metric='BIC',
    y_range=[3050,3375]
)