In [17]:
# importing packages
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib qt
from scipy.optimize import curve_fit
import ssm
from ssm import transitions 
import csv
import scipy.io
import os
import seaborn as sb
from matplotlib.lines import Line2D
import matplotlib.lines as mlines
from matplotlib.gridspec import GridSpec
from scipy.special import expit

In [18]:
# load data
filename = 'C:/Users/Asus/Desktop/PhD/rotations/AnnJeff/data/riskdata/animalspaper.mat'
data = scipy.io.loadmat(filename)
data = data['riskydata']
header =  ['animal0', 'sessionid1', 'trialnumber2', 'trial_block3', 'lotterymag4', 'lotteryprob5', 
           'surebetmag6', 'rewardreceived7', 'lotterychoice8', 'lotteryoutcome9']
rootpath = 'C:/Users/Asus/Desktop/PhD/rotations/AnnJeff/data/hmmglmdata/'

In [19]:
# defining plot colors
stimcolors = sb.color_palette("rocket_r")
colors = ['gold', 'plum', 'lightskyblue']

In [20]:
# defining psychometric curve function
def psycurves(x, u, l, b, s):
  """ 4-parameters psychometric function

    x (ndarray): data to fit
    u (float): upper assymptote (lim psycurve when (stim -> +inf))
    l (float): lower assymptote (lim psycurve when (stim -> -inf))
    b (float): bias (stim value where psycurve crosses chance level)
    s (float): slope  

  Returns:
    ndarray: estimated parameter values of shape (n_regressors)
  """
    
  return (u-l)*((np.exp((x-b)/s))/(1 + np.exp((x-b)/s))) + l

In [21]:
# defining psychometric curve function
def psycurves2(x, u, l, b, s):
  """ 4-parameters psychometric function

    x (ndarray): data to fit
    u (float): upper assymptote (lim psycurve when (stim -> +inf))
    l (float): lower assymptote (lim psycurve when (stim -> -inf))
    b (float): bias (stim value where psycurve crosses chance level)
    s (float): slope  

  Returns:
    ndarray: estimated parameter values of shape (n_regressors)
  """
    
  return u + ((1 - u - l)/(1 + np.exp(-(x-b)*s)))

In [22]:
animalid = np.unique(data[:,0])
nstim = len(np.unique(data[:,4]))
nanimals = len(animalid)

plottery = []
deltaEV = []
for aa in range(nanimals):
    animalflags = data[:,0] == animalid[aa]
    
    sessions = np.unique(data[animalflags, 1])
    nsessions = len(sessions)

    plottery.append(np.zeros([nsessions, nstim]))
    deltaEV.append(np.zeros([nsessions, nstim]))
    for ss in range(nsessions):
        sessionflags = data[:,1] == sessions[ss]
        ss_data = data[animalflags & sessionflags, :]
        
        surebet_flags = ss_data[:, 8] == 0
        surebet_mag = np.unique(ss_data[surebet_flags, 6])
        surebet_rwd = np.unique(ss_data[surebet_flags, 7])
        rwdmultiplier = surebet_rwd/surebet_mag
        
        lotteryprob = np.unique(ss_data[:,5])
        lotterymag = np.unique(ss_data[:,4])
        rwdreceived = np.unique(ss_data[:,7])

        if(len(lotterymag)<nstim):
            continue;
        
        deltaEV[aa][ss,:] = np.round((rwdmultiplier * (lotterymag * lotteryprob)) - surebet_rwd, 1)

        for ev in range(nstim):
            magnitudeflags = ss_data[:, 4] == lotterymag[ev]
            lotterychoiceflags = ss_data[magnitudeflags, 8] == 1
            plottery[aa][ss, ev] = sum(lotterychoiceflags)/sum(magnitudeflags)

In [23]:
# plot psychometric curves
for aa in range(nanimals):
    fig, ax = plt.subplots(1, 1, figsize = (5, 5))
    lotterymag = np.unique(deltaEV[aa])
    nsessions = plottery[aa].shape[0]
    lotterymean = np.mean(plottery[aa], 0)
    lotterysem = np.std(plottery[aa], 0)/np.sqrt(nsessions)
    
    plt.subplot(1, 1, 1)
    plt.errorbar(lotterymag, lotterymean, lotterysem,
                marker = 'o', linestyle = 'None', 
                markerfacecolor = 'k', markeredgecolor = 'k',
                ecolor = 'k', markersize = 5)
    
    popt, _ = curve_fit(psycurves, lotterymag, lotterymean)
    xx = np.linspace(min(lotterymag) - 10, max(lotterymag) + 10, 1000)
    plt.plot(xx, psycurves(xx, *popt), c = 'k')
    
    # reference lines
    plt.axvline(0, linestyle = ':', c = 'k',  lw = .5)
    plt.axhline(.5, linestyle = ':', c = 'k', lw = .5)
    
    # axis options
    plt.ylim((-.05,1.05))
    plt.xlabel('$\Delta$ EV')  
    plt.title('animal : ' + str(animalid[aa]))
    plt.ylabel('p(risk)')
    ax.set_yticks([0, .5, 1])
    ax.set_xticks(lotterymag)
    plt.xticks(fontsize = 7.5)
    plt.yticks(fontsize = 7.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ##saving figures
    #path = os.path.join(rootpath, str(animalid[aa]))
    #os.mkdir(path)
    #plt.savefig(path + '/alldata_' + str(animalid[aa]) + '.png')

In [109]:
# getting data for a specific animal 
animalidx = 0
animalflags = data[:,0] == animalid[animalidx]
sessions = np.unique(data[animalflags, 1])
nsessions = len(sessions)

Y = []
inputs = []
ninputs = 3
tt_deltaEV = []
prevchoice = []
normtt_deltaEV = []
for ss in range(nsessions):
    sessionflags = data[:, 1] == sessions[ss]
    ntrials = sum(sessionflags)
    
    Y.append(np.zeros([ntrials,1], dtype = int))
    inputs.append(np.ones([ntrials, 3]))
    tt_deltaEV.append(np.zeros([ntrials,1]))
    prevchoice.append(np.zeros([ntrials,1]))
    normtt_deltaEV.append(np.zeros([ntrials,1]))
    
    ss_data = data[animalflags & sessionflags, :]
    choice = ss_data[:,8]
    lotterymag = ss_data[:,4]
    
    # session specific settings
    surebet_flags = ss_data[:, 8] == 0
    surebet_mag = np.unique(ss_data[surebet_flags, 6])
    surebet_rwd = np.unique(ss_data[surebet_flags, 7])
    rwdmultiplier = surebet_rwd/surebet_mag
    lotteryprob = np.unique(ss_data[:,5])
    ss_lotterymag = np.unique(ss_data[:,4])
    
    ss_deltaEV = rwdmultiplier * (lotterymag * lotteryprob) - surebet_rwd
   
    for tt in range(ntrials):
        Y[ss][tt][0] = choice[tt]
        getdeltaEV = rwdmultiplier * (lotterymag[tt] * lotteryprob) - surebet_rwd
        tt_deltaEV[ss][tt] = getdeltaEV
        normtt_deltaEV[ss][tt] = (getdeltaEV/np.max(ss_deltaEV))
        inputs[ss][tt][0] = getdeltaEV/np.max(ss_deltaEV)
        #normtt_deltaEV[ss][tt] = (getdeltaEV - np.mean(ss_deltaEV))/(np.std(ss_deltaEV))

        # adding previous choice to the model
        if tt == 0:
            inputs[ss][tt][2] = 0
            prevchoice[ss][tt] = 0
        else:
            inputs[ss][tt][2] = choice[tt-1]
            prevchoice[ss][tt] = choice[tt-1]

In [110]:
# selecting inputs to use
inputs = [inputs[ss][:, 0:2] for ss in range(nsessions)]

In [94]:
## setting glm-hmm parameters
#nstates = np.array([1,2,3,4,5])
#obsdim = 1
#inputdim = inputs[0].shape[1]
#ncatergories = 2
#niterations = 200
#fig = plt.figure()
#for kk in nstates:
#    glmhmm = ssm.HMM(kk, obsdim, inputdim,  observations = 'input_driven_obs',
#                observation_kwargs = dict(C = ncatergories), transitions = "standard")
#
#    glmhmm2fit = glmhmm.fit(Y, inputs = inputs, method = 'em', num_iters = niterations, tolerance = 10**-5)
#
#    # plot LL convergence
#    plt.plot(glmhmm2fit, label = str(kk))
#
#    # axis options
#    plt.title(str(kk) + ' states')
#    plt.legend(loc = 'lower right')
#    plt.xlabel('EM iteration')
#    plt.xlim(0, len(glmhmm2fit))
#    plt.ylabel('log probability')
#    plt.show()
    
##saving figures
#path = os.path.join(rootpath, str(animalid[animalidx]))
#plt.savefig(path + '/LL_' + str(animalid[animalidx])+ '_' + str(inputdim) + 'inputs' + '.png')    

In [111]:
# setting glm-hmm parameters
nstates = 3
obsdim = 1
inputdim = inputs[0].shape[1]
ncatergories = 2
niterations = 150

glmhmm = ssm.HMM(nstates, obsdim, inputdim,  observations = 'input_driven_obs',
                observation_kwargs = dict(C = ncatergories), transitions = "standard")

glmhmm2fit = glmhmm.fit(Y, inputs = inputs, method = 'em', num_iters = niterations, tolerance = 10**-5)

  0%|          | 0/150 [00:00<?, ?it/s]

In [112]:
# get expected states
posterior_pstate = [glmhmm.expected_states(data = data, input = inputs)[0]
                    for data, inputs
                    in zip(Y, inputs)]

In [113]:
# getting states (discrete)
sessions = np.unique(data[animalflags, 1])
nsessions = len(sessions)
Z = []
for ss in range(nsessions):
    sessionflags = data[:, 1] == sessions[ss]
    ntrials = sum(sessionflags)
    posteriors = posterior_pstate[ss]
    Z.append(np.zeros([ntrials, 1], dtype = int))
    
    for tt in range(ntrials):
        Z[ss][tt] = np.where(posteriors[tt,:] == max(posteriors[tt,:]))


  if sys.path[0] == '':


In [114]:
# session stack to make psychometrics
statesflat = np.array([item for sublist in Z for item in sublist])
evsflat = np.array([item for sublist in normtt_deltaEV for item in sublist])
stimflat = np.array([item for sublist in tt_deltaEV for item in sublist])
choicesflat = data[animalflags, 8]
prevchoiceflat = np.array([item for sublist in prevchoice for item in sublist])
nstim = np.unique(stimflat).shape[0]

In [124]:
# summary of learned glmm parameters and hmm states
fig,ax = plt.subplots(1,3, figsize=(15, 5))

# plot glm weights for each covariate
w = - glmhmm.observations.params;
plt.subplot(1, 3, 1)
for kk in range(nstates):
    plt.plot(range(inputdim), w[kk][0], marker = 'o', 
             color = colors[kk], alpha = .7, linestyle = '-', label = 'state ' + str(kk+1))

plt.ylabel('glm weight')
plt.xticks(range(inputdim), ['$\Delta$ EV', 'bias'])
plt.axhline(y = 0, color = 'k', alpha = .5, ls = '--')
plt.legend()
plt.title('glm weights')
ax[0].spines['top'].set_visible(False)
ax[0].spines['right'].set_visible(False)

# plot transition matrix
plt.subplot(1, 3, 3)
transmat = np.round(glmhmm.transitions.transition_matrix, 2)
plt.imshow(transmat, vmin = -.8, vmax = 1, cmap = 'bone')

for ii in range(nstates):
    for jj in range(nstates):
        text = plt.text(jj, ii, str(transmat[ii, jj]), ha = 'center', va = 'center', color = 'k')

plt.xlim(-.5, nstates - .5)
plt.xticks(range(0, nstates), ('1', '2', '3'))
plt.yticks(range(0, nstates), ('1', '2', '3'))
plt.ylim(nstates - .5, -.5)
plt.ylabel('state at trial (t)')
plt.xlabel('state at trial (t + 1)')
plt.title('transition matrix')

# plot fractional occupancy 
occ = np.zeros([nstates, 1])
plt.subplot(1, 3, 2)
for zz in range(nstates):
    stateflags = statesflat == zz
    occ = sum(stateflags)/len(statesflat)
    plt.bar(zz, occ, width = .75, color = colors[zz], alpha = .7)

# axis options
plt.ylim((0,1))
plt.xlabel('state')  
plt.ylabel('state fractional occupancy')
plt.xticks([0,1,2], ['1','2','3'])
ax[1].spines['top'].set_visible(False)
ax[1].spines['right'].set_visible(False)

##saving figures
path = os.path.join(rootpath, str(animalid[animalidx]))
plt.savefig(path + '/learned_params' + str(animalid[animalidx])+ '_' + str(inputdim) + 'inputs' + '.png')    

In [119]:
# plotting discrimination behavior for each state (4-parameter psycurve) 
fig, ax = plt.subplots(1, 1, figsize = (10, 10))

plottery = np.zeros([nstim, nstates])
semlottery = np.zeros([nstim, nstates])
for zz in range(nstates):
    stateflags = statesflat[:,0] == zz
    
    for st in range(nstim):
        stimflags = stimflat[:,0] == np.unique(stimflat)[st]
        
        lotterychoices = choicesflat[stateflags & stimflags]
        plottery[st,zz] = np.mean(lotterychoices)
        semlottery[st, zz] = (np.std(lotterychoices))/(np.sqrt(len(lotterychoices)))
        
    # plot psychometric curves
    plt.subplot(1, 1, 1)
    plt.errorbar(np.unique(stimflat), plottery[:,zz], semlottery[:,zz], color = colors[zz],
                marker = 'o', linestyle = 'None')
    
    popt, _ = curve_fit(psycurves, np.unique(stimflat), plottery[:,zz], xtol = 1e-5)
    xx = np.linspace(min(np.unique(stimflat)) - 10, max(np.unique(stimflat)) + 10, 1000)
    plt.plot(xx, psycurves(xx, *popt), c = colors[zz], lw = 2)
    
    
    # reference lines
    plt.axvline(0, linestyle = ':', c = 'k',  lw = .5)
    plt.axhline(.5, linestyle = ':', c = 'k', lw = .5)
    plt.text(100, .2 - 0.05 * (zz + 1), 'trials in state '+ str(zz + 1) + ' = ' + str(sum(stateflags)), color = colors[zz])
    
    # axis options
    plt.ylim((-.05, 1.05))
    plt.xlabel('$\Delta$ EV')  
    plt.ylabel('p(risk)')
    ax.set_yticks([0, .5, 1])
    ax.set_xticks(np.unique(stimflat))
    plt.xticks(fontsize = 7.5)
    plt.yticks(fontsize = 7.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

##saving figures
#path = os.path.join(rootpath, str(animalid[animalidx]))
#plt.savefig(path + '/psycurves_' + str(animalid[animalidx]) + '_' + str(inputdim) + 'inputs' + '.png')    

In [118]:
# plot psycurves glm 
fig, ax = plt.subplots(1, 2, figsize = (10, 10))
for zz in range(nstates):
    plt.subplot(1,2,1)
    stateflags = statesflat[:,0] == zz
    weights = w[zz, :]
    print(weights)
    stimvals = np.linspace(np.min(evsflat), np.max(evsflat), 1000)
    xxstimset = np.linspace(np.min(stimflat), np.max(stimflat), 1000)
    xx = np.array([
        stimvals,
        np.repeat(1, len(stimvals))]).T

    wx = np.matmul(xx, weights[0])
    plt.plot(xxstimset, expit(wx), color = colors[zz], lw = 3)

    # reference lines
    plt.axvline(0, linestyle = ':', c = 'k',  lw = .5)
    plt.axhline(.5, linestyle = ':', c = 'k', lw = .5)
    plt.text(50, .2 - 0.05 * (zz + 1), 'trials in state '+ str(zz + 1) + ' = ' + str(sum(stateflags)), color = colors[zz])
    
# axis options
plt.ylim((-.05, 1.05))
plt.xlabel('$\Delta$ EV', fontsize = 15)  
plt.ylabel('p(risk)', fontsize = 15)
ax[0].set_yticks([0, .5, 1], fontsize = 15)
#ax[0].set_xticks(np.unique(stimflat))
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
ax[0].spines['top'].set_visible(False)
ax[0].spines['right'].set_visible(False)

# plot glm weights for each covariate
w = - glmhmm.observations.params;
plt.subplot(1,2,2)
for kk in range(nstates):
    plt.plot(range(inputdim), w[kk][0], marker = 'o', 
             color = colors[kk], alpha = .7, linestyle = '-', label = 'state ' + str(kk+1))

plt.ylabel('glm weight', fontsize = 15)
plt.xticks(range(inputdim), ['$\Delta$ EV ', 'bias'], fontsize = 15)
plt.axhline(y = 0, color = 'k', alpha = .5, ls = '--')
plt.legend()
plt.title('glm weights')
ax[1].spines['top'].set_visible(False)
ax[1].spines['right'].set_visible(False)

[[6.55072632 0.33903388]]
[[ 9.95730287 -0.83309305]]
[[ 2.28027028 -1.69870906]]


In [53]:
print(expit(-1.7044))
print(1/(1+np.exp(1.7044)))

0.15389147308888582
0.15389147308888582


In [123]:
# session summary with psycurves & glm parameters
for ss in range(nsessions):
    fig = plt.figure(figsize = (25, 10))
    fig.subplots_adjust(left = 0.1, wspace = 0.9)
    
    # plot choices
    sessionflags = data[:, 1] == sessions[ss]
    ntrials = sum(sessionflags)
    choice = data[sessionflags, 8]
    lotteryoutcome = data[sessionflags, 7]
    stimset = np.unique(tt_deltaEV[ss])
    
    ax1 = plt.subplot2grid((2,4), (1,0), colspan = 3)
    for tt in range(ntrials):
        stim = tt_deltaEV[ss][tt][0]
        stimidx = np.where(stimset == stim)[0][0]
        if choice[tt] == 0:
            trialpatch = matplotlib.patches.Rectangle((tt, 0), 1, 1, color = stimcolors[stimidx])
            ax1.add_patch(trialpatch) 
            
        elif choice[tt] == 1 & (lotteryoutcome[tt] != 0):         
            trialpatch = matplotlib.patches.Rectangle((tt, 0), 1, 1, color = stimcolors[stimidx], alpha = 0.7)
            trialpatch.set_hatch('*')
            ax1.add_patch(trialpatch) 
        elif choice[tt] == 1 & (lotteryoutcome[tt] == 0):
            trialpatch = matplotlib.patches.Rectangle((tt, 0), 1, 1, color = stimcolors[stimidx],alpha = 0.7)
            trialpatch.set_hatch('x')
            ax1.add_patch(trialpatch)
    
    # axis options
    plt.xlim((0, ntrials)) 
    plt.yticks([])
    plt.xlabel('trial number')
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    
    # legend options
    lotteryclrs = []
    lotteryvalues = []
    for st in range(len(stimset)):
        ev = stimset[st]
        clr = stimcolors[st]
        lotteryclrs.append(Line2D([0], [0], color=clr, lw=4))
        lotteryvalues.append('$\Delta$ E
    lotteryclrs.extend((mlines.Line2D([], [], color = 'k', lw = 4, marker = 's', linestyle = 'None'),
                        mlinesV' + str(ev))
.Line2D([], [], color = 'k', lw = 4, marker = '*', linestyle = 'None'),
                        mlines.Line2D([], [], color = 'k', lw = 4, marker = 'x', linestyle = 'None')))
    
    lotteryvalues.extend(('surebet', 'win', 'lose'))
    
    ax1.legend(lotteryclrs, lotteryvalues, bbox_to_anchor=[1, 1], title = '$\\bf{trial type}$')

    
    # plot posteriors        
    ax2 = plt.subplot2grid((2,4), (0,0), colspan = 3, sharex = ax1)
    for kk in range(nstates):
        plt.plot(posterior_pstate[ss][:, kk], label = 'state ' + str(kk + 1), lw = 1.5,
                color = colors[kk])
    plt.ylim((-0.01, 1.01))
    plt.yticks([0, 0.5, 1])
    plt.xlabel('trial number')
    plt.ylabel('p(state)')
    plt.title('session number ' + str(ss+1))
    plt.legend()
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    
    # plot glm weights for each covariate
    ax3 = plt.subplot2grid((2,4), (0,3), rowspan = 1)
    w = - glmhmm.observations.params;
    for kk in range(nstates):
        ax3.plot(range(inputdim), w[kk][0], marker = 'o', 
                 color = colors[kk], alpha = .7, linestyle = '-', label = 'state ' + str(kk+1))

    plt.ylabel('glm weight')
    plt.xticks(range(ninputs), ['$\Delta$ EV ', 'bias', 'prev choice'])
    plt.axhline(y = 0, color = 'k', alpha = .5, ls = '--')
    plt.legend()
    plt.title('glm weights')
    ax3.spines['top'].set_visible(False)
    ax3.spines['right'].set_visible(False)

    # plot psycurves
    ax4 = plt.subplot2grid((2,4), (1,3), rowspan = 1)
    for kk in range(nstates):
        stateflags = statesflat[:,0] == kk
        weights = w[kk, :]
        stimvals = np.linspace(np.min(evsflat)-1, np.max(evsflat)+1, 1000)
        xxstimset = np.linspace(np.min(stimflat), np.max(stimflat), 1000)
        xx = np.array([
            stimvals,
            np.repeat(1, len(stimvals))]).T

        wx = np.matmul(xx, weights[0])
        ax4.plot(xxstimset, expit(wx), color = colors[kk])
        plt.text(50, .2 - 0.05 * (kk + 1), 'trials in state '+ str(kk + 1) + ' = ' + str(sum(stateflags)), color = colors[kk])
    
    # reference lines
    plt.axvline(0, linestyle = ':', c = 'k',  lw = .5)
    plt.axhline(.5, linestyle = ':', c = 'k', lw = .5)
        
    # axis options
    plt.ylim((-.05, 1.05))
    plt.xlabel('$\Delta$ EV')  
    plt.ylabel('p(risk)')
    ax4.set_yticks([0, .5, 1])
    #ax4.set_xticks(np.unique(stimflat))
    plt.xticks(fontsize = 7.5)
    plt.yticks(fontsize = 7.5)
    ax4.spines['top'].set_visible(False)
    ax4.spines['right'].set_visible(False)
    
    path = os.path.join(rootpath, str(animalid[animalidx]))
    plt.savefig(path + '/sessionsummary_' + str(ss+1) + '_' + str(animalid[animalidx]) + '_' + str(inputdim) + 'inputs' + '.png')    

In [594]:
for aa in range(nanimals):
    deltaEVset = np.unique(deltaEV[aa])
    #print('animal ' + str(aa + 1) + ': ' + '\n' + 'deltaEV set : ' + str(deltaEVset) + '\n'+'normalized deltaEV set : ' + str((deltaEVset/np.max(deltaEVset))) + '\n')