In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import json
import os

import analysis.pipelines.population_spikes as ps
from analysis import utils

RESULT_PATH = "simulation_results"
OUTPUT_PATH = "analysis_results"

save_data = True

In [2]:
ps.set_variables(RESULT_PATH=RESULT_PATH, OUTPUT_PATH=OUTPUT_PATH)

PN_pop_names = ps.PN_pop_names
ITN_pop_names = ps.ITN_pop_names

#### Set trials and statistics

In [3]:
filter = ('rand', 'div')
revert_junction = False
exclude = ['baseline', 'short', 'long']

with open(os.path.join(OUTPUT_PATH, 'trials_ordered.json'), 'r') as f:
    trials_ordered = json.load(f)
trials = ps.get_trials(filter, trials=trials_ordered, revert_junction=revert_junction, exclude=exclude)
print(trials)

['ramp_a0_t0.3_rand', 'ramp_a0_t0.3_div', 'ramp_a0_t1.0_rand', 'ramp_a0_t1.0_div', 'ramp_a0_t1.0_down_rand', 'ramp_a0_t1.0_down_div', 'ramp_a0_t3.0_rand', 'ramp_a0_t3.0_div', 'join_a0_t0.3_rand', 'join_a0_t0.3_div', 'join_a0_t1.0_rand', 'join_a0_t1.0_div', 'join_a0_t1.0_quit_rand', 'join_a0_t1.0_quit_div', 'join_a0_t3.0_rand', 'join_a0_t3.0_div', 'fade_a01_t0.1_rand', 'fade_a01_t0.1_div', 'fade_a03_t0.1_div', 'fade_a01_t0.3_rand', 'fade_a01_t0.3_div', 'fade_a03_t0.3_div', 'fade_a01_t1.0_rand', 'fade_a01_t1.0_div', 'fade_a03_t1.0_div', 'fade_a01_t3.0_rand', 'fade_a01_t3.0_div', 'fade_a03_t3.0_div']


In [4]:
wave_pop = 'ITN'

waves = ['gamma', 'beta']
pop_groups = ['PN stimulated', 'PN rest'] + ITN_pop_names
stats = ['mean', 'stdev']
stat_func = {'mean': np.mean, 'stdev': np.std}

# define function for retrieving unit ids in each group
def get_grp_ids(plv_ds):
    ids = {}
    idx = plv_ds.assy_id.values >= 0
    ids['PN stimulated'] = np.concatenate([plv_ds[p].values for p in plv_ds.PN_names[idx].values])
    ids['PN rest'] = np.concatenate([plv_ds[p].values for p in plv_ds.PN_names[~idx].values])
    for p in ITN_pop_names:
        ids[p] = plv_ds[p].values
    return ids

columns = pd.MultiIndex.from_product((waves, pop_groups, stats))
plv_list = []
trial_label = []

#### Get statistics dataframe

In [5]:
for tr in trials:
    file = os.path.join(ps.FR_ENTR_PATH, tr + '.nc')
    if not os.path.isfile(file):
        continue
    trial_label.append(utils.get_trial_label(tr, dlm=' '))
    plv_ds = xr.load_dataset(file)
    grp_ids = get_grp_ids(plv_ds)
    plv_stat = pd.Series(0., index=columns)
    for grp, ids in grp_ids.items():
        plv = plv_ds.PLV.sel(unit_id=ids, wave_population=wave_pop)
        for w in waves:
            for stat in stats:
                plv_stat.loc[(w, grp, stat)] = stat_func[stat](plv.values)
    plv_list.append(plv_stat.values)

plv_df = pd.DataFrame(plv_list, index=pd.Index(trial_label, name='trial'), columns=columns)

Save to file

In [6]:
if save_data:
    file_name = 'PLV_to_' + wave_pop + '_FR.csv'
    plv_df.to_csv(os.path.join(ps.FR_ENTR_PATH, file_name))

#### Format print statistics

In [7]:
@np.vectorize
def format_stats(mean, std):
    return f'{mean:.3f} ({std:.3f})'

pd_idx = pd.IndexSlice

means = plv_df.loc[:, pd_idx[:, :, 'mean']].values
stds = plv_df.loc[:, pd_idx[:, :, 'stdev']].values
stat_df = pd.DataFrame(format_stats(means, stds), index=plv_df.index,
                       columns=columns.droplevel(2).drop_duplicates())

In [8]:
if save_data:
    file_name = 'PLV_to_' + wave_pop + '_FR_fmt.csv'
    stat_df.to_csv(os.path.join(ps.FR_ENTR_PATH, file_name))