## Choice response coding
This code defines time bins where firing rates are highly modulated by choice response factor, rather than the identity of an object. R-squared values between the basic MLR model and extended model were compared. The extended model further includes choice response term.

In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd

from scipy import stats
from scipy.ndimage import gaussian_filter
from sklearn.linear_model import LinearRegression

import matplotlib as mpl
import matplotlib.pyplot as plt  

from datetime import date
import time

import random

from joblib import Parallel, delayed

import h5py

In [2]:
# no top and right spines in all plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [3]:
mother_path = Path('D:/Multi-modal project/')

### Parameter setting

In [4]:
gauss_sigma = 1

# colors for multimodal, vis-only, aud-only conditions
color = ['mediumorchid','cornflowerblue','lightcoral','gray']
color2 = ['cyan','magenta','brown']
linestyle = ['-',':']

today = str(date.today())

### Data preparation

In [12]:
save_path = mother_path /'analysis'/'result'/'4. Choice response coding'/today
cell_path = mother_path/'analysis'/'result'/'zFR export'/'13-Apr-2022 (5 trials)'
data_path = mother_path /'analysis'/'result'/'3. Multiple regression for object-selectivity'

cell_list = os.listdir(cell_path)

# load hdf5 files containing shuffled results
f = h5py.File(data_path/'2023-04-28'/'2023-04-28_multiple_regression_result.hdf5','r')

# Data analysis

In [6]:
def plot_SDF_beta(df,linewidth,smooth,save,save_format):
    """
    This function plots mean firing rate patterns of each stimulus condition
    and beta coefficients for visual and auditory regressors in multiple linear regression.
    
    Additionally, this function plots beta coefficients of different subsampled data set
    to verify the consistency of subsampling method.
    """
    cond = [(df.Type=='Multimodal')&(df.RWD_Loc==boy_goal),
            (df.Type=='Multimodal')&(df.RWD_Loc==egg_goal),
            (df.Type=='Visual')&(df.RWD_Loc==boy_goal),
            (df.Type=='Visual')&(df.RWD_Loc==egg_goal),
            (df.Type=='Auditory')&(df.RWD_Loc==boy_goal),
            (df.Type=='Auditory')&(df.RWD_Loc==egg_goal),
            (df.Type=='Elemental')&(df.RWD_Loc==boy_goal),
            (df.Type=='Elemental')&(df.RWD_Loc==egg_goal)]
    
    fr_id = df.columns.get_loc('Var10')  # get the index of the first firing rate column
    
    fr_mean = np.zeros((len(cond),95))
    fr_sem = np.zeros((len(cond),95))
    for i in range(len(cond)):
        fr_mean[i,:] = df[cond[i]].iloc[:,fr_id:fr_id+95].to_numpy().mean(axis=0)
        fr_sem[i,:] = stats.sem(df[cond[i]].iloc[:,fr_id:fr_id+95].to_numpy())
    
    if smooth:
        for i in range(len(cond)):
            fr_mean[i,:] = gaussian_filter(fr_mean[i,:],sigma=gauss_sigma)
            fr_sem[i,:] = gaussian_filter(fr_sem[i,:],sigma=gauss_sigma)
            
    y_max = np.ceil(np.max(fr_mean+fr_sem))
    y_min = np.ceil(np.abs(np.min(fr_mean-fr_sem)))*-1
    
    cell_file_name = cell_name.strip('.csv')
    fig,ax = plt.subplots(4,2,figsize=(8,10))
    plt.suptitle(cell_file_name,fontsize=15);
    x = np.arange(95)*10
    
    ax[0,0].plot(x,r_squared_diff,linewidth=linewidth,color='black')
    ax[0,0].axhline(y=r_squared_diff_crit,color='red',linestyle=':')
    ax[0,0].text(50,r_squared_diff_crit+0.005,f'{r_squared_diff_crit:.3f}',color='red')
    ax[0,0].set_xticks([0,400,950])
    ax[0,0].set_xlim([0,950])
    ax[0,0].set_xlabel('Time (ms)',fontsize=12)
    ax[0,0].set_ylabel('R-squared difference',fontsize=12)
    
    for i in range(len(cond)):
        if i%2==0:
            ls = linestyle[0]
        else:
            ls = linestyle[1]            
        ax[1,0].plot(x,fr_mean[i,:],color=color[int(np.floor(i/2))],linewidth=linewidth,linestyle=ls)
        ax[1,0].fill_between(x,fr_mean[i,:]-fr_sem[i,:],fr_mean[i,:]+fr_sem[i,:],color=color[int(np.floor(i/2))],alpha=0.2)
        
        ax[i%2+2,0].plot(x,fr_mean[i,:],color=color[int(np.floor(i/2))],linewidth=linewidth,linestyle=ls)
        ax[i%2+2,0].fill_between(x,fr_mean[i,:]-fr_sem[i,:],fr_mean[i,:]+fr_sem[i,:],color=color[int(np.floor(i/2))],alpha=0.2)
    
        ax[int(np.floor(i/2)),1].plot(x,fr_mean[i,:],color=color[int(np.floor(i/2))],linewidth=linewidth,linestyle=ls)
        ax[int(np.floor(i/2)),1].fill_between(x,fr_mean[i,:]-fr_sem[i,:],fr_mean[i,:]+fr_sem[i,:],color=color[int(np.floor(i/2))],alpha=0.2)
    
    t1 = ['All conditoins','Boy','Egg','Multimodal','Visual-only','Auditory-only','Control']
    for i in range(7):
        if i<3:
            ax[i+1,0].scatter(response_bin*10,np.tile(y_max-0.2,(len(response_bin),1)),marker='*',color='black',s=10)
            ax[i+1,0].set_title(t1[i],fontsize=13)
            ax[i+1,0].set_ylabel('z-scored FR',fontsize=12)
            ax[i+1,0].set_xlabel('Time (ms)',fontsize=12)
            ax[i+1,0].set_xticks([0,400,950])
            ax[i+1,0].set_xlim([0,950])
            ax[i+1,0].set_yticks(np.arange(y_min,y_max+0.1,1))
        else:
            ax[i-3,1].scatter(response_bin*10,np.tile(y_max-0.2,(len(response_bin),1)),marker='*',color='black',s=10)            
            ax[i-3,1].set_title(t1[i],fontsize=13)
            ax[i-3,1].set_ylabel('z-scored FR',fontsize=12)
            ax[i-3,1].set_xlabel('Time (ms)',fontsize=12)
            ax[i-3,1].set_xticks([0,400,950])
            ax[i-3,1].set_xlim([0,950])
            ax[i-3,1].set_yticks(np.arange(y_min,y_max+0.1,1))            
            
    plt.tight_layout()
    
    if save:
        fig_path = save_path/save_format/region
        if os.path.exists(fig_path) is False:
            os.makedirs(fig_path)    
        if save_format=='png':
            plt.savefig(fig_path/f'{cell_file_name}.png',dpi=100,facecolor='white')
        elif save_format=='svg':
            plt.savefig(fig_path/f'{cell_file_name}.svg')
        plt.close()

In [7]:
def save_result(f):
    """
    This function saves data into HDF5 format.
    """
    cell_group = f.create_group(str(cell_id))
    cell_group.create_dataset('beta_coef',data=beta_coef)
    cell_group.create_dataset('beta_coef_choice',data=beta_coef_choice)
    cell_group.create_dataset('beta_coef_shuffle',data=beta_coef_shuffle)
    cell_group.create_dataset('beta_coef_choice_shuffle',data=beta_coef_choice_shuffle)
    cell_group.create_dataset('r_squared',data=rsquare)
    cell_group.create_dataset('r_squared_choice',data=rsquare_choice)
    cell_group.create_dataset('r_squared_shuffle',data=rsquare_shuffle)
    cell_group.create_dataset('r_squared_choice_shuffle',data=rsquare_choice_shuffle)
    
    cell_group.attrs['Rat'] = rat_id
    cell_group.attrs['Region'] = region
    cell_group.attrs['Session'] = session_id
    
    cell_group.attrs['Visual'] = int(len(vis_sig_bin[0])!=0)
    cell_group.attrs['Auditory'] = int(len(aud_sig_bin[0])!=0)

In [8]:
%%time
for cell_run,cell_name in enumerate(cell_list):
    loop_start = time.time()
    # get information about the cell
    cell_info = cell_name.split('-')
    cell_id, rat_id, session_id, region = int(cell_info[0]), cell_info[1], cell_info[2], cell_info[5]
    
    if (rat_id=='654')&(session_id=='4'):
        continue
    
    # load cell data
    df = pd.read_csv(cell_path/cell_name)
    df.drop(df[df.Correctness==0].index,inplace=True)
    df.reset_index(inplace=True,drop=True)
    df[['Visual','Auditory']] = df[['Visual','Auditory']].fillna('no')
    
    boy_goal = df.loc[df['Visual']=='Boy','RWD_Loc'].values[0]
    boy_aud = df.loc[df['RWD_Loc']==boy_goal,'Auditory'].values[0]
    
    egg_goal = df.loc[df['Visual']=='Egg','RWD_Loc'].values[0]
    egg_aud = df.loc[df['RWD_Loc']==egg_goal,'Auditory'].values[0]
    
    # load r-squared values from subsampled dataset
    r_squared = f[f'{cell_id}/r_squared']
    r_squared_choice = f[f'{cell_id}/r_squared_choice']
    
    # load r-squared values from shuffled dataset
    r_squared_shuffle = f[f'{cell_id}/r_squared_shuffle']
    r_squared_choice_shuffle = f[f'{cell_id}/r_squared_choice_shuffle']

    # calculate r-squared difference between the extended and basic model
    r_squared_diff = np.mean(r_squared_choice,axis=0) - np.mean(r_squared,axis=0)
    r_squared_diff_shuffled = np.subtract(r_squared_choice_shuffle,r_squared_shuffle)
    r_squared_diff_shuffled = np.ravel(r_squared_diff_shuffled)

    #r_squared_diff_crit = [np.percentile(r_squared_diff_shuffled[:,x],90) for x in range(95)]
    r_squared_diff_crit = np.percentile(r_squared_diff_shuffled,85)
    response_bin = np.where(r_squared_diff>r_squared_diff_crit)[0]
    
    plot_SDF_beta(df,2,1,1,'png')
                
    loop_end = time.time()
    loop_time = divmod(loop_end-loop_start,60)
    print(cell_name.strip('.csv'), f'////// {cell_run+1}/{len(cell_list)} completed  //////  {int(loop_time[0])} min {loop_time[1]:.2f} sec')

0003-600-1-1-Crossmodal-TeV-deep-(-7.32 mm)-TT4.1 ////// 1/888 completed  //////  0 min 1.04 sec
0004-600-1-1-Crossmodal-TeV-deep-(-7.32 mm)-TT4.2 ////// 2/888 completed  //////  0 min 0.74 sec
0005-600-1-1-Crossmodal-TeV-deep-(-7.32 mm)-TT4.3 ////// 3/888 completed  //////  0 min 0.72 sec
0006-600-1-1-Crossmodal-PER-superficial-(-7.2 mm)-TT5.1 ////// 4/888 completed  //////  0 min 0.72 sec
0007-600-1-1-Crossmodal-PER-superficial-(-7.2 mm)-TT5.2 ////// 5/888 completed  //////  0 min 0.78 sec
0008-600-1-1-Crossmodal-PER-superficial-(-7.2 mm)-TT5.3 ////// 6/888 completed  //////  0 min 0.83 sec
0010-600-1-1-Crossmodal-TeV-deep-(-6.96 mm)-TT6.2 ////// 7/888 completed  //////  0 min 0.80 sec
0011-600-1-1-Crossmodal-PER-superficial-(-6.48 mm)-TT7.1 ////// 8/888 completed  //////  0 min 0.87 sec
0013-600-1-1-Crossmodal-PER-deep-(-6.48 mm)-TT8.2 ////// 9/888 completed  //////  0 min 0.82 sec
0014-600-1-1-Crossmodal-PER-deep-(-6.48 mm)-TT8.3 ////// 10/888 completed  //////  0 min 0.81 sec
0015

0093-600-5-5-Crossmodal-POR-deep-(-7.56 mm)-TT24.2 ////// 82/888 completed  //////  0 min 0.72 sec
0094-600-5-5-Crossmodal-POR-deep-(-7.56 mm)-TT24.3 ////// 83/888 completed  //////  0 min 1.76 sec
0095-600-5-5-Crossmodal-POR-deep-(-7.56 mm)-TT24.4 ////// 84/888 completed  //////  0 min 0.82 sec
0096-600-5-5-Crossmodal-POR-deep-(-7.56 mm)-TT24.5 ////// 85/888 completed  //////  0 min 0.69 sec
0174-602-1-1-Crossmodal-PER-deep-(-5.28 mm)-TT6.1 ////// 86/888 completed  //////  0 min 0.71 sec
0175-602-1-1-Crossmodal-PER-deep-(-5.28 mm)-TT6.10 ////// 87/888 completed  //////  0 min 0.84 sec
0176-602-1-1-Crossmodal-PER-deep-(-5.28 mm)-TT6.11 ////// 88/888 completed  //////  0 min 0.90 sec
0177-602-1-1-Crossmodal-PER-deep-(-5.28 mm)-TT6.2 ////// 89/888 completed  //////  0 min 0.80 sec
0178-602-1-1-Crossmodal-PER-deep-(-5.28 mm)-TT6.3 ////// 90/888 completed  //////  0 min 1.14 sec
0179-602-1-1-Crossmodal-PER-deep-(-5.28 mm)-TT6.4 ////// 91/888 completed  //////  0 min 0.94 sec
0181-602-1-1-C

0279-602-3-3-Crossmodal-PER-superficial-(-4.68 mm)-TT10.1 ////// 163/888 completed  //////  0 min 2.67 sec
0281-602-3-3-Crossmodal-PER-superficial-(-4.68 mm)-TT10.3 ////// 164/888 completed  //////  0 min 0.74 sec
0283-602-3-3-Crossmodal-PER-superficial-(-4.56 mm)-TT11.2 ////// 165/888 completed  //////  0 min 0.93 sec
0284-602-3-3-Crossmodal-PER-superficial-(-4.56 mm)-TT11.3 ////// 166/888 completed  //////  0 min 0.82 sec
0285-602-3-3-Crossmodal-PER-superficial-(-4.56 mm)-TT11.4 ////// 167/888 completed  //////  0 min 0.73 sec
0286-602-3-3-Crossmodal-PER-superficial-(-4.2 mm)-TT12.1 ////// 168/888 completed  //////  0 min 0.97 sec
0287-602-3-3-Crossmodal-PER-superficial-(-4.2 mm)-TT12.2 ////// 169/888 completed  //////  0 min 0.89 sec
0291-602-3-3-Crossmodal-PER-superficial-(-4.2 mm)-TT12.6 ////// 170/888 completed  //////  0 min 1.35 sec
0292-602-3-3-Crossmodal-PER-superficial-(-4.2 mm)-TT12.7 ////// 171/888 completed  //////  0 min 1.16 sec
0293-602-3-3-Crossmodal-PER-superficial-(

1023-640-2-2-Crossmodal-PER-deep-(-4.56 mm)-TT11.5 ////// 243/888 completed  //////  0 min 1.08 sec
1024-640-2-2-Crossmodal-PER-deep-(-4.56 mm)-TT11.6 ////// 244/888 completed  //////  0 min 1.08 sec
1025-640-2-2-Crossmodal-PER-superficial-(-4.56 mm)-TT13.1 ////// 245/888 completed  //////  0 min 0.85 sec
1026-640-2-2-Crossmodal-PER-superficial-(-4.56 mm)-TT13.2 ////// 246/888 completed  //////  0 min 1.22 sec
1027-640-2-2-Crossmodal-PER-superficial-(-4.56 mm)-TT13.3 ////// 247/888 completed  //////  0 min 1.06 sec
1028-640-2-2-Crossmodal-PER-superficial-(-4.56 mm)-TT13.4 ////// 248/888 completed  //////  0 min 0.94 sec
1029-640-2-2-Crossmodal-PER-superficial-(-4.56 mm)-TT13.5 ////// 249/888 completed  //////  0 min 1.13 sec
1030-640-2-2-Crossmodal-PER-superficial-(-4.56 mm)-TT13.6 ////// 250/888 completed  //////  0 min 0.77 sec
1033-640-2-2-Crossmodal-PER-superficial-(-4.68 mm)-TT16.1 ////// 251/888 completed  //////  0 min 0.82 sec
1034-640-2-2-Crossmodal-TeV-superficial-(-5.64 mm)-

1340-647-4-4-Crossmodal-PER-superficial-(-5.64 mm)-TT17.6 ////// 324/888 completed  //////  0 min 1.00 sec
1341-647-4-4-Crossmodal-PER-superficial-(-5.88 mm)-TT19.1 ////// 325/888 completed  //////  0 min 1.08 sec
1343-647-4-4-Crossmodal-PER-superficial-(-5.88 mm)-TT19.3 ////// 326/888 completed  //////  0 min 1.12 sec
1345-647-4-4-Crossmodal-PER-superficial-(-5.88 mm)-TT19.5 ////// 327/888 completed  //////  0 min 1.18 sec
1346-647-4-4-Crossmodal-TeV-deep-(-6.12 mm)-TT21.1 ////// 328/888 completed  //////  0 min 1.11 sec
1347-647-4-4-Crossmodal-TeV-deep-(-6.12 mm)-TT21.2 ////// 329/888 completed  //////  0 min 1.15 sec
1348-647-4-4-Crossmodal-TeV-deep-(-6.12 mm)-TT21.3 ////// 330/888 completed  //////  0 min 0.80 sec
1349-647-4-4-Crossmodal-TeV-deep-(-6.12 mm)-TT21.4 ////// 331/888 completed  //////  0 min 0.74 sec
1350-647-4-4-Crossmodal-TeV-deep-(-6.12 mm)-TT21.5 ////// 332/888 completed  //////  0 min 0.68 sec
1351-647-4-4-Crossmodal-TeV-deep-(-6.12 mm)-TT21.6 ////// 333/888 comple

1560-654-1-1-Crossmodal-Auditory-deep-(-6.24 mm)-TT24.3 ////// 404/888 completed  //////  0 min 0.76 sec
1561-654-1-1-Crossmodal-Auditory-deep-(-6.24 mm)-TT24.4 ////// 405/888 completed  //////  0 min 0.73 sec
1567-654-2-2-Crossmodal-Visual-superficial-(-6.48 mm)-TT3.1 ////// 406/888 completed  //////  0 min 0.75 sec
1568-654-2-2-Crossmodal-Visual-deep-(-6.24 mm)-TT4.1 ////// 407/888 completed  //////  0 min 0.68 sec
1569-654-2-2-Crossmodal-Visual-deep-(-6.24 mm)-TT4.2 ////// 408/888 completed  //////  0 min 0.68 sec
1570-654-2-2-Crossmodal-Visual-deep-(-5.64 mm)-TT6.1 ////// 409/888 completed  //////  0 min 0.67 sec
1571-654-2-2-Crossmodal-Visual-superficial-(-5.28 mm)-TT8.1 ////// 410/888 completed  //////  0 min 0.66 sec
1572-654-2-2-Crossmodal-PER-deep-(-5.4 mm)-TT10.1 ////// 411/888 completed  //////  0 min 0.68 sec
1574-654-2-2-Crossmodal-PER-deep-(-5.4 mm)-TT10.3 ////// 412/888 completed  //////  0 min 0.87 sec
1576-654-2-2-Crossmodal-PER-deep-(-5.4 mm)-TT10.5 ////// 413/888 com

KeyError: 'Unable to open object (component not found)'

In [9]:
f.close()
print('END')

END
