In [None]:
# from google.colab import drive
# drive.mount('/content/drive/')
# %cd '/content/drive/MyDrive/fly_model_shared/brian_pipeline/'

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns
sns.set_style('whitegrid')

from pathlib import Path
import pickle

# files
pkl_map = './data/name_mappings_530.pickle' # name mappings
xls_exps = './data/exp_groups.xlsx' # definition of experiment groups
pkl_glob = Path('./data/exp_activation/').glob('default_*.pickle') # wildcard matching multiple outputs

# helper functions
def load_exp(pkl_glob):
    'load experiments from pickle and feather'

    meta_dict, spk_dict = dict(), dict()

    for pkl in pkl_glob:
        # exp name
        exp = pkl.name.replace('.pickle', '').replace('default_', '')

        # load metadata from pickle
        with open(pkl, 'rb') as f:
            data = pickle.load(f) 
        meta_dict[exp] = data

        # load spike times from feather
        fth = pkl.with_suffix('.feather')
        spk_dict[exp] = pd.read_feather(fth)

    return meta_dict, spk_dict

# Load results

In [None]:
# load name mappings
with open(pkl_map, 'rb') as f:
    flyid2i, flyid2name, i2flyid, i2name, name2flyid, name2i = pickle.load(f)


# get two dictionaries: meta data and spike dataframes
meta_dict, spk_dict = load_exp(pkl_glob) 

# create another dict with list of excited neurons per experiment
exc_dict = { i: [ flyid2name[k] for k in j['exc_i'] ] for i, j in meta_dict.items() }

# experimental setup TODO: do not hard code
t_sim = 1           # duration of trial in s 
n_run = 30          # number of runs

# print info
print('Loaded {} experiments...'.format(len(meta_dict)))
print('... ' + ' '.join([i for i in meta_dict.keys()]))
print()
print('Loaded {} name mapping... (showing first 100)'.format(len(name2flyid)))
print('... ' + ' '.join([i for i in name2flyid.keys()][:100]) + ' ...')

# Process data

In [None]:
# collect spike times of all experiments in one dataframe
df = pd.DataFrame()

for exp in spk_dict: # cycle through experiments
    df_exp = spk_dict[exp] # df for individual experiment
    
    #  sum up spike events
    df_exp = df_exp.dropna(how='all') # select only rows containing spikes
    ds = df_exp.apply(lambda x: np.concatenate(x.dropna(how='any')), axis=1) # concatenate spk times, ignore nan
    ds.name = exp # dataseries name will convert to dataframe column
    df = pd.concat([df, ds], axis=1)

# rename brian ids to flywire/custom names
df = df.rename(index=i2flyid).rename(index=flyid2name) # (1) flywire ids and (2) custom names
df.index = df.index.astype(str) # represent flywire IDs as str, not int

# sort indixes: named first
idx_n = [ i for i in name2flyid.keys() if i in df.index] # named neurons
idx_i = [ i for i in df.index if i not in name2flyid.keys() ] # flywire ids
idx_i.sort() # sort flywire ids
df = df.loc[idx_n + idx_i, :] # first named neurons, then flywire ids

# convert counts to rates
df_rate = df.applymap( lambda x: len(x) / ( n_run * t_sim ), na_action='ignore' )
df_rate = df_rate.fillna(0) # replace nan with 0, necessary for differences later

# save as xlsx
with pd.ExcelWriter('./results/all_experiments.xlsx', mode='w', engine='xlsxwriter') as w:
    df_rate.to_excel(w, sheet_name='all_experiments')

    # formatting in the xlsx file
    wb = w.book
    fmt = wb.add_format({'num_format': '#,##0.0'}) # set floating point display precision here (excel format)
    for _, ws in w.sheets.items():
        ws.set_column(1, 1, 10, fmt)
        ws.freeze_panes(1, 1)

# save and plot results

In [None]:
# helper functions
def remove_exc(df):
    'remove Poission inputs for a given experiment (requires exp_dict to be defined)'
    for exp in df:
        for exc in exc_dict[exp]:
            df.loc[exc, exp] = np.nan
    return df

def prune_df(df, exps, only_named=True, exclude_exc=True):
    '''returns dataframe
        only rows with not only 0 are kept
        only_named keeps only neurons that are defined in name2flyid dict
        exclude_exc removes values that are Poisson inputs
    '''

    # select experiments from dataframe
    df = df.loc[:, exps]

    # replace Poisson inputs with nan
    if exclude_exc:
        df = remove_exc(df)
        
    #  drop rows with only 0
    idx = df.replace(to_replace=0, value=np.nan).dropna(how='all').index # indices of rows that are not only 0/nan
    df = df.loc[idx, :]
      
    if only_named: # keep only neurons with explicit names
        idx = [ i for i in name2flyid.keys() if i in df.index] # named neurons
        df = df.loc[idx, :]
    
    return df

def diff_df(df, exps, only_named=True):
    '''returns rate dataframe with difference relative to the first column
        first column is unchanged
        only rows with not only 0 are kept
        only_named keeps only neurons that are defined in name2flyid dict
    '''
    df = prune_df(df, exps, only_named=only_named)

    # dataframe with rate differences
    df_d = pd.DataFrame(index=df.index)

    # first column: control
    df_d.loc[:, exps[0]] = df.loc[:, exps[0]]

    # other columns: difference to control
    for exp in exps[1:]:
        df_d.loc[:, exp] = df.loc[:, exp].fillna(0) - df.loc[:, exps[0]].fillna(0)

    return df_d

def plot_heatmap_abs(df, exps, only_named=True, path=None):
    'plot and save (optional) firing rates for experiments defined in exps'

    df = prune_df(df_rate, exps, only_named=only_named)

    x, y = df.shape
    fig, ax = plt.subplots(figsize=(y, x/3+1))

    ax.set_title('Absolute firing rate [Hz]')
    sns.heatmap(
        ax=ax, data=df, square=True,
        xticklabels=True, yticklabels=True,
        annot=True, fmt='.1f', annot_kws={'size': 'small'},
        cbar=False, cmap='binary',
    )
    ax.tick_params(axis='x', labeltop=True, labelbottom=True, labelrotation=90)

    fig.tight_layout()
    
    if path:
        fig.savefig(path)


def plot_heatmap_diff(df, n, exps, only_named=True, path=None):
    'plot heatmap where first column is the control and all other are plottet relative to control'

    df = diff_df(df_rate, exps=[n] + exps, only_named=only_named)

    x, y = df.shape
    fig, axarr = plt.subplots(ncols=2, gridspec_kw={'width_ratios': (8+1, 8+y)}, figsize=(y+2, x/3+2))
    
    ax = axarr[0]
    ax.set_title('Absolute\nfiring rate [Hz]')
    sns.heatmap(
        ax=ax, data=df.loc[:, [n]], square=True,
        xticklabels=True, yticklabels=True,
        annot=True, fmt='.1f', annot_kws={'size': 'small'},
        cbar=False, cmap='binary',
    )
    ax.tick_params(axis='x', labeltop=True, labelbottom=True, labelrotation=90)

    ax = axarr[1]
    ax.set_title('Change in\nfiring rate [Hz]')
    sns.heatmap(
        ax=ax, data=df.loc[:, exps], square=True, 
        xticklabels=True, yticklabels=True,
        annot=True, fmt='.1f', annot_kws={'size': 'small'},
        cbar=False, cmap='bwr_r', center=0
    )
    ax.tick_params(axis='y', labelright=True, labelleft=False, labelrotation=0)
    ax.tick_params(axis='x', labeltop=True, labelbottom=True, labelrotation=90)

    fig.tight_layout()

    if path:
        fig.savefig(path)
    
    return fig

def save_xls(df_rate, exp_sets, path, diff=False):

    with pd.ExcelWriter(path, mode='w', engine='xlsxwriter') as w:
        for n, exps in exp_sets.items():
            # select exps from df_rate
            if diff:
                df = diff_df(df_rate, exps=[n] + exps,only_named=False)
            else:
                df = prune_df(df_rate, exps, only_named=False)

            # write dataframes as first an second sheet in file
            df.to_excel(w, sheet_name=n)

        # formatting in the xlsx file
        wb = w.book
        fmt = wb.add_format({'num_format': '#,##0.0'}) # set floating point display precision here (excel format)
        for _, ws in w.sheets.items():
            ws.set_column(1, len(df.columns), 10, fmt)
            ws.freeze_panes(1, 1)

## define groups of experiments

In [None]:
# load experiment set definitions
exp_gr = pd.read_excel(xls_exps,  sheet_name=None ) # definition of experiement sets
for i in exp_gr.keys():
    d = exp_gr[i].to_dict(orient='list')
    for k, v in d.items():
        d[k] = [ j for j in d[k] if type(j) == str]
    exp_gr[i] = d

## save spread sheets

In [None]:
# control
save_xls(df_rate, exp_gr['control'], './results/control/spike_rates.xlsx')
# walk-stop coactivation
save_xls(df_rate, exp_gr['walk-stop'], './results/walk-stop/spike_rates.xlsx', diff=True)
# walk-walk coactivation
save_xls(df_rate, exp_gr['walk-walk'], './results/walk-walk/spike_rates.xlsx', diff=True)
# sensory-walk coactivation
save_xls(df_rate, exp_gr['sensory-walk'], './results/sensory-walk/spike_rates.xlsx', diff=True)
# walk-sensory coactivation
save_xls(df_rate, exp_gr['walk-sensory'], './results/walk-sensory/spike_rates.xlsx', diff=True)

## plot heatmaps

In [None]:
# control
plt.ioff()  # only save, but do not show plots
for n, exps in exp_gr['control'].items():
    plot_heatmap_abs(df_rate, exps, only_named=True, path='./results/control/{}.png'.format(n))

In [None]:
# walk-stop
plt.ioff()  # only save, but do not show plots
for n, exps in exp_gr['walk-stop'].items():
    plot_heatmap_diff(df_rate, n, exps, only_named=True, path='./results/walk-stop/{}.png'.format(n))

In [None]:
# walk-walk
plt.ioff()  # only save, but do not show plots
for n, exps in exp_gr['walk-walk'].items():
    plot_heatmap_diff(df_rate, n, exps, only_named=True, path='./results/walk-walk/{}.png'.format(n))

In [None]:
# sensory-walk
plt.ioff()  # only save, but do not show plots
for n, exps in exp_gr['sensory-walk'].items():
    plot_heatmap_diff(df_rate, n, exps, only_named=True, path='./results/sensory-walk/{}.png'.format(n))

In [None]:
# walk-sensory
plt.ioff()  # only save, but do not show plots
for n, exps in exp_gr['walk-sensory'].items():
    plot_heatmap_diff(df_rate, n, exps, only_named=True, path='./results/walk-sensory/{}.png'.format(n))