In [1]:
import os
import numpy as np
import pandas as pd

from scipy import stats
from scipy.ndimage import gaussian_filter
from scipy.signal import find_peaks

import pingouin

import matplotlib as mpl
import matplotlib.pyplot as plt  
import seaborn as sns

from datetime import date

import random

In [2]:
# no top and right spines in all plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [3]:
mother_path = 'D:/Multi-modal project/'

# Parameter setting

In [4]:
sig_alpha = 0.01
sig_cohend = 0.4
cons_bin_crit = 5

only_PER = True   # analyze only PER neurons?

 # colors for multimodal, vis-only, aud-only conditions
colorpalette = ['purple','blue','red','gray']

cat = ['M','V','A','MV','MA','VA','MVA']

today = str(date.today())

# Data preparation

In [5]:
unit_summary = pd.read_table(mother_path+'/analysis/result'+
                             '/1. Cluster summary/clusterSummary.csv',sep=',')
save_path = mother_path+'analysis/result/3. Item-selectivity/'
data_path = mother_path+'analysis/result/zFR export/13-Apr-2022 (5 trials)/'
os.chdir(data_path)

In [6]:
# empty list for result csv file
result=[]

In [7]:
cell_list = os.listdir(os.curdir)

# Data analysis

In [8]:
def find_sig_field(sig_bin,cohend):
    cohend = np.abs(cohend)
    field_bin = []
    for i in sig_bin:
        sig_chunk = []
        now_bin = i
        goback = True
        while goback is True:
            if now_bin < 0:
                goback = False
            elif cohend[now_bin] > sig_cohend:
                sig_chunk = np.append(sig_chunk,now_bin)
                now_bin -= 1
            else:
                goback = False
                
        now_bin = i
        goforward = True
        while goforward is True:
            if now_bin > 94:
                goforward = False
            elif cohend[now_bin] > sig_cohend:
                sig_chunk = np.append(sig_chunk,now_bin)
                now_bin += 1
            else:
                goforward = False
        if (np.max(sig_chunk)-np.min(sig_chunk))<(cons_bin_crit-1):
            sig_chunk = []
        field_bin = np.append(field_bin,sig_chunk)
    result = np.array(list(set(field_bin)))
    return result.astype(int)

In [9]:
def shuffle_proportion(num_shuffle):
    
    # M,V,A,MV,MA,VA,MVA
    shuffle_result = np.zeros((num_shuffle,len(cat)))
    
    for shuffle_run in range(num_shuffle):
        
        cat_count = np.zeros((1,len(cat)))
                
        for cell_run in cell_list:
            
            cell_info = cell_run.split('-')
            region = cell_info[5]
            if only_PER:
                if region!='PER':
                    continue;
            
            data = pd.read_csv(cell_run)
    
            data.drop(data[data.Correctness==0].index,inplace=True)                
            data.iloc[:,9:] = round(data.iloc[:,9:]*10000)/10000 
            data['ShuffledChoice'] = np.random.permutation(data['Choice'])            

            cond = [(data.Type=='Multimodal')&(data.ShuffledChoice=='Left'),
                    (data.Type=='Multimodal')&(data.ShuffledChoice=='Right'),
                    (data.Type=='Visual')&(data.ShuffledChoice=='Left'),
                    (data.Type=='Visual')&(data.ShuffledChoice=='Right'),
                    (data.Type=='Auditory')&(data.ShuffledChoice=='Left'),
                    (data.Type=='Auditory')&(data.ShuffledChoice=='Right')]    
            
            mean_by_cond = np.zeros((6,95))
            sem_by_cond = np.zeros((6,95))
            for i in range(6):
                mean_by_cond[i,:]=data[cond[i]].iloc[:,9:-1].to_numpy().mean(axis=0)
                sem_by_cond[i,:]=stats.sem(data[cond[i]].iloc[:,9:-1].to_numpy())    
                
            pval_map = np.zeros((3,95))
            sig_bin = np.zeros((3,95))
    
            cohend_by_cond = np.zeros((3,95))
            sig_field = dict()
    
            for i in range(3):
                for j in range(95):
                    comp_data = data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]
                    if all(x==comp_data.iloc[0] for x in comp_data):
                        pval_map[i,j] = 1
                        cohend_by_cond[i,j] = 0
                    else:
                        pval_map[i,j] = stats.ttest_ind(data[cond[i*2]].iloc[:,9+j],data[cond[i*2+1]].iloc[:,9+j])[1]
                        cohend_by_cond[i,j] = (mean_by_cond[i*2,j]-mean_by_cond[i*2+1,j])/(data[cond[i*2]|cond[i*2+1]].iloc[:,9+j]).std()
                    sig_bin[i,j] = (pval_map[i,j] < sig_alpha)*np.sign(data[cond[i*2]].iloc[:,9+j].mean()-data[cond[i*2+1]].iloc[:,9+j].mean())    

            pval_map[np.isinf(pval_map)] = 1
            pval_map[np.isnan(pval_map)] = 1
            cohend_by_cond[np.isinf(cohend_by_cond)] = 0        
            cohend_by_cond[np.isnan(cohend_by_cond)] = 0     
                
            sig_field_count = np.zeros((6,1))
            SI = np.zeros((6,1))
            peaktime = np.zeros((3,1))
            binary = np.zeros((3,95))
            for i in range(3):
                sig_field[i] = find_sig_field(np.nonzero(sig_bin[i,:])[0],cohend_by_cond[i,:])
                if sig_field[i].any():
                    peaktime[i] = sig_field[i][np.argmax(np.abs(cohend_by_cond[i,sig_field[i]]))]
                    for j in sig_field[i]:
                        if cohend_by_cond[i,j] > 0:
                            sig_field_count[2*i] += 1
                            SI[2*i] += cohend_by_cond[i,j]
                            binary[i,j] = 1
                        elif cohend_by_cond[i,j] < 0:
                            sig_field_count[2*i+1] += 1        
                            SI[2*i+1] += np.abs(cohend_by_cond[i,j])
                            binary[i,j] = -1       
                                
            M_sig = int((sig_field_count[0]!=0)|(sig_field_count[1]!=0))
            V_sig = int((sig_field_count[2]!=0)|(sig_field_count[3]!=0))
            A_sig = int((sig_field_count[4]!=0)|(sig_field_count[5]!=0))
            
            cat_info = ''
            if M_sig == 1:
                cat_info += 'M'
            if V_sig == 1:
                cat_info += 'V'
            if A_sig == 1:
                cat_info += 'A'
                
            for i,c in enumerate(cat):
                if cat_info==c:
                    cat_count[0,i]+=1
        
        shuffle_result[shuffle_run,:] = cat_count
    return shuffle_result

In [10]:
%%time
shuffle_result = shuffle_proportion(1)

CPU times: total: 3min 53s
Wall time: 4min 19s
