In [1]:
import os
import numpy as np
import pandas as pd

from scipy import stats
from scipy.ndimage import gaussian_filter
from scipy.signal import find_peaks

import pingouin

import matplotlib as mpl
import matplotlib.pyplot as plt  
import seaborn as sns

from datetime import date
import random
import warnings

In [2]:
warnings.filterwarnings("ignore", category=UserWarning)

In [3]:
# no top and right spines in all plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [4]:
mother_path = 'D:/Multi-modal project/'

# Parameter setting

In [5]:
sig_alpha = 0.01
sig_cohend = 0.4
cons_bin_crit = 5

gauss_sigma = 2
gauss_on = True

only_PER = False   # analyze only PER neurons?

condition = ['Multimodal','Visual','Auditory','Control']
 # colors for multimodal, vis-only, aud-only conditions
colorpalette = ['mediumorchid','cornflowerblue','lightcoral','gray']

cmap = 'PiYG'

today = str(date.today())

# Data preparation

In [6]:
unit_summary = pd.read_table(mother_path+'/analysis/result'+
                             '/1. Cluster summary/clusterSummary.csv',sep=',')
save_path = mother_path+'analysis/result/4. Selectivity correlations/'
data_path = mother_path+'analysis/result/zFR export/13-Apr-2022 (5 trials)/'
os.chdir(data_path)

In [7]:
# empty list for result csv file
result=[]

In [8]:
cell_list = os.listdir(os.curdir)

# Data analysis

In [9]:
def SDF_plot(cond,ax,colorpalette,linewidth,shade_on):
    x = np.arange(95)*10
    ax.axvline(x=400,color='black',linewidth=1,linestyle=':')
    for i in cond:
        if (i==0)|(i==1):
            cond, current_color = 0, colorpalette[0]
        elif (i==2)|(i==3):
            cond, current_color = 1, colorpalette[1]
        elif (i==4)|(i==5):
            cond, current_color = 2, colorpalette[2]
        elif (i==6)|(i==7):
            cond, current_color = 3, colorpalette[3]
            
        if np.mod(i,2)==0:
            current_style='-'
        else:
            current_style=':'
            
        ax.plot(x,mean_by_cond[i,:],color=current_color,
                linewidth=linewidth,linestyle=current_style)
        if shade_on:
            ax.fill_between(x, mean_by_cond[i,:]-sem_by_cond[i,:],
                             mean_by_cond[i,:]+sem_by_cond[i,:],
                             color=current_color, alpha=0.2)
        else:
            ax.errorbar(x, mean_by_cond[i,:],yerr=sem_by_cond[i,:],xerr=sem_by_cond[i,:],
                        ecolor=current_color,elinewidth=1,fmt='None',alpha=0.5)
    if is_sig[cond]!=0:
        ax.axvline(x=peaktime[cond]*10,ymax=0.1,color='black',linewidth=3)    
    ax.set_xlim([0,950])
    ax.tick_params(axis='y',which='both',direction='in')    
    ax.tick_params(axis='x',which='both',bottom=False,top=False,labelbottom=False)    
    ax.set_yticks(np.arange(y_min,(y_max+1)))
    ax.set_ylim([y_min,y_max])
    ax.set_ylabel('z-scored FR')
    ax.scatter(sig_field[cond]*10,np.tile(y_max-0.2,(len(sig_field[cond]),1)),
            s=10,marker='_',c='black',linewidth=2)    

In [10]:
def find_sig_field(sig_bin, cohend):
    cohend = np.abs(cohend)
    field_bin = []
    for i in sig_bin:
        sig_chunk = []
        for j in range(i-1, -1, -1):
            if cohend[j] > sig_cohend:
                sig_chunk.append(j)
            else:
                break
        for j in range(i, 95):
            if cohend[j] > sig_cohend:
                sig_chunk.append(j)
            else:
                break
        if len(sig_chunk) >= cons_bin_crit:
            field_bin.extend(sig_chunk)
    result = np.unique(field_bin)
    return result.astype(int)

In [11]:
def selectivity_map(cond,ax):
    this_map = cohend_sig[cond,:]
    new_map = this_map[np.newaxis,:]
    ax.imshow(new_map,aspect="auto",cmap=cmap,vmin=-1,vmax=1);
    ax.tick_params(axis='y',which='both',left=False,labelleft=False)
    ax.set_xticks([0,20,40,60,80,95])
    ax.set_xticklabels([0,200,400,600,800,950])
    ax.set_xlabel('Time (ms)')

### Main analysis
Plot object-selectivity map along with SDF and record correlations of selectivity map between modality conditions

In [12]:
for cell_run in cell_list:
    os.chdir(data_path)
    data = pd.read_csv(cell_run)
    # we only use correct trials in the current analysis
    data.drop(data[data.Correctness==0].index,inplace=True)
    data.iloc[:,9:] = round(data.iloc[:,9:]*10000)/10000

    # get information about the cell
    cell_info = cell_run.split('-')
    cell_id = int(cell_info[0])
    rat_id = cell_info[1]
    session_id = cell_info[2]
    region = cell_info[5]

    cond = [(data.Type=='Multimodal')&(data.RWD_Loc=='Left'),
            (data.Type=='Multimodal')&(data.RWD_Loc=='Right'),
            (data.Type=='Visual')&(data.RWD_Loc=='Left'),
            (data.Type=='Visual')&(data.RWD_Loc=='Right'),
            (data.Type=='Auditory')&(data.RWD_Loc=='Left'),
            (data.Type=='Auditory')&(data.RWD_Loc=='Right'),
            (data.Type=='Elemental')&(data.RWD_Loc=='Left'),
            (data.Type=='Elemental')&(data.RWD_Loc=='Right')]      
            
    mean_by_cond = np.zeros((8,95))
    sem_by_cond = np.zeros((8,95))
    for i in range(8):
        mean_by_cond[i,:]=data[cond[i]].iloc[:,9:].to_numpy().mean(axis=0)
        sem_by_cond[i,:]=stats.sem(data[cond[i]].iloc[:,9:].to_numpy())    
    
    pval_map = np.zeros((4,95))
    sig_bin = np.zeros((4,95))
    
    cohend_by_cond = np.zeros((4,95))
    sig_field = dict()
    
    for i in range(4):
        for j in range(95):
            comp_data = data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]
            if all(x==comp_data.iloc[0] for x in comp_data):
                pval_map[i,j] = 1
                cohend_by_cond[i,j] = 0
            else:
                pval_map[i,j] = stats.ttest_ind(data[cond[i*2]].iloc[:,9+j],data[cond[i*2+1]].iloc[:,9+j])[1]
                cohend_by_cond[i,j] = (mean_by_cond[i*2,j]-mean_by_cond[i*2+1,j])/(data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]).std()
            sig_bin[i,j] = (pval_map[i,j] < sig_alpha)*np.sign(data[cond[i*2]].iloc[:,9+j].mean()-data[cond[i*2+1]].iloc[:,9+j].mean())    

    pval_map[np.isinf(pval_map)] = 1
    pval_map[np.isnan(pval_map)] = 1
    cohend_by_cond[np.isinf(cohend_by_cond)] = 0        
    cohend_by_cond[np.isnan(cohend_by_cond)] = 0     
       
    sig_field_count = np.zeros((8,1))
    SI = np.zeros((8,1))
    peaktime = np.zeros((4,1))
    binary = np.zeros((4,95))
    for i in range(4):
        sig_field[i] = find_sig_field(np.nonzero(sig_bin[i,:])[0],cohend_by_cond[i,:])
        if sig_field[i].any():
            peaktime[i] = sig_field[i][np.argmax(np.abs(cohend_by_cond[i,sig_field[i]]))]
            for j in sig_field[i]:
                if cohend_by_cond[i,j] > 0:
                    sig_field_count[2*i] += 1
                    SI[2*i] += cohend_by_cond[i,j]
                    binary[i,j] = 1
                elif cohend_by_cond[i,j] < 0:
                    sig_field_count[2*i+1] += 1        
                    SI[2*i+1] += np.abs(cohend_by_cond[i,j])
                    binary[i,j] = -1                    
    
    is_sig = np.zeros((4,1))
    for i in range(4):
        is_sig[i] = int((sig_field_count[2*i]!=0)|(sig_field_count[2*i+1]!=0))
    
    # Selectivity map that has Cohen's d values only in significant time bins.
    cohend_sig = np.zeros((4,95))
    for i in range(4):
        for j in range(95):
            if binary[i,j]!=0:
                cohend_sig[i,j] = cohend_by_cond[i,j]    
    
    # Calculate correlations between selectivity map
    #MV_corr = stats.pearsonr(cohend_sig[0,:],cohend_sig[1,:])[0]
    #MA_corr = stats.pearsonr(cohend_sig[0,:],cohend_sig[2,:])[0]
    #VA_corr = stats.pearsonr(cohend_sig[1,:],cohend_sig[2,:])[0]
    
    MV_corr = stats.kendalltau(cohend_sig[0,:],cohend_sig[1,:])[0]
    MA_corr = stats.kendalltau(cohend_sig[0,:],cohend_sig[2,:])[0]
    VA_corr = stats.kendalltau(cohend_sig[1,:],cohend_sig[2,:])[0]    
    
    # gaussian smoothing
    if gauss_on:
        for i in range(8):
            mean_by_cond[i,:] = gaussian_filter(mean_by_cond[i,:],sigma=gauss_sigma)
            sem_by_cond[i,:] = gaussian_filter(sem_by_cond[i,:],sigma=gauss_sigma)          
    
    y_max = np.round(mean_by_cond.max()+sem_by_cond.max())
    y_min = np.round(mean_by_cond.min()-sem_by_cond.max())        
        
    #fig = plt.figure(figsize=(12,6))
    fig = plt.figure(figsize=(4,12))
    plt.suptitle(cell_run.strip('.csv'),fontsize=9)
    gs = mpl.gridspec.GridSpec(4,1,height_ratios=[1,1,1,0.5],hspace=0.5)
    #gs = mpl.gridspec.GridSpec(nrows=3,ncols=3,height_ratios=[4,1,2])
    
    for i in range(3):
        inner = mpl.gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs[i],height_ratios=[3,1],
                                                    hspace=0.1)
        ax = plt.Subplot(fig,inner[0])
        SDF_plot([i*2,i*2+1],ax,colorpalette,2,True)
        ax.set_title(condition[i],fontsize=14)
        fig.add_subplot(ax)
                
        ax2 = plt.Subplot(fig,inner[1])
        selectivity_map(i,ax2) 
        fig.add_subplot(ax2)

    ax3 = plt.subplot(gs[3,0])
    ax3.axis('off')
    ax3.text(0,0.8,f'M-V correlation: {MV_corr:.2f}',fontsize=14)
    ax3.text(0,0.5,f'M-A correlation: {MA_corr:.2f}',fontsize=14)
    ax3.text(0,0.2,f'V-A correlation: {VA_corr:.2f}',fontsize=14)
    
    fig.subplots_adjust(top=0.92)
    fig.tight_layout()
    
    if os.path.exists(save_path+str(date.today())+'/'+region) is False:
        os.makedirs(save_path+str(date.today())+'/'+region)    
    os.chdir(save_path+str(date.today())+'/'+region)
    #plt.savefig(cell_run.strip('.csv')+'.png',dpi=150,facecolor='white')
    plt.savefig(cell_run.strip('.csv')+'.svg')
    plt.close()
    
    # make csv output
    cell_result = {'Key':cell_id,
                   'RatID':rat_id,
                   'Session':session_id,
                   'Region':region,
                   'MV corr': MV_corr,
                   'MA corr': MA_corr,
                   'VA corr': VA_corr}

    result.append(cell_result)    

### Export CSV file

In [13]:
result = pd.DataFrame(result)

save_path+str(date.today())+'/'+region
os.chdir(save_path+str(date.today()))

result.to_csv(str(date.today())+'_selectivity_corr.csv',index=False)

In [14]:
print('END')

END
