In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm_notebook as tqdm
import pandas as pd
import pathlib
from collections import defaultdict
from scipy.linalg import svd
import seaborn as sns
import multiprocessing
from functools import partial

from causal_optoconnectics.graphics import regplot, scatterplot, probplot
from causal_optoconnectics.tools import conditional_probability, joint_probability, roll_pad
from causal_optoconnectics.tools import compute_trials_multi, decompress_spikes
from causal_optoconnectics.core import Connectivity
%matplotlib inline

# Process results

In [3]:
x_i, x_j = 11, 13
y_i, y_j = 12, 19
z_i, z_j = 7, 10

def process_metadata(W, stim_index, params):
    
    pairs = []
    for i in range(params['n_neurons']):
        for j in range(params['n_neurons']):
            if i==j:
                continue
            pair = f'{i}_{j}'
            pairs.append({
                'source': i,
                'target': j,
                'pair': pair,
                'weight': W[i, j, 0],
                'source_stim': W[stim_index, i, 0] > 0,
                'source_stim_strength': W[stim_index, i, 0],
                'target_stim': W[stim_index, j, 0] > 0,
            })
    return pd.DataFrame(pairs)

def process(pair, trials, W, stim_index, params, n_trials=None):
    i, j = [int(a) for a in pair.split('_')]
    
    pre, post = trials[i], trials[j]
    
    n_trials = len(pre) if n_trials is None else n_trials
    
    conn = Connectivity(pre[:n_trials], post[:n_trials], x_i, x_j, y_i, y_j, z_i, z_j)
    
    result ={
        'source': i,
        'target': j,
        'pair': pair,
        'beta_iv': conn.beta_iv, 
        'beta': conn.beta,
        'beta_iv_did': conn.beta_iv_did, 
        'beta_did': conn.beta_did, 
        'hit_rate': conn.hit_rate,
        'weight': W[i, j, 0],
        'source_stim': W[stim_index, i, 0] > 0,
        'source_stim_strength': W[stim_index, i, 0],
        'target_stim': W[stim_index, j, 0] > 0,
    }
    result.update(params)
    return result


def compute_time_dependence(i, j, step=10000):
    pre, post = trials[i], trials[j]
    results = []
    start = 0
    for stop in tqdm(range(step, len(pre) + step, step)):
        results.append(process(i,j,stop))
    return results

In [4]:
def multi_process(W_0, trials, params, pairs=None):
    import multiprocessing
    from functools import partial
    if pairs is None:
        pairs = []
        for i in range(params['n_neurons']):
            for j in range(params['n_neurons']):
                if i==j:
                    continue
                pairs.append(f'{i}_{j}')
    multiprocessing.freeze_support()
    with multiprocessing.Pool() as p:
        results = p.map(partial(process, W_0=W_0, trials=trials, params=params), pairs)
        
    return pd.DataFrame(results)

In [5]:
from scipy.linalg import norm
from scipy.optimize import minimize_scalar

def error(a, df, key):
    return df['weight'] - a * df[key]

def error_norm(a, df, key):
    return norm(error(a, df, key), ord=2)

def min_error(df, key):
    return minimize_scalar(error_norm, args=(df, key)).fun

In [6]:
data_path = pathlib.Path('/mnt/causal_optoconnectics/sweep_1/')

In [7]:
paths = list(data_path.iterdir())
data_df = pd.DataFrame({'path': paths})
for i, fname in enumerate(paths):
    a,b,c = fname.stem.split('_')
    data_df.loc[i, 'n_neurons'], data_df.loc[i, 'stim_strength'], data_df.loc[i, 'sigma'] = a.replace('n', ''), b.replace('ss', ''), c.replace('s', '')
data_df = data_df.astype({'stim_strength': float, 'sigma': float, 'n_neurons': int})

In [8]:
data_df

Unnamed: 0,path,n_neurons,stim_strength,sigma
0,/mnt/causal_optoconnectics/sweep_1/n40_ss3_s10,40,3.0,10.0
1,/mnt/causal_optoconnectics/sweep_1/n20_ss2_s6,20,2.0,6.0
2,/mnt/causal_optoconnectics/sweep_1/n40_ss1_s7,40,1.0,7.0
3,/mnt/causal_optoconnectics/sweep_1/n30_ss3_s6,30,3.0,6.0
4,/mnt/causal_optoconnectics/sweep_1/n50_ss1_s3,50,1.0,3.0
...,...,...,...,...
272,/mnt/causal_optoconnectics/sweep_1/n20_ss5_s9,20,5.0,9.0
273,/mnt/causal_optoconnectics/sweep_1/n30_ss4_s7,30,4.0,7.0
274,/mnt/causal_optoconnectics/sweep_1/n20_ss2_s10,20,2.0,10.0
275,/mnt/causal_optoconnectics/sweep_1/n20_ss3_s2,20,3.0,2.0


In [9]:
def load(path):
    data = np.load(path / 'rank_0.npz', allow_pickle=True)   
    data = {k: data[k][()] for k in data.keys()}
    data['data'] = [np.load(fn, allow_pickle=True)['data'][()] for fn in path.glob('*.npz')]
    return data
#     data['all_params'] = [np.load(fn, allow_pickle=True)['params'][()] for fn in path.glob('*.npz')]
#     data['all_W_0'] = [np.load(fn, allow_pickle=True)['W_0'][()] for fn in path.glob('*.npz')]
#     data['all_W'] = [np.load(fn, allow_pickle=True)['W'][()] for fn in path.glob('*.npz')]
    
#     assert(all([a==b for a, b in zip(data['all_params'], data['all_params'][1:])]))
#     assert(all([np.array_equal(a,b,equal_nan=True) for a, b in zip(data['all_W_0'], data['all_W_0'][1:])]))
#     assert(all([np.array_equal(a,b,equal_nan=True) for a, b in zip(data['all_W'], data['all_W'][1:])]))
#     return data

In [None]:
values = pd.DataFrame()
for i, row in tqdm(data_df.iterrows(), total=len(data_df)):
    data = load(row.path)
    X = data['data']
    W_0 = data['W_0']
    W = data['W']
    stim_index = range(len(W))[-1]
    params = data['params']
    n_neurons = params['n_neurons']
    assert n_neurons == row.n_neurons
    trials = compute_trials_multi(X, len(W_0), stim_index)
    
    np.fill_diagonal(W_0, 0)
    s_W = svd(W_0, compute_uv=False)
    data_df.loc[i, 'W_condition'] = s_W.max() / s_W.min()
    data_df.loc[i, 'W_smin'] = s_W.min()
    data_df.loc[i, 'W_smax'] = s_W.max()
    
#     x = decompress_spikes(X[0], len(W), params['n_time_step'])
#     s = svd(x[:len(W_0)], compute_uv=False)
#     data_df.loc[i, 'condition'] = s.max() / s.min()
#     data_df.loc[i, 'smin'] = s.min()
#     data_df.loc[i, 'smax'] = s.max()

    x = decompress_spikes(X[0], len(W), params['n_time_step'])
    cov_x = np.cov(x[:len(W_0)])
    s_cov = svd(cov_x, compute_uv=False)
    data_df.loc[i, 'cov_condition'] = s_cov.max() / s_cov.min()
    data_df.loc[i, 'cov_smin'] = s_cov.min()
    data_df.loc[i, 'cov_smax'] = s_cov.max()

    results_meta = process_metadata(W=W, stim_index=stim_index, params=params)
    sample_meta = results_meta.query('source_stim and not target_stim and weight >= 0')
    sample = pd.DataFrame([process(pair=pair, W=W, stim_index=stim_index, trials=trials, params=params) for pair in sample_meta.pair.values])
#     with multiprocessing.Pool() as p:
#         sample = p.map(partial(process, W=W, stim_index=stim_index, trials=trials, params=params), sample_meta.pair.values.tolist())
    sample = pd.DataFrame(sample) 
    values = pd.concat((values, sample))
    data_df.loc[i, 'error_beta'] = min_error(sample, 'beta_did')
    data_df.loc[i, 'error_beta_iv'] = min_error(sample, 'beta_iv_did')

  0%|          | 0/277 [00:00<?, ?it/s]

In [None]:
data_df.to_csv(data_path / 'summary.csv')

In [None]:
values.to_csv(data_path / 'values.csv')

In [None]:
sub_df = data_df.query('stim_strength == 4')

In [None]:
import matplotlib
fig, (ax, cax) = plt.subplots(1, 2, gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1}, figsize=(6,5), dpi=150)
mnorm = matplotlib.colors.LogNorm()
xmin, xmax = min(sub_df.error_beta.min(), sub_df.error_beta_iv.min()), max(sub_df.error_beta.max(), sub_df.error_beta_iv.max())
sc = ax.scatter(sub_df.error_beta, sub_df.error_beta_iv, c=mnorm(sub_df.cov_condition))
ax.plot([xmin, xmax], [xmin, xmax])
ax.set_xlabel(r'$\mathrm{error}(\beta)$')
ax.set_ylabel(r'$\mathrm{error}(\beta_{IV})$')
cbar = plt.colorbar(sc, cax=cax)
cbar.ax.set_yticklabels(mnorm.inverse(cbar.ax.get_yticks()).round())
cbar.ax.set_ylabel(r'$||\mathrm{cov}|| \times ||\mathrm{cov}^{-1}||$')
sns.despine()

In [None]:
Error_beta = sub_df.pivot('sigma', 'n_neurons', 'error_beta')
Error_beta_iv = sub_df.pivot('sigma', 'n_neurons', 'error_beta_iv')
Condition = sub_df.pivot('sigma', 'n_neurons', 'cov_condition')

In [None]:
plt.rcParams.update({'figure.figsize': (5,5), 'figure.dpi': 150})

In [None]:
sns.heatmap(Error_beta, annot=sub_df.pivot('sigma', 'n_neurons', 'W_smax'))

In [None]:
fig, (ax, cax) = plt.subplots(1, 2, gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1}, figsize=(6,5))
im = ax.imshow(Error_beta - Error_beta_iv, aspect='auto', origin='lower', cmap='PiYG', vmin=-6, vmax=6)
ax.set_yticks(range(len(stim_strengths)))
ax.set_xticks(range(len(sigmas)))
ax.set_yticklabels(stim_strengths)
ax.set_xticklabels(sigmas)
# plt.title(r'Error comparison')
cbar = plt.colorbar(im, cax=cax)
cbar.ax.set_ylabel(r'$\mathrm{error}(\beta) - \mathrm{error}(\beta_{IV})$')
ax.set_ylabel('Stimulus strength')
ax.set_xlabel(r'$\sigma$')
k = 0
for i, stim_strength in enumerate(stim_strengths):
    for j, sigma in enumerate(sigmas):
        ax.text(j - 0.2, i, condition[k].round(1))
        k += 1

In [None]:
vmin, vmax = min(Error_beta.min(), Error_beta_iv.min()), max(Error_beta.max(), Error_beta_iv.max())
fig, (ax, cax) = plt.subplots(1, 2, gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1}, figsize=(6,5))
im = ax.imshow(Error_beta, aspect='auto', origin='lower', vmin=vmin, vmax=vmax)
ax.set_yticks(range(len(stim_strengths)))
ax.set_xticks(range(len(sigmas)))
ax.set_yticklabels(stim_strengths)
ax.set_xticklabels(sigmas)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.set_ylabel(r'$\mathrm{error}(\beta)$')
ax.set_ylabel('Stimulus strength')
ax.set_xlabel(r'$\sigma$')
ax.set_title(r'$\beta$')


fig, (ax, cax) = plt.subplots(1, 2, gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1}, figsize=(6,5))
im = ax.imshow(Error_beta_iv, aspect='auto', origin='lower', vmin=vmin, vmax=vmax)
ax.set_yticks(range(len(stim_strengths)))
ax.set_xticks(range(len(sigmas)))
ax.set_yticklabels(stim_strengths)
ax.set_xticklabels(sigmas)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.set_ylabel(r'$\mathrm{error}(\beta_{IV})$')
ax.set_ylabel('Stimulus strength')
ax.set_xlabel(r'$\sigma$')
ax.set_title(r'$\beta_{IV}$')


fig, (ax, cax) = plt.subplots(1, 2, gridspec_kw={'width_ratios':[1,0.05], 'wspace': 0.1}, figsize=(6,5))
im = ax.imshow(Condition, aspect='auto', origin='lower')
ax.set_yticks(range(len(stim_strengths)))
ax.set_xticks(range(len(sigmas)))
ax.set_yticklabels(stim_strengths)
ax.set_xticklabels(sigmas)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.set_ylabel(r'$||cov|| \times ||cov^{-1}||$')
ax.set_ylabel('Stimulus strength')
ax.set_xlabel(r'$\sigma$')
ax.set_title('Condition of inversion')
