In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os.path
import os
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
from pathlib import Path

In [None]:
figureTargetFolder = Path(r"C:\Users\neurogears\Documents\git\CF_Hardware\device.pump\Exp_Data&Code\AnalysisCode\Figures")

#Global Viz settings
sns.set_style('darkgrid') # darkgrid, white grid, dark, white and ticks
plt.rc('axes', titlesize=18)     # fontsize of the axes title
plt.rc('axes', labelsize=14)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=13)    # fontsize of the tick labels
plt.rc('ytick', labelsize=13)    # fontsize of the tick labels
plt.rc('legend', fontsize=13)    # legend fontsize
plt.rc('font', size=13)          # controls default text sizes

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'Arial'

In [None]:
def compute_probability_matrix(df):
    prot_values = np.sort(df.nprotocols_X.unique())
    prot_values = [int(el) for el in prot_values]
    prot_matrix = np.ndarray((len(prot_values), len(prot_values)))

    for j in range(len(prot_values)): #notX
        for i in range(len(prot_values)): #X
            all_valid_choices = df.query(f'nprotocols_X == {prot_values[j]} and nprotocols_notX == {prot_values[i]} and trial_type == "valid"').reward_pump.values
            count_X = list(all_valid_choices).count('X')
            if len(all_valid_choices) > 0:
                prob_X = count_X/len(all_valid_choices)
            else:
                prob_X = np.nan

            prot_matrix[i,j] = prob_X

    return prot_values, prot_matrix


In [None]:
unibandits = pd.read_csv('../../ExperimentalData/Behavior/bhv_dataset.csv')

animal_list = ['Aluminium', 'Silicon']

In [None]:
probdf = pd.DataFrame.from_dict(unibandits.query('valid_trial == True and notlast15 == True').groupby(['animal', 'abs_log2_X_over_notX', 'valid_within_block']).groups.keys())
probdf = probdf.rename(columns={0: 'animal', 1:"abs_log2_X_over_notX", 2:"valid_within_block"})

probdf['high_chosen'] = probdf.apply(lambda x: unibandits.query(f'valid_trial == True and animal == "{x.animal}" and abs_log2_X_over_notX == {x.abs_log2_X_over_notX} and valid_within_block == {x.valid_within_block}').high_chosen.values, axis = 1)

probdf['total_count'] = probdf.high_chosen.apply(lambda x: len(x))
probdf['high_count'] = probdf.high_chosen.apply(lambda x: np.sum(x))

probdf['probability_high'] = probdf.apply(lambda x: x.high_count / x.total_count, axis = 1)


In [None]:
cdict1 = {'red':   ((0.0, 0.0, 0.0),
                   (0.5, 0.0, 0.1),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 0.0, 1.0),
                   (0.5, 0.1, 0.0),
                   (1.0, 0.0, 0.0))
        }
blue_red1 = LinearSegmentedColormap('BlueRed1', cdict1)

fig, axs = plt.subplots(2,2, figsize = (10,8), tight_layout = True)

for i in range(2):
    prot_values, prot_matrix = compute_probability_matrix(unibandits.query(f'animal == "{animal_list[i]}"'))
    sns.heatmap(ax = axs[0,i], data = prot_matrix, center = 0.5, cmap = blue_red1, vmin = 0, vmax = 1)
    axs[0,i].set_xlabel('X')
    axs[0,i].set_ylabel('notX')
    axs[0,i].set_title(f'{animal_list[i]}')
    axs[0,i].set_xticklabels(prot_values)
    axs[0,i].set_yticklabels(prot_values)


    sns.lineplot(ax = axs[1,i], data = probdf.query(f'animal == "{animal_list[i]}"'), x = 'valid_within_block', y = 'probability_high', hue = 'abs_log2_X_over_notX')
    axs[1,i].set_ylabel('P (high)')
    axs[1,i].set_xlabel('Trial since transition')
    axs[1,i].set_xlim(0,30)
    axs[1,i].set_ylim(0,1)
    axs[1,i].set_yticks((0,0.5, 1))


fig.savefig(figureTargetFolder / "BehaviorAnalysis.pdf")

plt.show()

Example blocks

In [None]:
# example session
sessdf = unibandits.query(f'animal == "Aluminium" and date == 220603 and trial_type == "valid"')

sessdf['Xcount_previous_15_choices'] = np.nan

for ii in range(15, len(sessdf.index)-1):
    sessdf.loc[sessdf.index[ii], 'Xcount_previous_15_choices'] = list(sessdf.loc[sessdf.index[ii-15:ii], 'reward_pump'].values).count('X')

sessdf['probX_last15'] = sessdf.Xcount_previous_15_choices.apply(lambda x: x/15).astype(float)

egdf = sessdf.query('blockno > 1')
egdf['choice'] = egdf.probX_last15.apply(lambda x: (x-.5)*8)

poke_color_dic = {
    'X': '#813cb0',
    'notX': '#419c65'}

colordf = pd.DataFrame()
colordf['blockno'] = egdf.blockno.unique()
colordf['start_trial'] = colordf.blockno.apply(lambda x: egdf.query(f'blockno == {x}').trialno.values[0])
colordf['end_trial'] = colordf.start_trial.shift(-1)
colordf.loc[colordf.index[-1], 'end_trial'] = egdf.trialno.values[-1]
colordf['highest_pump'] = colordf.blockno.apply(lambda x: egdf.query(f'blockno == {x}').highest_pump.values[0])
colordf['color'] = colordf.highest_pump.map(poke_color_dic)

In [None]:
fig, axs = plt.subplots(2, sharex = True, figsize = (10,8), tight_layout = True)

sns.lineplot(ax = axs[0], data = egdf, x = 'trialno', y = 'probX_last15', color = 'black', label = 'choice')

for ii in range(len(colordf)):
    for aa in range(2):
        axs[aa].axvspan(colordf.loc[ii, 'start_trial'], colordf.loc[ii, 'end_trial'], facecolor = colordf.loc[ii, 'color'], alpha = .1)

axs[0].axhline(.5, color = 'black', lw = 1)

sns.lineplot(ax = axs[1], data = egdf, x = 'trialno', y = 'nprotocols_X', label = 'reward X', color = poke_color_dic['X'])
sns.lineplot(ax = axs[1], data = egdf, x = 'trialno', y = 'nprotocols_notX', label = 'reward notX', color = poke_color_dic['notX'])


for bb in (egdf.query('block_changed == True').trialno.values)[1:]:
    for aa in range(2):
        axs[aa].axvline(bb, color = 'grey', ls = 'dashed', lw = 1)

plt.xlim(egdf.trialno.values[0], egdf.trialno.values[-1])

axs[0].set_ylabel('P(X)')
axs[1].set_ylabel('Reward (number of protocols)')
axs[1].set_xlabel('Trials')
axs[0].set_ylim(0,1)

axs[1].set_yscale('log', base = 2)
axs[1].yaxis.set_major_formatter(ticker.ScalarFormatter())
axs[1].yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))

plt.suptitle('Aluminium 220603')
sns.despine(top = True)
plt.savefig('example_blocks_background_rwd_2subplots.png', facecolor = "white")
plt.show()