In [None]:
import os
from glob import glob
import numpy as np
import pandas as pd

from scipy import stats
from scipy.ndimage import gaussian_filter
from scipy.signal import find_peaks

import pingouin

import matplotlib as mpl
import matplotlib.pyplot as plt  
import seaborn as sns

from datetime import date

import random

In [None]:
# no top and right spines in all plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [None]:
mother_path = 'C:/LHY/Multi-modal project/'

# Parameter setting

In [None]:
sig_alpha = 0.01
sig_cohend = 0.4
cons_bin_crit = 5

gauss_sigma = 2
gauss_on = True

shade_on = True

 # colors for multimodal, vis-only, aud-only conditions
#colorpalette = ['mediumorchid','cornflowerblue','lightcoral','gray','green','orange']
colorpalette = ['tab:purple','tab:blue','tab:red','tab:gray','tab:green','tab:orange']

today = str(date.today())

# Data preparation

In [None]:
data_path = mother_path+'analysis/result/3.1. Item-selectivity/2023-01-06 (d=0.5)/'
df = pd.read_csv(data_path+'2023-01-06_item-selectivity.csv')

df.insert(5,'Cat','')

for i in range(len(df)):
    cat_info = ''
    if df.loc[i,'M_sig'] == 1:
        cat_info += 'M'
    if df.loc[i,'V_sig'] == 1:
        cat_info += 'V'
    if df.loc[i,'A_sig'] == 1:
        cat_info += 'A'
    if df.loc[i,'C_sig'] == 1:
        cat_info += 'C'
    if cat_info=='':
        df.loc[i,'Cat'] = 'None'
    else:
        df.loc[i,'Cat'] = cat_info

MVAC = df.drop(index=df[df['Cat']!='MVAC'].index)

In [None]:
save_path = mother_path+'analysis/result/6. Response coding/'
data_path = mother_path+'analysis/result/zFR export/13-Apr-2022 (5 trials)'
os.chdir(data_path)

In [None]:
# empty list for result csv file
result=[]

In [None]:
cell_list = os.listdir(os.curdir)

# Data analysis

In [None]:
def SDF_plot(cond,ax,colorpalette,linewidth,shade_on):
    x = np.arange(95)*10
    ax.axvline(x=400,color='black',linewidth=1,linestyle=':')
    for i in cond:
        if (i==0)|(i==1):
            cond, current_color = 0, colorpalette[0]
        elif (i==2)|(i==3):
            cond, current_color = 1, colorpalette[1]
        elif (i==4)|(i==5):
            cond, current_color = 2, colorpalette[2]
        elif (i==6)|(i==7):
            cond, current_color = 3, colorpalette[3]
        elif (i==8)|(i==9):
            cond, current_color = 4, colorpalette[4]            
        elif (i==10)|(i==11):
            cond, current_color = 5, colorpalette[5]     
            
        if np.mod(i,2)==0:
            current_style='-'
        else:
            current_style=':'
            
        ax.plot(x,mean_by_cond[i,:],color=current_color,
                linewidth=linewidth,linestyle=current_style)
        if shade_on:
            ax.fill_between(x, mean_by_cond[i,:]-sem_by_cond[i,:],
                             mean_by_cond[i,:]+sem_by_cond[i,:],
                             color=current_color, alpha=0.2)
        else:
            ax.errorbar(x, mean_by_cond[i,:],yerr=sem_by_cond[i,:],xerr=sem_by_cond[i,:],
                        ecolor=current_color,elinewidth=1,fmt='None',alpha=0.5)
    #if is_sig[cond]!=0:
    #    ax.axvline(x=peaktime[cond]*10,ymax=0.1,color='black',linewidth=3)    
    ax.tick_params(axis='y',which='both',direction='in')  
    ax.set_xticks([0,200,400,600,800,950])
    ax.set_xlim([0,950])
    ax.set_yticks(np.arange(y_min,(y_max+1)))
    ax.set_ylim([y_min,y_max])
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('z-scored FR')
    ax.scatter(sig_field[cond]*10,np.tile(y_max-0.2,(len(sig_field[cond]),1)),
            s=10,marker='_',c='black',linewidth=2)    

In [None]:
def find_sig_field(sig_bin, cohend):
    cohend = np.abs(cohend)
    field_bin = []
    for i in sig_bin:
        sig_chunk = []
        for j in range(i-1, -1, -1):
            if cohend[j] > sig_cohend:
                sig_chunk.append(j)
            else:
                break
        for j in range(i, 95):
            if cohend[j] > sig_cohend:
                sig_chunk.append(j)
            else:
                break
        if len(sig_chunk) >= cons_bin_crit:
            field_bin.extend(sig_chunk)
    result = np.unique(field_bin)
    return result.astype(int)

In [None]:
for cell_run in MVAC.Key.values:

    os.chdir(data_path)
    
    if cell_run < 10:
        cell_key = '000'+str(cell_run)
    elif cell_run < 100:
        cell_key = '00'+str(cell_run)
    elif cell_run < 1000:
        cell_key = '0'+str(cell_run)
    else:
        cell_key = str(cell_run)    
    cell_name = glob(cell_key+'-*')
    
    if not cell_name:
        continue

    data = pd.read_csv(cell_name[0])

    # get information about the cell
    cell_info = cell_name[0].split('-')
    cell_id = int(cell_info[0])
    rat_id = cell_info[1]
    session_id = cell_info[2]
    region = cell_info[5]

    cond = [(data.Type=='Multimodal')&(data.RWD_Loc=='Left')&(data.Correctness==1),
            (data.Type=='Multimodal')&(data.RWD_Loc=='Right')&(data.Correctness==1),
            (data.Type=='Visual')&(data.RWD_Loc=='Left')&(data.Correctness==1),
            (data.Type=='Visual')&(data.RWD_Loc=='Right')&(data.Correctness==1),
            (data.Type=='Auditory')&(data.RWD_Loc=='Left')&(data.Correctness==1),
            (data.Type=='Auditory')&(data.RWD_Loc=='Right')&(data.Correctness==1),
            (data.Type=='Elemental')&(data.RWD_Loc=='Left')&(data.Correctness==1),
            (data.Type=='Elemental')&(data.RWD_Loc=='Right')&(data.Correctness==1),
            (data.RWD_Loc=='Left')&(data.Correctness==1),
            (data.RWD_Loc=='Right')&(data.Correctness==1),
            (data.RWD_Loc=='Right')&(data.Correctness==0),
            (data.RWD_Loc=='Left')&(data.Correctness==0)]

    mean_by_cond = np.zeros((12,95))
    sem_by_cond = np.zeros((12,95))

    for i in range(12):
        mean_by_cond[i,:]=data[cond[i]].iloc[:,9:].to_numpy().mean(axis=0)
        sem_by_cond[i,:]=stats.sem(data[cond[i]].iloc[:,9:].to_numpy())   
    
    pval_map = np.zeros((6,95))
    sig_bin = np.zeros((6,95))
    
    cohend_by_cond = np.zeros((6,95))
    sig_field = dict()
    
    for i in range(6):
        for j in range(95):
            comp_data = data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]
            if all(x==comp_data.iloc[0] for x in comp_data):
                pval_map[i,j] = 1
                cohend_by_cond[i,j] = 0
            else:
                pval_map[i,j] = stats.ttest_ind(data[cond[i*2]].iloc[:,9+j],data[cond[i*2+1]].iloc[:,9+j])[1]
                cohend_by_cond[i,j] = (mean_by_cond[i*2,j]-mean_by_cond[i*2+1,j])/(data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]).std()
            sig_bin[i,j] = (pval_map[i,j] < sig_alpha)*np.sign(data[cond[i*2]].iloc[:,9+j].mean()-data[cond[i*2+1]].iloc[:,9+j].mean())    

    pval_map[np.isinf(pval_map)] = 1
    pval_map[np.isnan(pval_map)] = 1
    cohend_by_cond[np.isinf(cohend_by_cond)] = 0        
    cohend_by_cond[np.isnan(cohend_by_cond)] = 0     
       
    sig_field_count = np.zeros((12,1))
    SI = np.zeros((12,1))
    peaktime = np.zeros((6,1))
    binary = np.zeros((6,95))
    for i in range(6):
        sig_field[i] = find_sig_field(np.nonzero(sig_bin[i,:])[0],cohend_by_cond[i,:])
        if sig_field[i].any():
            peaktime[i] = sig_field[i][np.argmax(np.abs(cohend_by_cond[i,sig_field[i]]))]
            for j in sig_field[i]:
                if cohend_by_cond[i,j] > 0:
                    sig_field_count[2*i] += 1
                    SI[2*i] += cohend_by_cond[i,j]
                    binary[i,j] = 1
                elif cohend_by_cond[i,j] < 0:
                    sig_field_count[2*i+1] += 1        
                    SI[2*i+1] += np.abs(cohend_by_cond[i,j])
                    binary[i,j] = -1                    
    
    is_sig = np.zeros((4,1))
    for i in range(4):
        is_sig[i] = int((sig_field_count[2*i]!=0)|(sig_field_count[2*i+1]!=0))        
    
    # Selectivity correlation between correct and incorrect trials
    #sig_map_corr = np.zeros((1,95))
    #sig_map_incorr = np.zeros((1,95))    
    #for i in range(95):
    #    if binary[4,i]!=0:
    #        sig_map_corr[0,i] = cohend_by_cond[4,i]
    #    if binary[5,i]!=0:
    #        sig_map_incorr[0,i] = cohend_by_cond[5,i]
    #corr = stats.pearsonr(sig_map_corr[0],sig_map_incorr[0]) 
    
    #corr = stats.pearsonr(cohend_by_cond[4,:],cohend_by_cond[5,:]) 
    
    corr = stats.pearsonr(cohend_by_cond[4,40:],cohend_by_cond[5,40:]) 
    
    #if np.sum(cohend_by_cond[4,40:])>0:
    #    corr = stats.pearsonr(mean_by_cond[8,40:],mean_by_cond[10,40:])
    #else:
    #    corr = stats.pearsonr(mean_by_cond[9,40:],mean_by_cond[11,40:])                                                        
    
    # gaussian smoothing
    if gauss_on:
        for i in range(12):
            mean_by_cond[i,:] = gaussian_filter(mean_by_cond[i,:],sigma=gauss_sigma)
            sem_by_cond[i,:] = gaussian_filter(sem_by_cond[i,:],sigma=gauss_sigma)     
    
    y_max = np.round(mean_by_cond.max()+sem_by_cond.max())
    y_min = np.round(mean_by_cond.min()-sem_by_cond.max()) 
    
    fig = plt.figure(figsize=(8,12))
    plt.suptitle(cell_name[0].strip('.csv'),fontsize=15)
    gs = mpl.gridspec.GridSpec(nrows=4,ncols=2)
    
    x = np.arange(95)*10
    
    
    ax0 = plt.subplot(gs[0,0])
    SDF_plot([0,1],ax0,colorpalette,2,shade_on)
    ax0.set_title('Multimodal')
    ax0.set_xlabel('Time (ms)')
    
    ax1 = plt.subplot(gs[1,0])
    SDF_plot([2,3],ax1,colorpalette,2,shade_on)
    ax1.set_title('Visual-only')
    ax1.set_xlabel('Time (ms)')
    
    ax2 = plt.subplot(gs[2,0])
    SDF_plot([4,5],ax2,colorpalette,2,shade_on)
    ax2.set_title('Auditory-only')
    ax2.set_xlabel('Time (ms)')
    
    ax3 = plt.subplot(gs[3,0])
    SDF_plot([6,7],ax3,colorpalette,2,shade_on)
    ax3.set_title('Control')
    ax3.set_xlabel('Time (ms)')
    
    ax4 = plt.subplot(gs[0,1])
    SDF_plot([8,9],ax4,colorpalette,2,shade_on)
    ax4.set_title('Correct trials (all)')
    ax4.set_xlabel('Time (ms)')
    
    ax5 = plt.subplot(gs[1,1])
    SDF_plot([10,11],ax5,colorpalette,2,shade_on)
    ax5.set_title('Incorrect trials (all)')
    ax5.set_xlabel('Time (ms)')    
    
    ax6 = plt.subplot(gs[2,1])
    ax6.axis('off')
    ax6.text(0.1,0.9,round(corr[0]*1000)/1000,fontsize=14)    
    
    fig.tight_layout()
    fig.subplots_adjust(top=0.92)    
    
    if corr[0]>0.5:
        final_save_path = save_path+str(date.today())+'/'+region+'/response-selective'
    else:
        final_save_path = save_path+str(date.today())+'/'+region+'/no response-selective'
    if os.path.exists(final_save_path) is False:
        os.makedirs(final_save_path)    
    os.chdir(final_save_path)
    plt.savefig(cell_name[0].strip('.csv')+'.png',dpi=100,facecolor='white')
    #plt.savefig(cell_name[0].strip('.csv')+'.svg')
    plt.close()
    
    cell_result = {'Key':cell_id,
                   'RatID':rat_id,
                   'Session':session_id,
                   'Region':region,
                   'Response_corr':round(corr[0]*1000)/1000}

    result.append(cell_result)        

In [None]:
result = pd.DataFrame(result)

os.chdir(save_path+str(date.today()))
result.to_csv(str(date.today())+'_response_corr.csv',index=False)


### Comparing selectivity between response and object cells

In [None]:
df['M_size'] = df['M_L_field_size']+df['M_R_field_size']
df['V_size'] = df['V_L_field_size']+df['V_R_field_size']
df['A_size'] = df['A_L_field_size']+df['A_R_field_size']


df_choice = df[df.Key.isin(choice_cell_list)].copy()
df_choice.shape

df_nochoice = df[~df.Key.isin(choice_cell_list)].copy()

I need to make a code that load the list of response cells here.

In [None]:
fig,ax = plt.subplots(3,1,figsize=(5,8))

x = ['Others','Choice']
cond = ['M','V','A']

for i in range(3):
    y1 = df_nochoice.loc[df_nochoice[cond[i]+'_size']!=0,cond[i]+'_size']
    y2 = df_choice[cond[i]+'_size']
    y = [y1.mean(), y2.mean()]
    y_err = [stats.sem(y1), stats.sem(y2)]
    
    ax[i].barh(x,y,xerr=y_err,color=['gray','black'])
    ax[i].set_xticks([0,10,20,30,40,50])
    ax[i].set_xlim([0,50])
    ax[i].set_xticklabels([0,100,200,300,400,500])
    if i == 2:
        ax[i].set_xlabel('Selectivity size (ms)')
    
    tt = stats.ttest_ind(y1,y2)
    ax[i].text(30,0,'pval ='+str(round(tt.pvalue*(10**5))/(10**5)))
    
    
plt.tight_layout()
plt.savefig(save_path+'choice cell selectivity.svg')