In [1]:
# importing packages
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.lines as mlines
from matplotlib.gridspec import GridSpec

%matplotlib qt

from scipy.optimize import curve_fit
from scipy.special import expit

import ssm
from ssm import transitions 

import scipy.io
import os
import seaborn as sb
import pandas as pd

In [47]:
# load data
filename = 'C:/Users/Asus/Desktop/PhD/rotations/AnnJeff/data/riskdata/animalspaper.mat'
data = scipy.io.loadmat(filename)
data = data['riskydata']
df = pd.DataFrame(data)
df.columns = ['animal', 'sessionid', 'trialnumber', 'trial_block', 'lotterymag', 'lotteryprob', 
           'surebetmag', 'rewardreceived', 'lotterychoice', 'lotteryoutcome']

In [61]:
# getting relevant variables from data
animals = df.animal.unique()
nanimals = len(animals)

In [75]:
df

Unnamed: 0,animal,sessionid,trialnumber,trial_block,lotterymag,lotteryprob,surebetmag,rewardreceived,lotterychoice,lotteryoutcome
0,2152.0,158487.0,2.0,2.0,0.0,0.55,3.0,24.0,0.0,0.0
1,2152.0,158487.0,4.0,4.0,8.0,0.55,3.0,24.0,0.0,1.0
2,2152.0,158487.0,6.0,6.0,0.0,0.55,3.0,24.0,0.0,1.0
3,2152.0,158487.0,7.0,7.0,4.0,0.55,3.0,24.0,0.0,1.0
4,2152.0,158487.0,9.0,9.0,4.0,0.55,3.0,24.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...
8427,2166.0,182807.0,91.0,33.0,4.0,0.60,3.0,48.0,0.0,0.0
8428,2166.0,182807.0,92.0,34.0,4.0,0.60,3.0,48.0,0.0,0.0
8429,2166.0,182807.0,94.0,36.0,0.0,0.60,3.0,48.0,0.0,1.0
8430,2166.0,182807.0,95.0,37.0,16.0,0.60,3.0,256.0,1.0,1.0


In [120]:
# creating input structure compatible with the ssm package
inputdim = 2
inputs = []
Y = []

for aa in range(nanimals):
    animaldata = df[df.animal == animals[aa]]
    sessions = animaldata.sessionid.unique()
    nsessions = len(sessions)
    
    inputsaux = []
    yaux = []
    for ss in range(nsessions):
        sessiondata = animaldata[animaldata.sessionid == sessions[ss]]
        ntrials = len(sessiondata)
        inputsaux.append(np.ones([ntrials, inputdim]))
        yaux.append(np.zeros([ntrials,1], dtype = int))
        
        sbmag = sessiondata.surebetmag[sessiondata.lotterychoice == 0].unique()
        sbrwd = sessiondata.rewardreceived[sessiondata.lotterychoice == 0].unique()
        rwdmult = sbrwd/sbmag
        lotteryprob = sessiondata.lotteryprob.unique()
        
        deltaEV = rwdmult * (sessiondata.lotterymag * lotteryprob) - sbrwd
        normdeltaEV = deltaEV/max(deltaEV)
        inputsaux[ss][:,0] = normdeltaEV
        yaux[ss][:,0] = sessiondata.lotterychoice
    
    inputs.append(inputsaux)
    Y.append(yaux)

In [139]:
# setting glm-hmm parameters
animalidx = 0
nstates = 3
obsdim = 1
inputdim = 2
ncatergories = 2
niterations = 200

pstate = []
for aa in range(nanimals):
    glmhmm = ssm.HMM(nstates, obsdim, inputdim,  observations = 'input_driven_obs',
                observation_kwargs = dict(C = ncatergories), transitions = "standard")
    glmhmm2fit = glmhmm.fit(Y[aa], inputs = inputs[aa], method = 'em', num_iters = niterations, tolerance = 10**-5)
                     
    posterior_pstate = [glmhmm.expected_states(data = data, input = inputs)[0]
                        for data, inputs
                        in zip(Y[aa], inputs[aa])]
    pstate.append(posterior_pstate)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
# add state info to dataframe