In [1]:
import os
import numpy as np
import pandas as pd

from scipy import stats
from scipy.ndimage import gaussian_filter
from scipy.signal import find_peaks

import pingouin

import matplotlib as mpl
import matplotlib.pyplot as plt  
import seaborn as sns

from datetime import date

import random

In [2]:
# no top and right spines in all plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [3]:
mother_path = 'D:/Multi-modal project/'

# Parameter setting

In [4]:
sig_alpha = 0.01
sig_cohend = 0.4
cons_bin_crit = 5

gauss_sigma = 2
gauss_on = True

only_PER = False   # analyze only PER neurons?

 # colors for multimodal, vis-only, aud-only conditions
colorpalette = ['mediumorchid','cornflowerblue','lightcoral','gray']

today = str(date.today())

# Data preparation

In [5]:
unit_summary = pd.read_table(mother_path+'/analysis/result'+
                             '/1. Cluster summary/clusterSummary.csv',sep=',')
save_path = mother_path+'analysis/result/3. Item-selectivity/'
data_path = mother_path+'analysis/result/zFR export/13-Apr-2022 (5 trials)/'
os.chdir(data_path)

In [6]:
# empty list for result csv file
result=[]

In [7]:
cell_list = os.listdir(os.curdir)

# Data analysis

In [8]:
def SDF_plot(cond,ax,colorpalette,linewidth,shade_on):
    x = np.arange(95)*10
    ax.axvline(x=400,color='black',linewidth=1,linestyle=':')
    for i in cond:
        if (i==0)|(i==1):
            cond, current_color = 0, colorpalette[0]
        elif (i==2)|(i==3):
            cond, current_color = 1, colorpalette[1]
        elif (i==4)|(i==5):
            cond, current_color = 2, colorpalette[2]
        elif (i==6)|(i==7):
            cond, current_color = 3, colorpalette[3]
            
        if np.mod(i,2)==0:
            current_style='-'
        else:
            current_style=':'
            
        ax.plot(x,mean_by_cond[i,:],color=current_color,
                linewidth=linewidth,linestyle=current_style)
        if shade_on:
            ax.fill_between(x, mean_by_cond[i,:]-sem_by_cond[i,:],
                             mean_by_cond[i,:]+sem_by_cond[i,:],
                             color=current_color, alpha=0.2)
        else:
            ax.errorbar(x, mean_by_cond[i,:],yerr=sem_by_cond[i,:],xerr=sem_by_cond[i,:],
                        ecolor=current_color,elinewidth=1,fmt='None',alpha=0.5)
    if is_sig[cond]!=0:
        ax.axvline(x=peaktime[cond]*10,ymax=0.1,color='black',linewidth=3)    
    ax.xticks=[0,200,400,600,800,950]
    ax.tick_params(axis='y',which='both',direction='in')    
    ax.set_xlim([0,950])
    ax.set_yticks(np.arange(y_min,(y_max+1)))
    ax.set_ylim([y_min,y_max])
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('z-scored FR')
    ax.scatter(sig_field[cond]*10,np.tile(y_max-0.2,(len(sig_field[cond]),1)),
            s=10,marker='_',c='black',linewidth=2)    

In [9]:
def pval_plot(cond,ax,linewidth):
    ax.plot(np.arange(95)*10,np.log(pval_map[cond,:])*-1,
            color='darkgray',marker='o',fillstyle='none', markersize=4,
            linewidth=linewidth)
    ax.axvline(x=400,color='black',linewidth=1,linestyle=':')
    ax.axhline(y=np.log(0.05)*-1,color='r',linestyle=':')
    ax.annotate('p=0.05',(10,np.log(0.05)*-1-0.15))
    ax.axhline(y=np.log(0.01)*-1,color='r',linestyle=':')
    ax.annotate('p=0.01',(10,np.log(0.01)*-1-0.15))
    ax.axhline(y=np.log(0.001)*-1,color='r',linestyle=':')
    ax.annotate('p=0.001',(10,np.log(0.001)*-1-0.15))  

    nonan = pval_map[cond,~np.isnan(pval_map[cond,:])]    
    m = max(np.log(nonan)*-1)
    if m < np.log(0.001)*-1:
        m = 6
    else:
        m = round(m)-1
    ax.scatter(sig_field[cond]*10,np.tile(m,(len(sig_field[cond]),1)),
            s=10,marker='_',c='black',linewidth=2)
    if is_sig[cond]!=0:
        ax.axvline(x=peaktime[cond]*10,ymax=0.1,color='black',linewidth=3)    
    ax.set_ylabel('-log(p-value)')
    ax.xticks=[0,200,400,600,800,950]    
    ax.set_title('p-value plot')
    ax.set_xlabel('Time (ms)')

In [10]:
def cohend_plot(cond,ax,linewidth):
    ax.plot(np.arange(95)*10,np.abs(cohend_by_cond[cond,:]),
            color='forestgreen',marker='o',fillstyle='none',markersize=4,linewidth=linewidth)
    ax.axvline(x=400,color='black',linewidth=1,linestyle=':')    
    ax.axhline(y=sig_cohend,color='r',linestyle=':')
    ax.annotate(str(sig_cohend),(10,sig_cohend))
    
    noinf = cohend_by_cond[cond,~np.isinf(cohend_by_cond[cond,:])]
    m = round(max(np.abs(noinf))*10)/10-0.05
    ax.scatter(sig_field[cond]*10,np.tile(m,(len(sig_field[cond]),1)),
            s=10,marker='_',c='black',linewidth=2)    
    
    if is_sig[cond]!=0:
        ax.axvline(x=peaktime[cond]*10,ymax=0.1,color='black',linewidth=3)    
    ax.set_ylabel('Cohen\'s d')
    ax.xticks=[0,200,400,600,800,950]
    ax.set_title('Cohen\'s d plot')
    ax.set_xlabel('Time (ms)')

In [11]:
def find_sig_field(sig_bin,cohend):
    cohend = np.abs(cohend)
    field_bin = []
    for i in sig_bin:
        sig_chunk = []
        now_bin = i
        goback = True
        while goback is True:
            if now_bin < 0:
                goback = False
            elif cohend[now_bin] > sig_cohend:
                sig_chunk = np.append(sig_chunk,now_bin)
                now_bin -= 1
            else:
                goback = False
                
        now_bin = i
        goforward = True
        while goforward is True:
            if now_bin > 94:
                goforward = False
            elif cohend[now_bin] > sig_cohend:
                sig_chunk = np.append(sig_chunk,now_bin)
                now_bin += 1
            else:
                goforward = False
        if (np.max(sig_chunk)-np.min(sig_chunk))<(cons_bin_crit-1):
            sig_chunk = []
        field_bin = np.append(field_bin,sig_chunk)
    result = np.array(list(set(field_bin)))
    return result.astype(int)

In [12]:
cohend_result_M = np.empty((0,97))
cohend_result_V = np.empty((0,97))
cohend_result_A = np.empty((0,97))
cohend_result_C = np.empty((0,97))

sig_bin_result_M = np.empty((0,96))
sig_bin_result_V = np.empty((0,96))
sig_bin_result_A = np.empty((0,96))
sig_bin_result_C = np.empty((0,96))

In [13]:
for cell_run in cell_list:
    os.chdir(data_path)
    data = pd.read_csv(cell_run)
    # we only use correct trials in the current analysis
    data.drop(data[data.Correctness==0].index,inplace=True)
    
    data.iloc[:,9:] = round(data.iloc[:,9:]*10000)/10000
    
    # get information about the cell
    cell_info = cell_run.split('-')
    cell_id = int(cell_info[0])
    rat_id = cell_info[1]
    session_id = cell_info[2]
    region = cell_info[5]

    cond = [(data.Type=='Multimodal')&(data.RWD_Loc=='Left'),
            (data.Type=='Multimodal')&(data.RWD_Loc=='Right'),
            (data.Type=='Visual')&(data.RWD_Loc=='Left'),
            (data.Type=='Visual')&(data.RWD_Loc=='Right'),
            (data.Type=='Auditory')&(data.RWD_Loc=='Left'),
            (data.Type=='Auditory')&(data.RWD_Loc=='Right'),
            (data.Type=='Elemental')&(data.RWD_Loc=='Left'),
            (data.Type=='Elemental')&(data.RWD_Loc=='Right')]    
    
    if only_PER:
        if region!='PER':  # analyze only PER neurons?
            continue;    
            
    mean_by_cond = np.zeros((8,95))
    sem_by_cond = np.zeros((8,95))
    for i in range(8):
        mean_by_cond[i,:]=data[cond[i]].iloc[:,9:].to_numpy().mean(axis=0)
        sem_by_cond[i,:]=stats.sem(data[cond[i]].iloc[:,9:].to_numpy())    
    
    pval_map = np.zeros((4,95))
    sig_bin = np.zeros((4,95))
    
    cohend_by_cond = np.zeros((4,95))
    sig_field = dict()
    
    for i in range(4):
        for j in range(95):
            comp_data = data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]
            if all(x==comp_data.iloc[0] for x in comp_data):
                pval_map[i,j] = 1
                cohend_by_cond[i,j] = 0
            else:
                pval_map[i,j] = stats.ttest_ind(data[cond[i*2]].iloc[:,9+j],data[cond[i*2+1]].iloc[:,9+j])[1]
                cohend_by_cond[i,j] = (mean_by_cond[i*2,j]-mean_by_cond[i*2+1,j])/(data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]).std()
            sig_bin[i,j] = (pval_map[i,j] < sig_alpha)*np.sign(data[cond[i*2]].iloc[:,9+j].mean()-data[cond[i*2+1]].iloc[:,9+j].mean())    

    pval_map[np.isinf(pval_map)] = 1
    pval_map[np.isnan(pval_map)] = 1
    cohend_by_cond[np.isinf(cohend_by_cond)] = 0        
    cohend_by_cond[np.isnan(cohend_by_cond)] = 0     
       
    sig_field_count = np.zeros((8,1))
    SI = np.zeros((8,1))
    peaktime = np.zeros((4,1))
    binary = np.zeros((4,95))
    for i in range(4):
        sig_field[i] = find_sig_field(np.nonzero(sig_bin[i,:])[0],cohend_by_cond[i,:])
        if sig_field[i].any():
            peaktime[i] = sig_field[i][np.argmax(np.abs(cohend_by_cond[i,sig_field[i]]))]
            for j in sig_field[i]:
                if cohend_by_cond[i,j] > 0:
                    sig_field_count[2*i] += 1
                    SI[2*i] += cohend_by_cond[i,j]
                    binary[i,j] = 1
                elif cohend_by_cond[i,j] < 0:
                    sig_field_count[2*i+1] += 1        
                    SI[2*i+1] += np.abs(cohend_by_cond[i,j])
                    binary[i,j] = -1                    
    
    is_sig = np.zeros((4,1))
    for i in range(4):
        is_sig[i] = int((sig_field_count[2*i]!=0)|(sig_field_count[2*i+1]!=0))
    
    cohend_result_M = np.append(cohend_result_M,
                                np.append([cell_id,peaktime[0][0]],
                                          cohend_by_cond[0,:].reshape((95,1))).reshape((1,97)),axis=0)
    cohend_result_V = np.append(cohend_result_V,
                                np.append([cell_id,peaktime[1][0]],
                                          cohend_by_cond[1,:].reshape((95,1))).reshape((1,97)),axis=0)
    cohend_result_A = np.append(cohend_result_A,
                                np.append([cell_id,peaktime[2][0]],
                                          cohend_by_cond[2,:].reshape((95,1))).reshape((1,97)),axis=0)
    cohend_result_C = np.append(cohend_result_C,
                                np.append([cell_id,peaktime[3][0]],
                                          cohend_by_cond[3,:].reshape((95,1))).reshape((1,97)),axis=0)
    
    sig_bin_result_M = np.append(sig_bin_result_M,
                                 np.append(cell_id,binary[0,:].reshape((95,1))).reshape((1,96)),axis=0)
    sig_bin_result_V = np.append(sig_bin_result_V,
                                 np.append(cell_id,binary[1,:].reshape((95,1))).reshape((1,96)),axis=0)
    sig_bin_result_A = np.append(sig_bin_result_A,
                                 np.append(cell_id,binary[2,:].reshape((95,1))).reshape((1,96)),axis=0)
    sig_bin_result_C = np.append(sig_bin_result_C,
                                 np.append(cell_id,binary[3,:].reshape((95,1))).reshape((1,96)),axis=0)
    
    # gaussian smoothing
    if gauss_on:
        for i in range(8):
            mean_by_cond[i,:] = gaussian_filter(mean_by_cond[i,:],sigma=gauss_sigma)
            sem_by_cond[i,:] = gaussian_filter(sem_by_cond[i,:],sigma=gauss_sigma)          
    
    y_max = np.round(mean_by_cond.max()+sem_by_cond.max())
    y_min = np.round(mean_by_cond.min()-sem_by_cond.max())    
    
    if cell_id == 1907:
        y_max = 2
    
    fig = plt.figure(figsize=(12,12))
    plt.suptitle(cell_run.strip('.csv'),fontsize=15)
    gs = mpl.gridspec.GridSpec(nrows=4,ncols=3)
    
    ax0 = plt.subplot(gs[0,0])
    SDF_plot([0,1],ax0,colorpalette,2,True)
    ax0.set_title('Multimodal')
    ax0.set_xlabel('Time (ms)')
    
    ax1 = plt.subplot(gs[1,0])
    SDF_plot([2,3],ax1,colorpalette,2,True)
    ax1.set_title('Visual-only')

    ax2 = plt.subplot(gs[2,0])
    SDF_plot([4,5],ax2,colorpalette,2,True) 
    ax2.set_title('Auditory-only')

    ax3 = plt.subplot(gs[3,0])
    SDF_plot([6,7],ax3,colorpalette,2,True) 
    ax3.set_title('Control')

    ax4 = plt.subplot(gs[0,1])
    pval_plot(0,ax4,1)

    ax5 = plt.subplot(gs[1,1])
    pval_plot(1,ax5,1)

    ax6 = plt.subplot(gs[2,1])
    pval_plot(2,ax6,1)

    ax7 = plt.subplot(gs[3,1])
    pval_plot(3,ax7,1)

    ax8 = plt.subplot(gs[0,2])
    cohend_plot(0,ax8,1)
    
    ax9 = plt.subplot(gs[1,2])
    cohend_plot(1,ax9,1)

    ax10 = plt.subplot(gs[2,2])
    cohend_plot(2,ax10,1)   

    ax11 = plt.subplot(gs[3,2])
    cohend_plot(3,ax11,1) 

    fig.tight_layout()
    fig.subplots_adjust(top=0.92)
    
    if os.path.exists(save_path+str(date.today())+'/'+region) is False:
        os.makedirs(save_path+str(date.today())+'/'+region)    
    os.chdir(save_path+str(date.today())+'/'+region)
    #plt.savefig(cell_run.strip('.csv')+'.png',dpi=100,facecolor='white')
    plt.savefig(cell_run.strip('.csv')+'.svg')
    plt.close()
    
    # make csv output
    cell_result = {'Key':cell_id,
                   'RatID':rat_id,
                   'Session':session_id,
                   'Region':region,
                   'M_sig':int((sig_field_count[0]!=0)|(sig_field_count[1]!=0)),
                   'V_sig':int((sig_field_count[2]!=0)|(sig_field_count[3]!=0)),
                   'A_sig':int((sig_field_count[4]!=0)|(sig_field_count[5]!=0)),
                   'C_sig':int((sig_field_count[6]!=0)|(sig_field_count[7]!=0)),
                   'M_L_field_size':int(sig_field_count[0][0]),
                   'M_R_field_size':int(sig_field_count[1][0]),
                   'V_L_field_size':int(sig_field_count[2][0]),
                   'V_R_field_size':int(sig_field_count[3][0]),
                   'A_L_field_size':int(sig_field_count[4][0]),
                   'A_R_field_size':int(sig_field_count[5][0]),
                   'C_L_field_size':int(sig_field_count[6][0]),
                   'C_R_field_size':int(sig_field_count[7][0]),
                   'M_L_SI':round(SI[0][0]*1000)/1000,
                   'M_R_SI':round(SI[1][0]*1000)/1000,
                   'V_L_SI':round(SI[2][0]*1000)/1000,
                   'V_R_SI':round(SI[3][0]*1000)/1000,
                   'A_L_SI':round(SI[4][0]*1000)/1000,
                   'A_R_SI':round(SI[5][0]*1000)/1000,
                   'C_L_SI':round(SI[6][0]*1000)/1000,
                   'C_R_SI':round(SI[7][0]*1000)/1000,
                   'M_peaktime':int(peaktime[0][0])*10,
                   'V_peaktime':int(peaktime[1][0])*10,
                   'A_peaktime':int(peaktime[2][0])*10,
                   'C_peaktime':int(peaktime[3][0])*10}

    result.append(cell_result)    

# Export CSV file

In [14]:
cd_M = pd.DataFrame(cohend_result_M)
cd_V = pd.DataFrame(cohend_result_V)
cd_A = pd.DataFrame(cohend_result_A)
cd_C = pd.DataFrame(cohend_result_C)

sb_M = pd.DataFrame(sig_bin_result_M)
sb_V = pd.DataFrame(sig_bin_result_V)
sb_A = pd.DataFrame(sig_bin_result_A)
sb_C = pd.DataFrame(sig_bin_result_C)

result = pd.DataFrame(result)

In [15]:
save_path+str(date.today())+'/'+region
os.chdir(save_path+str(date.today()))

result.to_csv(str(date.today())+'_item-selectivity.csv',index=False)

cd_M.to_csv('Cohend_M.csv',index=False)
cd_V.to_csv('Cohend_V.csv',index=False)
cd_A.to_csv('Cohend_A.csv',index=False)
cd_C.to_csv('Cohend_C.csv',index=False)

sb_M.to_csv('Binary_M.csv',index=False)
sb_V.to_csv('Binary_V.csv',index=False)
sb_A.to_csv('Binary_A.csv',index=False)
sb_C.to_csv('Binary_C.csv',index=False)

In [16]:
print('END')

END
