In [1]:
# %matplotlib inline
from opconNosepokeFunctions import *


toneStartMarker = 23
trialStartMarker = 81
trialHitMarker = 51
trialMissMarker = 86
trialUnrewMarker = 87
sessionStartMarker = 61
rewardProbMarker = 83
trialProbMarker = 88
lickMarker = 22
list_sessionMarker = [12, 13]

# trializer_v4 (now with lick times!!)

In [4]:
animal = 'Raltz'
box = 1
date = '2024-05-20'

files = exclude_files(get_files(f"L:/4portProb/Box{box}/{animal}/"))
files = [file for file in files if (file >=f'{animal}-{date}*.dat')]
files.sort()
list_sessionMarker = [13]

df = merge_files_to_df(files)

sessdf = trializer_v4(df, list_sessionMarker,
                      trialStartMarker, trialHitMarker,
                      trialMissMarker, sessionStartMarker,
                      trialProbMarker, rewardProbMarker,
                      lickMarker, animal)

In [3]:
def trializer_v4(df, list_sessionMarker,
                 trialStartMarker, trialHitMarker,
                 trialMissMarker, sessionStartMarker,
                 trialProbMarker, rewardProbMarker,
                 lickMarker, animal, arms = 4):
    
    ''' Function to convert a df of joined .dat files with column names in order :
    'event', 'value', 'sm', 'tp', 'smtime', 'eptime', 'dw', 'ww', 'gpio_lick'
    
    INPUTS: df, list of session state machine numbers, trial start marker, trial rewarded marker,
    trial miss/unrewarded marker, session start marker, trial probability marker, session reward prob marker, 
    lick marker, animal name, num arms in task
    
    OUTPUT: single trial-wise dataframe of any number of animals
    2024-11-22
    '''
    # initialize var
    sessnum = 0
    npcounter = 0
    sess_list = []
    rewprobl = []
    
    event_lookup = {
        trialHitMarker: 1,
        trialMissMarker: 0,
    }
    list_events = [trialStartMarker, trialHitMarker, trialMissMarker, trialProbMarker, sessionStartMarker,
                   trialProbMarker, rewardProbMarker, lickMarker]
    
    # filter events as np array to increase speed and avoid repetitive indexing
    filtered_events = df[(df['sm'].isin(list_sessionMarker)) & (df['event'].isin(list_events))].to_numpy()
    indices = np.where((filtered_events[:, 0] == trialStartMarker) | (filtered_events[:, 0] == sessionStartMarker) |
                       (filtered_events[:, 0] == rewardProbMarker))[0]
    
    for ind in indices:
        # extract information from the current trialStartMarker
        trialstart = filtered_events[ind, 4]
        port = filtered_events[ind, 1]
        datetime = filtered_events[ind, 5]
        task = filtered_events[ind, 2]
        event = filtered_events[ind, 0]
    
        # select events occurring after trialStartMarker
        nextEvents = filtered_events[ind:]
    
        try:
            # find the index of trialMissMarker or trialHitMarker in nextEvents
            status_indices = np.where((nextEvents[:, 0] == trialMissMarker) | (nextEvents[:, 0] == trialHitMarker))[0]
            lick_indices = np.where(((nextEvents[:, 0]) == lickMarker) & (nextEvents[:, 1] == 1) & (nextEvents[:, -8] == 1))[0]
            statusIndex = status_indices[0] if status_indices.size > 0 else None
            lickIndex = lick_indices[lick_indices<status_indices[1]] if status_indices.size > 1 else lick_indices
            
            ms_licktimes = str(nextEvents[lickIndex][:, 4]).strip('[]')
            ms_lastlick = nextEvents[lickIndex][:, 4][-1]
            s_licktimes = str(nextEvents[lickIndex][:, 5]).strip('[]')
    
            # find the index of trialProbMarker in nextEvents
            rewprob_index = np.where(nextEvents[:, 0] == trialProbMarker)[0][0] 
            #contLine: if np.any(nextEvents[:, 0] == trialProbMarker) else None
    
            if statusIndex is not None:
                # find reward, trialend, and rewprob based on identified events
                reward = event_lookup[nextEvents[statusIndex, 0]]
                trialend = nextEvents[statusIndex, 4]
                rewprob = nextEvents[rewprob_index, 1] if rewprob_index is not None else np.nan
                
            else:
                # if trialMissMarker or trialHitMarker not found, nan
                reward = np.nan
                trialend = np.nan
                rewprob = np.nan
                ms_lastlick = np.nan
                
        except IndexError:
            # set values to nan in case of issues
            reward = np.nan
            trialend = np.nan
            rewprob = np.nan
            ms_lastlick = np.nan
        rewprobl_str = str(np.array(rewprobl, dtype = int))
        
        # append to sess_list
        sess_list.append([npcounter, trialstart, port, reward, trialend, sessnum,
                          rewprob, datetime, task, event, rewprobl_str,
                          ms_licktimes, ms_lastlick, s_licktimes])
        npcounter += 1
    
        if filtered_events[ind, 0] == sessionStartMarker:
            # update session number when sessionStartMarker found
            rewprobl = []
            sessnum +=1
    
        if filtered_events[ind, 0] == rewardProbMarker:
            rewprobl.append(filtered_events[ind, 4])
    
    # create pandas dataframe from the list of all session data
    sessdf =  pd.DataFrame(sess_list, columns = ['trial', 'trialstart',
                                                 'port', 'reward', 'trialend',
                                                 'session', 'rewprob', 'eptime',
                                                 'task', 'event', 'rewprobfull',
                                                 'ms_licktimes', 'ms_lastlick', 's_licktimes'])
    
    # remove all lines corresponding to start of a session or reward prob marker
    sessdf = sessdf[(sessdf.event!=sessionStartMarker) & (sessdf.event!=rewardProbMarker)].reset_index(drop = True).drop(columns = 'event')
    
    # remove all lines that have (trialend-trialstart)> 5500 or negative (sum of tone + wait time)
    # is this check required?
    
    # reset trial numbers
    sessdf['trial'] = np.arange(len(sessdf))
    
    # add animal information
    sessdf['animal'] = animal
    
    # add datetime
    sessdf['datetime'] = pd.to_datetime(sessdf.eptime, unit='s')

    return sessdf

In [95]:
df['datetime'] = pd.to_datetime(df.eptime, unit='s')

In [103]:
df[(df.datetime > datetime.datetime(2024,11, 12, 3, 0, 0)) & (df.event == 22) & (df.datetime < datetime.datetime(2024,11,13, 3, 40, 0)) & (df.value == 0)]

Unnamed: 0,event,value,sm,tp,smtime,eptime,dw,ww,gpio_lick,datetime
6541394,22,0.0,13.0,0.0,57051.0,1.731380e+09,3259.0,85.0,0.0,2024-11-12 03:01:08
6541452,22,0.0,13.0,0.0,64514.0,1.731380e+09,3260.0,86.0,0.0,2024-11-12 03:01:15
6541456,22,0.0,13.0,0.0,64623.0,1.731380e+09,3260.0,86.0,0.0,2024-11-12 03:01:15
6541524,22,0.0,13.0,0.0,72849.0,1.731380e+09,3261.0,87.0,0.0,2024-11-12 03:01:24
6541547,22,0.0,13.0,0.0,77885.0,1.731380e+09,3261.0,87.0,0.0,2024-11-12 03:01:29
...,...,...,...,...,...,...,...,...,...,...
6570627,22,0.0,13.0,0.0,1390720.0,1.731468e+09,181.0,378.0,0.0,2024-11-13 03:23:22
6570705,22,0.0,13.0,0.0,1405916.0,1.731468e+09,182.0,379.0,0.0,2024-11-13 03:23:37
6570710,22,0.0,13.0,0.0,1408507.0,1.731468e+09,182.0,379.0,0.0,2024-11-13 03:23:39
6570719,22,0.0,13.0,0.0,1424203.0,1.731468e+09,182.0,379.0,0.0,2024-11-13 03:23:55


In [85]:
times = sessdf[sessdf.datetime>datetime.datetime(2024,11,13)].s_licktimes.to_numpy()
times = [xs.split() for xs in times]
times = np.array([int(float(x)) for xs in times for x in xs])

In [92]:
np.unique(np.append(sessdf[sessdf.datetime>datetime.datetime(2024,11,13)].eptime, np.unique(times)))

array([1.73146684e+09, 1.73146685e+09, 1.73146687e+09, 1.73146689e+09,
       1.73146689e+09, 1.73146690e+09, 1.73146690e+09, 1.73146691e+09,
       1.73146692e+09, 1.73146692e+09, 1.73146693e+09, 1.73146694e+09,
       1.73146694e+09, 1.73146695e+09, 1.73146696e+09, 1.73146697e+09,
       1.73146699e+09, 1.73146700e+09, 1.73146701e+09, 1.73146702e+09,
       1.73146702e+09, 1.73146703e+09, 1.73146703e+09, 1.73146703e+09,
       1.73146704e+09, 1.73146706e+09, 1.73146706e+09, 1.73146706e+09,
       1.73146708e+09, 1.73146708e+09, 1.73146709e+09, 1.73146710e+09,
       1.73146711e+09, 1.73146712e+09, 1.73146713e+09, 1.73146714e+09,
       1.73146715e+09, 1.73146715e+09, 1.73146716e+09, 1.73146717e+09,
       1.73146717e+09, 1.73146718e+09, 1.73146718e+09, 1.73146719e+09,
       1.73146719e+09, 1.73146720e+09, 1.73146720e+09, 1.73146720e+09,
       1.73146722e+09, 1.73146722e+09, 1.73146723e+09, 1.73146725e+09,
       1.73146726e+09, 1.73146727e+09, 1.73146727e+09, 1.73146728e+09,
      

In [9]:
### exclude files if required, by expression
# default: exclude_files(files, exclusion=['Roselia.+.dat', 'test-2022.+.dat', 'test-2023.+.dat', 'Azurill.+.dat'])
animal = 'Togepi'
box = 7
date = '2024-08-04'

files = exclude_files(get_files(f"L:/4portProb/Box{box}/{animal}/"))
files = [file for file in files if (file >=f'{animal}-{date}*.dat')]
files.sort()
list_sessionMarker = [12, 13]

df = merge_files_to_df(files)
sessdf = trializer_v3(df, list_sessionMarker,
                      trialStartMarker, trialHitMarker,
                      trialMissMarker, sessionStartMarker,
                      trialProbMarker, rewardProbMarker, animal)
sessdf.task = sessdf.task.replace({13:'unstr', 12:'str'})
# reducing filesize on disk
sessdf["animal"] = sessdf["animal"].astype("category")
sessdf["task"] = sessdf['task'].astype("category")
# sessdf["rewprobfull"] = sessdf['rewprobfull'].astype(str).astype("category")

In [None]:
# check which sms on this day
stateMachines = np.sort(df['sm'].unique())
print(stateMachines)
SMlist = ['DailyWater0', 'WeeklyWater1', 'HouseLightsOn2', 'HouseLightsOff3', 'HouseLightsOnToneWater4',
           'BlinkLightsToneWater5', 'NosepokeNonSession6', 'LickTest7', 'LickTrain8', 'Nosepoke2sound9','Nosepoke1sound10',
          'NosepokeProb7011', 'NosepokeStr12', 'NosepokeUNStr13', 'NosepokeProb4014','ReadSessionParams15','NosepokeNonSession216','DoNothing17']
# session_rew_daily=(max(df[(df[2]<11) & (df[2]>6)][6]) - min(df[(df[2]<11) & (df[2]>6)][6]))
# non_session_rew_daily=max(df[df[2].isin([6,14])][6])-min(df[df[2].isin([6,14])][6])
# print('max day reward from sessions =', session_rew_daily)
# print('max day reward from non-session =',non_session_rew_daily)
# print('total reward through np =', session_rew_daily+non_session_rew_daily)

# Nosepoke session analysis

In [33]:
# %matplotlib nbagg
%matplotlib qt
import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle
# import matplotlib.pyplot as plt
# import seaborn as sns


# Create a function to update the plot based on user input
def update_plot(session_range):
    sessdf_filtered = sessdf[sessdf['session'].between(*session_range)]
    markers = {0: "$o$", 1: "o"}
#     markers = {0: 2, 1:3}
#     fig = plt.figure(figsize=(15, 8))
    g = sns.relplot(data=sessdf_filtered, y='port', x='trial', 
                    hue='rewprob',
                    aspect=11.7/5.5,
                    linewidth=0, s=100, palette = 'Blues', marker = '$o$',
                    markers=markers, style='reward',
                    hue_norm=(-20, 80))
    axes = g.axes.flatten()
    last = sessdf.groupby('session')['trial'].max().to_list()
    start = sessdf.groupby('session')['trial'].min().to_list()
    highport = [ind+1 for i in sessdf.groupby('session').head(1).rewprobfull for ind, j in enumerate(i) if j>=80]
    for ax in axes:
        for ind, l in enumerate(last):
            ax.axvline(l, linewidth=1, color='grey')
#             ax.axvline(, linewidth = 1, color = 'red')
            ax.set_yticks([1, 2, 3, 4])
            ax.add_patch(Rectangle((start[ind]-1, highport[ind]-0.5), l-start[ind], 1, fc = 'xkcd:blue', alpha = 0.1))
    
    plt.ylim(0.5, 4.5)
    plt.title('Choices', y=1.05)
#     g._legend.set_title('Rew. %')
    g._legend.texts[0].set_text("Reward %")
    g._legend.texts[5].set_text("Outcome")
#     plt.xticks(ticks = [], labels = [])

    plt.xlabel(' ', labelpad=20)
    plt.ylabel('Port')
    
    # Calculate x-axis limits based on the displayed sessions
    x_limit = (sessdf_filtered['trial'].min(), sessdf_filtered['trial'].max())
    plt.xlim(*x_limit)
    sns.despine(bottom = True)
#     plt.savefig('C:/Users/dlab/OneDrive - Indian Institute of Science/Drawings/plots/15012024/eg.svg', dpi = 600)
    plt.show()

# Define the session range using a widget
session_range_slider = widgets.IntRangeSlider(
    value=[sessdf['session'].min(), sessdf['session'].max()],
    min=sessdf['session'].min(),
    max=sessdf['session'].max(),
    step=1,
    description='Session Range:',
    continuous_update=False
)

# Create an interactive plot using interactive
interactive_plot = widgets.interactive(update_plot, session_range=session_range_slider)

# Display the interactive plot
display(interactive_plot)
# plt.xticks(ticks = [], labels = [])
# plt.tight_layout()
# sns.set_context('talk')
# plt.title('Blissey - DLS then DMS lesion')

interactive(children=(IntRangeSlider(value=(1, 50), continuous_update=False, description='Session Range:', max…

In [50]:
# bias analysis
# set_cwd('/home/rishika/sim/')
# from supplementaryFunctions import *
trialsinsess = 0
# new_unstr_sess_dict = loader('new_unstructured')
# sessdf = new_unstr_sess_dict['/home/rishika/sim/NAS_rishika/4portProb/Box2']
fig = plt.figure(figsize = (8,7))
ax = plt.subplot(111)
filtered = sessdf.groupby('session').filter(lambda x: x.reward.size >= trialsinsess)
for i in range(1,5):
    fewih = filtered[filtered['port']==i].groupby(['session', 'rewprob']).count()['port']/filtered.groupby(['session']).size()

    sns.scatterplot(data=filtered[filtered['port']==i].groupby(['session', 'rewprob']).count(), 
            x = 'rewprob',
            y=filtered[filtered['port']==i].groupby(['session', 'rewprob']).count()['port']/filtered.groupby(['session']).size(), 
                    alpha=0.1, ax=ax)
    sns.lineplot(data=fewih.groupby('rewprob').mean(), ax=ax, legend = 'auto', linewidth = 2)
#     ax.set_title(box)
    ax.set_ylabel('choice probability')
    ax.set_xlabel('reward percent')
    sns.despine()
    ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
plt.suptitle(f'Choice Probability vs. Reward Percent, per arm, animal = {animal}')

plt.tight_layout()
plt.show()

  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3', '_','_','4', '_'])
  ax.legend(['_','1','_', '_','2', '_', '_','3',

In [104]:
# reward collection latency
g = sns.relplot(data=sessdf, x=sessdf['trialend']-sessdf['trialstart'], y='trial', hue='reward',
            palette=['k', 'r'])
axes = g.axes.flatten()

# find the last trial number of a session
last = sessdf.groupby('session')['trial'].max()
for ax in axes:
    for l in last:
        ax.axhline(l, linewidth = 0.7, color='grey')
        ax.set_xlabel('Latency (ms)')
        ax.set_xlim(0,5600)
        ax.set_ylim(bottom=-5)
sns.set_context('notebook')
plt.suptitle('Latency', y=1.05)
# plt.savefig('plots/25082023/dms_D03_latency.png', dpi = 300)

  self._figure.tight_layout(*args, **kwargs)


Text(0.5, 1.05, 'Latency')

In [None]:
# moving average rew rate
window = 50
mov_avg = np.convolve(sessdf['reward'].to_list(), np.ones(window))/window
g = sns.relplot(data = mov_avg, color = 'k', kind = 'line')
axes = g.axes.flatten()

# find the last trial number of a session
last = sessdf.groupby('session#')['trial#'].max()
for ax in axes:
    for l in last:
        ax.axvline(l, linewidth = 1, color='grey')
        ax.axhline(np.mean(df[df[0]==83][0:4][4]/100), c='r')
        ax.set_xlabel('trial#')
        ax.set_ylabel(f'Moving average reward rate, window = {window}')
ax.set_ylim([0, 1])
ax.set_xlim(window, len(mov_avg)-window)
# plt.savefig('plots/25082023/dms_D03_mavgrr.png', dpi = 300)

In [201]:
# number of trials in each session
g = sns.relplot(data=sessdf['session'].value_counts(), kind='scatter', color='k')
g.set_xlabels('session')
g.set_ylabels('# of trials')
# for ax in g.axes.flat:
#     for i in sessdf['session']:
#         if i%3==0:
#             ax.axvline(x=i, color='xkcd:light grey', linestyle='-.')
plt.axhline(np.mean(sessdf['session'].value_counts()), c='r')
# plt.savefig('plots/25082023/dms_D03_numtrials.png', dpi = 300)

  self._figure.tight_layout(*args, **kwargs)


<matplotlib.lines.Line2D at 0x2021b1e97b0>

In [None]:
prob = winstay_loseshift(sessdf)
plt.bar(['Rew_Stay', 'Rew_Shift', 'No_Rew_Stay', 'No_Rew_Shift'], prob, color=['r', 'r', 'k', 'k'])
sns.despine()
print(sum(prob))
plt.title(f'Probabilities for win-stay, lose-shift. Sum of bars = {sum(prob[0:2]), sum(prob[2:])}')

In [None]:
arms = 4
rewardProb = rew_prob_extractor(df, arms, rewardProbMarker)
# calculate cumulative regret for all sessions     
sess_regret = {}
for sess, prob in rewardProb.items():
    max_arm = prob.index(max(prob))+1
    sess_regret[sess] = max_arm - (sessdf[sessdf['session#']==sess]['port'].values)
    
regret = [num for sublist in list(sess_regret.values()) for num in sublist]

# plot all rewardProb
fig = plt.figure(figsize = (10, 4))
for sess in range(len(rewardProb)):
    ax = plt.subplot(1, 3, sess+1)
    ax.bar(np.arange(1,5), rewardProb[sess], color='k')
    ax.set_title(f'session # {sess}')
    ax.set_xticks([1,2,3,4])
    ax.set_xlabel('port')
    ax.set_ylabel('reward percent')
    ax.spines[['right', 'top']].set_visible(False)
plt.tight_layout()
plt.suptitle('Reward probabilities', y=1.05)

In [None]:
###### plot cumulative regret 
g = sns.relplot(data = sessdf, x = 'trial#', y = np.cumsum(regret), color = 'k', kind = 'line')
# find the last trial number of a session
last = sessdf.groupby('session#')['trial#'].max()
for ax in g.axes.flat:
    for l in last:
        ax.axvline(l, linewidth = 1, color='grey')
ax.set_ylabel('cumulative regret')

In [None]:
# what is distance moved given last trial was not rewarded?
dist = []
dist_unrew = []
dist_rew = []
for i in range(len(sessdf)):
    if i+1<len(sessdf):
        if sessdf['reward'][i]==0:
            dist_unrew.append(sessdf['port'][i]-sessdf['port'][i+1])
        else:
            dist_rew.append(sessdf['port'][i]-sessdf['port'][i+1])
plt.plot(dist_unrew, 'k')
sns.despine()
# plt.plot(dist_rew, 'r')