In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import json
import os

import analysis.pipelines.population_spikes as ps
from analysis import utils

RESULT_PATH = "simulation_results"
OUTPUT_PATH = "analysis_results"

save_data = True

In [2]:
ps.set_variables(RESULT_PATH=RESULT_PATH, OUTPUT_PATH=OUTPUT_PATH)

PN_pop_names = ps.PN_pop_names
ITN_pop_names = ps.ITN_pop_names

#### Set trials and statistics

In [3]:
filter = ('rand', 'div')
revert_junction = False
exclude = ['baseline', 'short', 'long']

with open(os.path.join(OUTPUT_PATH, 'trials_ordered.json'), 'r') as f:
    trials_ordered = json.load(f)
trials = ps.get_trials(filter, trials=trials_ordered, revert_junction=revert_junction, exclude=exclude)
print(trials)

['ramp_a0_t0.3_rand', 'ramp_a0_t0.3_div', 'ramp_a0_t1.0_rand', 'ramp_a0_t1.0_div', 'ramp_a0_t1.0_down_rand', 'ramp_a0_t1.0_down_div', 'ramp_a0_t3.0_rand', 'ramp_a0_t3.0_div', 'join_a0_t0.3_rand', 'join_a0_t0.3_div', 'join_a0_t1.0_rand', 'join_a0_t1.0_div', 'join_a0_t1.0_quit_rand', 'join_a0_t1.0_quit_div', 'join_a0_t3.0_rand', 'join_a0_t3.0_div', 'fade_a01_t0.1_rand', 'fade_a01_t0.1_div', 'fade_a03_t0.1_div', 'fade_a01_t0.3_rand', 'fade_a01_t0.3_div', 'fade_a03_t0.3_div', 'fade_a01_t1.0_rand', 'fade_a01_t1.0_div', 'fade_a03_t1.0_div', 'fade_a01_t3.0_rand', 'fade_a01_t3.0_div', 'fade_a03_t3.0_div']


In [4]:
wave_pop = 'ITN'

waves = ['gamma', 'beta']
pop_groups = ['PN stimulated', 'PN rest'] + ITN_pop_names
stats = ['mean', 'stdev']
stat_func = {'mean': np.mean, 'stdev': np.std}

# define function for retrieving unit ids in each group
def get_grp_ids(plv_ds):
    ids = {}
    idx = plv_ds.assy_id.values >= 0
    ids['PN stimulated'] = np.concatenate([plv_ds[p].values for p in plv_ds.PN_names[idx].values])
    ids['PN rest'] = np.concatenate([plv_ds[p].values for p in plv_ds.PN_names[~idx].values])
    for p in ITN_pop_names:
        ids[p] = plv_ds[p].values
    return ids

columns = pd.MultiIndex.from_product((waves, pop_groups, stats))
plv_list = []
trial_label = []

#### Get statistics dataframe

In [5]:
significant_duration = True  # whether normalize indicator by baseline, otherwise by trial average
suffix = '_sigdur' if significant_duration else ''

for tr in trials:
    file = os.path.join(ps.FR_ENTR_PATH, tr + suffix + '.nc')
    if not os.path.isfile(file):
        continue
    trial_label.append(utils.get_trial_label(tr, dlm=' '))
    plv_ds = xr.load_dataset(file)
    grp_ids = get_grp_ids(plv_ds)
    plv_stat = pd.Series(0., index=columns)
    for grp, ids in grp_ids.items():
        for w in waves:
            plv = plv_ds.PLV.sel(unit_id=ids, wave=w, wave_population=wave_pop)
            for stat in stats:
                plv_stat.loc[(w, grp, stat)] = stat_func[stat](plv.values)
    plv_list.append(plv_stat.values)

plv_df = pd.DataFrame(plv_list, index=pd.Index(trial_label, name='trial'), columns=columns)

Save to file

In [6]:
display(plv_df)
if save_data:
    file_name = 'PLV_to_' + wave_pop + '_FR' + suffix + '.csv'
    plv_df.to_csv(os.path.join(ps.FR_ENTR_PATH, file_name))

Unnamed: 0_level_0,gamma,gamma,gamma,gamma,gamma,gamma,gamma,gamma,beta,beta,beta,beta,beta,beta,beta,beta
Unnamed: 0_level_1,PN stimulated,PN stimulated,PN rest,PN rest,FSI,FSI,LTS,LTS,PN stimulated,PN stimulated,PN rest,PN rest,FSI,FSI,LTS,LTS
Unnamed: 0_level_2,mean,stdev,mean,stdev,mean,stdev,mean,stdev,mean,stdev,mean,stdev,mean,stdev,mean,stdev
trial,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3
ramp random0 t0.3,0.072808,0.120205,0.130059,0.255202,0.166793,0.166412,0.127467,0.195009,0.057332,0.118839,0.100438,0.247596,0.083425,0.174596,0.144626,0.206434
ramp strong0 t0.3,0.074624,0.118936,0.123595,0.248811,0.19361,0.187008,0.153022,0.199068,0.061744,0.119646,0.092338,0.239786,0.087574,0.18542,0.173042,0.224528
ramp random0 t1.0,0.063497,0.107995,0.124706,0.238731,0.152004,0.15318,0.087196,0.16537,0.078262,0.100231,0.126536,0.244752,0.122296,0.177844,0.399206,0.213449
ramp strong0 t1.0,0.057219,0.095168,0.127076,0.240055,0.149628,0.146895,0.081563,0.165844,0.092979,0.108296,0.12592,0.245639,0.118616,0.161409,0.434271,0.215733
ramp down random0 t1.0,0.135692,0.270736,0.12134,0.258608,0.14872,0.228526,0.102026,0.227305,0.123675,0.131166,0.127244,0.240417,0.166805,0.244934,0.465994,0.220461
ramp down strong0 t1.0,0.131568,0.257242,0.121473,0.257163,0.109228,0.197976,0.080942,0.196983,0.13637,0.132023,0.120429,0.233004,0.163705,0.243439,0.487921,0.226608
ramp random0 t3.0,0.054204,0.09192,0.104073,0.198467,0.136045,0.122595,0.075193,0.137139,0.103351,0.085285,0.110012,0.19447,0.122893,0.133052,0.424344,0.205233
ramp strong0 t3.0,0.050681,0.084857,0.103149,0.194767,0.14722,0.124319,0.080008,0.144165,0.119697,0.090378,0.111369,0.195838,0.126183,0.141764,0.459501,0.21229
join random0 t0.3,0.079655,0.174604,0.119964,0.249873,0.14815,0.187156,0.12552,0.218012,0.063717,0.139616,0.112775,0.251855,0.10324,0.193752,0.189644,0.205749
join strong0 t0.3,0.081193,0.178837,0.126251,0.257088,0.16434,0.192418,0.115923,0.178883,0.063593,0.143134,0.113041,0.256553,0.079403,0.166724,0.183751,0.204244


#### Format print statistics

In [7]:
@np.vectorize
def format_stats(mean, std):
    return f'{mean:.3f} ({std:.3f})'

pd_idx = pd.IndexSlice

means = plv_df.loc[:, pd_idx[:, :, 'mean']].values
stds = plv_df.loc[:, pd_idx[:, :, 'stdev']].values
stat_df = pd.DataFrame(format_stats(means, stds), index=plv_df.index,
                       columns=columns.droplevel(2).drop_duplicates())

In [8]:
display(stat_df)

if save_data:
    file_name = 'PLV_to_' + wave_pop + '_FR' + suffix + '_fmt.csv'
    stat_df.to_csv(os.path.join(ps.FR_ENTR_PATH, file_name))

Unnamed: 0_level_0,gamma,gamma,gamma,gamma,beta,beta,beta,beta
Unnamed: 0_level_1,PN stimulated,PN rest,FSI,LTS,PN stimulated,PN rest,FSI,LTS
trial,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
ramp random0 t0.3,0.073 (0.120),0.130 (0.255),0.167 (0.166),0.127 (0.195),0.057 (0.119),0.100 (0.248),0.083 (0.175),0.145 (0.206)
ramp strong0 t0.3,0.075 (0.119),0.124 (0.249),0.194 (0.187),0.153 (0.199),0.062 (0.120),0.092 (0.240),0.088 (0.185),0.173 (0.225)
ramp random0 t1.0,0.063 (0.108),0.125 (0.239),0.152 (0.153),0.087 (0.165),0.078 (0.100),0.127 (0.245),0.122 (0.178),0.399 (0.213)
ramp strong0 t1.0,0.057 (0.095),0.127 (0.240),0.150 (0.147),0.082 (0.166),0.093 (0.108),0.126 (0.246),0.119 (0.161),0.434 (0.216)
ramp down random0 t1.0,0.136 (0.271),0.121 (0.259),0.149 (0.229),0.102 (0.227),0.124 (0.131),0.127 (0.240),0.167 (0.245),0.466 (0.220)
ramp down strong0 t1.0,0.132 (0.257),0.121 (0.257),0.109 (0.198),0.081 (0.197),0.136 (0.132),0.120 (0.233),0.164 (0.243),0.488 (0.227)
ramp random0 t3.0,0.054 (0.092),0.104 (0.198),0.136 (0.123),0.075 (0.137),0.103 (0.085),0.110 (0.194),0.123 (0.133),0.424 (0.205)
ramp strong0 t3.0,0.051 (0.085),0.103 (0.195),0.147 (0.124),0.080 (0.144),0.120 (0.090),0.111 (0.196),0.126 (0.142),0.460 (0.212)
join random0 t0.3,0.080 (0.175),0.120 (0.250),0.148 (0.187),0.126 (0.218),0.064 (0.140),0.113 (0.252),0.103 (0.194),0.190 (0.206)
join strong0 t0.3,0.081 (0.179),0.126 (0.257),0.164 (0.192),0.116 (0.179),0.064 (0.143),0.113 (0.257),0.079 (0.167),0.184 (0.204)
