# Demonstrate the path of high probability and the orthogonal path on the pyloric rhythm for experimental data

In [None]:
# Note: this application requires a more recent version of dill.
# Other applications in this repository will require 0.2.7.1
# You might have to switch between versions to run all applications.
!pip install --upgrade dill

In [1]:
import numpy as np
import matplotlib.pylab as plt
import delfi.distribution as dd
import time
from copy import deepcopy
import sys
sys.path.append("model/setup")
sys.path.append("model/simulator")
sys.path.append("model/inference")
sys.path.append("model/visualization")
sys.path.append("model/utils")

import sys; sys.path.append('../')
from common import col, svg, plot_pdf, samples_nd

import netio
import viz
import importlib
import viz_samples
import train_utils as tu

import matplotlib as mpl

%load_ext autoreload
%autoreload 2

In [2]:
PANEL_A  = 'illustration/panel_a.svg'
PANEL_B  = 'svg/31D_panel_b.svg'
PANEL_C  = 'svg/31D_panel_c.svg'
PANEL_C2 = 'svg/31D_panel_c2.svg'
PANEL_D = 'svg/31D_panel_d.svg'

PANEL_X1params = 'svg/31D_panel_App1_params.svg'
PANEL_X2params = 'svg/31D_panel_App2_params.svg'
PANEL_X1ss = 'svg/31D_panel_App1_ss.svg'
PANEL_X2ss = 'svg/31D_panel_App2_ss.svg'

PANEL_X = 'svg/31D_panel_x.svg'

### Load samples

In [3]:
params = netio.load_setup('train_31D_R1_BigPaper')

In [4]:
filedir = "results/31D_samples/pyloricsamples_31D_noNaN_3.npz"
pilot_data, trn_data, params_mean, params_std = tu.load_trn_data_normalize(filedir, params)
print('We use', len(trn_data[0]), 'training samples.')

stats = trn_data[1]
stats_mean = np.mean(stats, axis=0)
stats_std  = np.std(stats, axis=0)

We use 170000 training samples.


### Load network'

In [5]:
date_today = '1908208'
import dill as pickle
with open('results/31D_nets/191001_seed1_Exper11deg.pkl', 'rb') as file:
    inf_SNPE_MAF, log, params = pickle.load(file)
params = netio.load_setup('train_31D_R1_BigPaper')

In [6]:
prior = netio.create_prior(params, log=True)
dimensions = np.sum(params.use_membrane) + 7
lims = np.asarray([-np.sqrt(3)*np.ones(dimensions), np.sqrt(3)*np.ones(dimensions)]).T

In [7]:
prior = netio.create_prior(params, log=True)
params_mean = prior.mean
params_std = prior.std

In [8]:
from find_pyloric import merge_samples, params_are_bounded

labels_ = viz.get_labels(params)
prior_normalized = dd.Uniform(-np.sqrt(3)*np.ones(dimensions), np.sqrt(3)*np.ones(dimensions), seed=params.seed)

### Load experimental data

In [9]:
summstats_experimental = np.load('results/31D_experimental/190807_summstats_prep845_082_0044.npz')['summ_stats']

### Calculate posterior

In [10]:
from find_pyloric import merge_samples, params_are_bounded

all_paths = []
all_posteriors = []
labels_ = viz.get_labels(params)

posterior_MAF = inf_SNPE_MAF.predict([summstats_experimental]) # given the current sample, we now predict the posterior given our simulation outcome. Note that this could just be overfitted.

### Load samples

In [None]:
samples_MAF = merge_samples("results/31D_samples/02_cond_vals", name='conductance_params')
samples_MAF = np.reshape(samples_MAF, (1000*2520, 31))
print(np.shape(samples_MAF))

### Load start and end point

In [None]:
num_to_watch = 3
infile = 'results/31D_pairs/similar_and_good/sample_pair_{}.npz'.format(num_to_watch) # 0 is shitty
npz = np.load(infile)
start_point = npz['params1']
end_point = npz['params2']

In [None]:
start_point_unnorm = start_point * params_std + params_mean
end_point_unnorm   = end_point   * params_std + params_mean
ratio = end_point_unnorm / start_point_unnorm
run_true = (ratio > np.ones_like(ratio) * 2.0) | (ratio < np.ones_like(ratio) / 2.0)

In [None]:
print(run_true)

### Calculate the high-probability path

In [None]:
from HighProbabilityPath import HighProbabilityPath

In [None]:
# number of basis functions used
num_basis_functions = 2

# number of timesteps
num_path_steps = 80

high_p_path = HighProbabilityPath(num_basis_functions, num_path_steps, use_sine_square=True)

In [None]:
#print('Starting to calculate path')
#high_p_path.set_start_end(start_point, end_point)
#high_p_path.set_pdf(posterior_MAF, dimensions)
#high_p_path.find_path(posterior_MAF, prior=prior_normalized, multiply_posterior=1,
#                      non_linearity=None, non_lin_param=3.0)
#high_p_path.get_travelled_distance()
#print('Finished calculating path')

In [None]:
#np.savez('results/31D_paths/high_p_path.npz', high_p_path=high_p_path)

In [None]:
high_p_path = np.load('results/31D_paths/high_p_path.npz', allow_pickle=True)['high_p_path'].tolist()

In [None]:
lims = np.asarray([-np.sqrt(3)*np.ones(dimensions), np.sqrt(3)*np.ones(dimensions)]).T

# Panel B: experimental data
Note: the full data is not contained in the repo. Therefore, this figure can not be created.

In [None]:
npz = np.load('results/31D_experimental/trace_data_845_082_0044.npz')
t = npz['t']
PD_spikes = npz['PD_spikes']
LP_spikes = npz['LP_spikes']
PY_spikes = npz['PY_spikes']
pdn = npz['pdn']
lpn = npz['lpn']
pyn = npz['pyn']

In [None]:
start_index = 219500 + 2100
end_index   = 246500 + 2100  # 32000
height_offset = 200
shown_t = t[end_index] - t[start_index]
time_len = shown_t / 0.025 * 1000
dt = t[1] - t[0]

In [None]:
import matplotlib.patches as mp

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    fig, ax = plt.subplots(1,1,figsize=(2.87, 2.08*3/4)) # (2.87, 2.08*3/4)
    ax.plot(t[start_index:end_index], 2.5+pdn[start_index:end_index]*0.007, c=col['GT'], lw=0.8)
    ax.plot(t[start_index:end_index], 1.2+lpn[start_index:end_index]*0.25, c=col['GT'], lw=0.8)
    ax.plot(t[start_index:end_index], -0.1+pyn[start_index:end_index]*0.013, c=col['GT'], lw=0.8)
    
    linew = 0.4
    headl = 0.06
    headw = 0.16
    linelen = 0.17
    circlefact = 0.8
    
    # period arrow
    height1 = 3.2
    plt.arrow(t[start_index]+0.6, height1, 1.15, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.arrow(t[start_index]+1.75, height1, -1.15, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.plot([t[start_index]+0.6, t[start_index]+0.6], [height1-linelen,height1+linelen], c='k', lw=linew*1.5)
    plt.plot([t[start_index]+1.75, t[start_index]+1.75], [height1-linelen,height1+linelen], c='k', lw=linew*1.5)
    #patch =mp.Ellipse((t[start_index]+1.2, 3.65), 0.2*circlefact,0.6*circlefact, color='lightgray')
    #ax.add_patch(patch)
    
    # delay arrow
    height2 = 1.64
    plt.arrow(t[start_index]+0.6, height2, 0.48, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.arrow(t[start_index]+1.08, height2, -0.48, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.plot([t[start_index]+0.6, t[start_index]+0.6], [height2-linelen,height2+linelen], c='k', lw=linew*1.5)
    plt.plot([t[start_index]+1.08, t[start_index]+1.08], [height2-linelen,height2+linelen], c='k', lw=linew*1.5)
    #patch =mp.Ellipse((t[start_index]+0.94, 2.1), 0.2*circlefact,0.6*circlefact, color='lightgray')
    #ax.add_patch(patch)
    
    # gap arrow
    plt.arrow(t[start_index]+1.98, height2, 0.27, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.arrow(t[start_index]+2.25, height2, -0.27, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.plot([t[start_index]+1.98, t[start_index]+1.98], [height2-linelen,height2+linelen], c='k', lw=linew*1.5)
    plt.plot([t[start_index]+2.25, t[start_index]+2.25], [height2-linelen,height2+linelen], c='k', lw=linew*1.5)
    #patch =mp.Ellipse((t[start_index]+2.1, 2.1), 0.2*circlefact,0.6*circlefact, color='lightgray')
    #ax.add_patch(patch)
    
    # duration arrow
    height4 = 0.44
    plt.arrow(t[start_index]+1.33, height4, 0.43, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.arrow(t[start_index]+1.76, height4, -0.43, 0,  shape='full', head_width=headw, head_length=headl, length_includes_head=True, color='k', lw=linew)
    plt.plot([t[start_index]+1.33, t[start_index]+1.33], [height4-linelen,height4+linelen], c='k', lw=linew*1.5)
    plt.plot([t[start_index]+1.76, t[start_index]+1.76], [height4-linelen,height4+linelen], c='k', lw=linew*1.5)
    #patch =mp.Ellipse((t[start_index]+1.55, 0.9), radius=0.2, color='lightgray')
    #ax.add_patch(patch)
    

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.get_xaxis().set_ticks([])
    ax.get_yaxis().set_visible(False)
    ax.set_ylim([-0.95, 4.0])
    
    duration = 0.5
    number_of_timesteps = int(duration / dt)
    t_scale = np.linspace(t[start_index], t[start_index + number_of_timesteps], 2)
    
    ax.plot(t_scale, -0.8 * np.ones_like(t_scale), c='k', lw=1.0)
    
    #plt.savefig(PANEL_B, facecolor='None', transparent=True)
    plt.show()

# Panel C: posterior

In [None]:
from decimal import Decimal
all_labels = []
for dim_i in range(31):
    if dim_i > len(params_mean) - 7.5: # synapses
        if dim_i == 24: all_labels.append([r'$\mathdefault{0.01}$ ', r'$\mathdefault{10000}\;\;\;\;$  '])
        else: all_labels.append([r'$\;\;\mathdefault{0.01}$', r'$\mathdefault{1000}\;\;\;\;$ '])
    else: # membrane conductances
        num_after_digits = -int(np.log10(lims[dim_i, 1] * params_std[dim_i] + params_mean[dim_i]))
        if num_after_digits > 2:
            num_after_digits=2
        labels = [round(Decimal((lims[dim_i, num_tmp] * params_std[dim_i] + params_mean[dim_i]) / 0.628e-3), num_after_digits)
                  for num_tmp in range(2)]
        new_labels = []
        counter=0
        for l in labels:
            if counter == 0:
                new_labels.append(r'$\mathdefault{'+str(l)+'}$')
            else:
                new_labels.append(r'$\mathdefault{'+str(l)+'}\;\;\;$ ')
            counter+=1
        all_labels.append(new_labels)

In [None]:
import matplotlib.patheffects as pe

with mpl.rc_context(fname='../.matplotlibrc'):

    labels_ = viz.get_labels_8pt(params)
    labels_[9] += ''

    fig, axes = samples_nd(samples=[samples_MAF[:1260000], high_p_path.path_coords],
                           subset=[2,4,10,19,24,25,26,28],
                           limits=lims,
                           ticks=lims,
                           tick_labels=all_labels,
                           fig_size=(17.0*0.2435,17.0*0.2435),
                           labels=labels_,
                           points=[start_point, end_point],
                           scatter_offdiag={'rasterized':True, 'alpha':1.0},
                           points_offdiag={'marker':'o', 'markeredgecolor':'w', 'markersize':3.6, 'markeredgewidth':0.5, 'path_effects':[pe.Stroke(linewidth=1.2, foreground='k'), pe.Normal()]},
                           points_colors=[col['CONSISTENT1'], col['CONSISTENT2']],
                           samples_colors=[col['SNPE'], 'white'],
                           diag=['kde', 'None'],
                           upper=['hist', 'plot'],
                           hist_offdiag={'bins':50},
                           plot_offdiag={'linewidth': 1.6, 'path_effects':[pe.Stroke(linewidth=2.4, foreground='k'), pe.Normal()]})
    
#     plt.savefig(PANEL_C, facecolor='None', transparent=True)
    plt.show()

### Evaluate whether samples along path are identical according to Prinz

In [None]:
pyloric_sim = netio.create_simulators(params)
summ_stats = netio.create_summstats(params)

In [None]:
from viz import plot_posterior_over_path

In [None]:
high_p_path_mod = deepcopy(high_p_path)
# plots for the samples
num_cols = 2
num_rows = 5
scale = 'dist' # set this to 'dist' if you want to x-axis to be scale according to the travelled distance

num_steps = num_cols*num_rows
if scale == 'dist':
    steps = np.linspace(0, high_p_path_mod.dists[-1], num_steps)
else:
    steps = np.linspace(0, 1.0, num_steps)

# Inlet for Panel C

In [None]:
dimensions_to_use = [24,25]

high_p_path_mod = deepcopy(high_p_path)
num_paths = 10
path_start_positions = np.linspace(0, high_p_path_mod.dists[-1], num_paths)
high_p_indizes = high_p_path_mod.find_closest_index_to_dist(path_start_positions)

In [None]:
use_high_p_index = 45
high_p_indizes = [use_high_p_index]

In [None]:
from OrthogonalPath import OrthogonalPath

dimensions_to_use = [24,25]

high_p_path_mod = deepcopy(high_p_path)
start_point_ind = 23# 10

# ortho_path = OrthogonalPath(high_p_path_mod.path_coords, start_point_ind)
# ortho_path.find_orthogonal_path(posterior_MAF, max_distance=high_p_path_mod.dists[-1]/27, dim=dimensions, prior=prior_normalized)
# ortho_path.get_travelled_distance()
# print(len(ortho_path.path_coords))
#np.savez('results/31D_paths/ortho_path.npz', ortho_path=ortho_path)

In [None]:
ortho_path = np.load('results/31D_paths/ortho_path.npz', allow_pickle=True)['ortho_path'].tolist()

In [None]:
ortho_path_mod = deepcopy(ortho_path)
num_path_pos = 2
path_start_positions = np.linspace(0, ortho_path_mod.dists[-1], num_path_pos)
ortho_p_indizes = ortho_path_mod.find_closest_index_to_dist(path_start_positions)

In [None]:
ortho_p_indizes = [ortho_p_indizes[-1]]

In [None]:
labels_ = viz.get_labels_8pt(params)
labels_[9] += ''
color_mixture = 0.5 * (np.asarray(list(col['CONSISTENT1'])) + np.asarray(list(col['CONSISTENT2'])))

p1g = high_p_path.path_coords[int(high_p_indizes[0])]

p1b = ortho_path.path_coords[int(ortho_p_indizes[0])]

with mpl.rc_context(fname='../.matplotlibrc'):

    _ = viz.plot_single_marginal_pdf(pdf1=posterior_MAF, prior=prior, resolution=200,
                                                 lims=lims, samples=np.transpose(samples_MAF), figsize=(1.5, 1.5),
                                                 ticks=False, no_contours=True, labels_params=labels_,
                                                 start_point=high_p_path.start_point, end_point=high_p_path.end_point,
                                                 path1=high_p_path.path_coords, display_axis_lims=True,
                                                 path2=ortho_path.path_coords, pointscale=0.5,
                                                 p1g=p1g, start_col=col['CONSISTENT1'], end_col=col['CONSISTENT2'],
                                                 p1b=p1b, current_col1=color_mixture,current_col=col['CONSISTENT2'],
                                                 current_col2=col['INCONSISTENT'],
                                                 path_steps1=1, path_steps2=1,
                                                 dimensions=dimensions_to_use)
    #plt.savefig(PANEL_C2, facecolor='None', transparent=True, dpi=300, bbox_inches='tight')
    plt.show()

# Panel D

In [None]:
dimensions_to_use = [6,7]

high_p_path_mod = deepcopy(high_p_path)
num_paths = 5
path_start_positions = np.linspace(0, high_p_path_mod.dists[-1], num_paths)
high_p_indizes = high_p_path_mod.find_closest_index_to_dist(path_start_positions)
indizes_show = high_p_indizes
high_p_indizes.pop(2)
high_p_indizes.pop(1)
current_point = high_p_path_mod.path_coords[high_p_indizes]
high_p_indizes = np.flip(high_p_indizes)
print(high_p_indizes)

In [None]:
high_p_indizes = [79, 0, use_high_p_index]

In [None]:
prior.mean

In [None]:
prior.std

In [None]:
labels_ = viz.get_labels_8pt(params)
high_p_path_mod = deepcopy(high_p_path)
seeds = [8, 8, 8, 8, 8]
offsets = 39000 * np.ones_like(seeds)
#offsets[0] = 47000
offsets[1] = 83500 # 75500
offsets[2] = 29000 # 21000
offsets[3] = 40500 # 40500
dimensions_to_use2D = [6,7]

with mpl.rc_context(fname='../.matplotlibrc'):

    fig = viz.viz_path_and_samples_abstract_twoRows(posterior_MoG=posterior_MAF, high_p_path=high_p_path_mod, ortho_path=ortho_path_mod, prior=prior, lims=lims, samples=samples_MAF,
                                                    figsize=(5.87, 3.0), offsets=offsets, linescale=1.5, ticks=False, no_contours=True, labels_params=labels_, start_point=high_p_path.start_point,
                                                    end_point=high_p_path.end_point, ortho_p_indizes=ortho_p_indizes, high_p_indizes=high_p_indizes, mycols=col, time_len=int(time_len),
                                                    path1=high_p_path_mod.path_coords, path_steps1=1, path2=ortho_path_mod.path_coords, path_steps2=1, dimensions_to_use=dimensions_to_use2D, #ax=ax,
                                                    seeds=seeds, indizes=[0], hyperparams=params, date_today='190910_80start', case='ortho_p', save_fig=False)
    #plt.savefig(PANEL_D, facecolor='None', transparent=True, dpi=300, bbox_inches='tight')
    plt.show()

# Assemble figure

In [None]:
color_mixture = 0.5 * (np.asarray(list(col['CONSISTENT1'])) + np.asarray(list(col['CONSISTENT2'])))

In [None]:
import time
import IPython.display as IPd

def svg(img):
    IPd.display(IPd.HTML('<img src="{}" / >'.format(img, time.time())))

In [None]:
from svgutils.compose import *

# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.
# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964
svg_scale = 1.25  # set this to 1.25 for Inkscape, 1.0 otherwise
factor_svg=5.5

# Panel letters in Helvetica Neue, 12pt, Medium
kwargs_text = {'size': '12pt', 'font': 'Arial', 'weight': '800'}
kwargs_consistent = {'size': '10pt', 'font': 'Arial', 'weight': '500', 'color': '#AF99EF'}
kwargs_consistent1 = {'size': '10pt', 'font': 'Arial', 'weight': '500', 'color': '#9E7DD5'}
kwargs_inconsistent = {'size': '10pt', 'font': 'Arial', 'weight': '500', 'color': '#D73789'}
kwargs_text8pt = {'size': '7.7pt', 'font': 'Arial'}

startx1 = 492
startx2 = 594
starty1 = 204
starty2 = 307

endx1 = 642
endx2 = 673
endy1 = 159
endy2 = 191

deltax1 =  endx1-startx1
deltax2 =  endx2-startx2
deltay1 =  endy1-starty1
deltay2 =  endy2-starty2

sizefactor = 1.0
dshift = 0.5*factor_svg

f = Figure("20.3cm", "9.1cm",

    Line(((startx1,starty1+dshift),(startx1+deltax1*sizefactor,starty1+dshift+deltay1*sizefactor)), width=1.5, color='grey'),
    Line(((startx2,starty2+dshift),(startx2+deltax2*sizefactor,starty2+dshift+deltay2*sizefactor)), width=1.5, color='grey'),
           
    Panel(
          SVG(PANEL_A).scale(svg_scale).scale(0.9).move(0, 15*factor_svg),
          Text("a", -2.7*factor_svg, 16.9*factor_svg-dshift, **kwargs_text),
    ).move(2.7*factor_svg, -14.4*factor_svg+dshift),
           
    Panel(
        SVG(PANEL_B).scale(svg_scale).move(0*factor_svg, 0*factor_svg),
        Text("b", -6.0*factor_svg, 5*factor_svg-dshift, **kwargs_text),
       Text("PD", -1.*factor_svg+0.0, 8.2*factor_svg, **kwargs_text8pt),
       Text("LP", -1.*factor_svg+0.0, 13.4*factor_svg, **kwargs_text8pt),
       Text("PY", -1.*factor_svg+0.0, 18.6*factor_svg, **kwargs_text8pt),
        
        #Text("Period", 15.5*factor_svg+0.0, 2.8*factor_svg, **kwargs_text8pt),
        #Text("Delay", 11.3*factor_svg+0.0, 9.6*factor_svg, **kwargs_text8pt),
        #Text("Gap", 27.5*factor_svg+0.0, 9.6*factor_svg, **kwargs_text8pt),
        #Text("Duration", 19.2*factor_svg+0.0, 13.8*factor_svg, **kwargs_text8pt),
        Text("1", 17.45*factor_svg+0.0, 4.5*factor_svg, **kwargs_text8pt),
        Text("2", 13.1*factor_svg+0.0, 10.6*factor_svg, **kwargs_text8pt),
        Text("3", 28.75*factor_svg+0.0, 10.6*factor_svg, **kwargs_text8pt),
        Text("4", 21.7*factor_svg+0.0, 15.4*factor_svg, **kwargs_text8pt),
       #Text("50 mV", 39.4*factor_svg, 25*factor_svg, **kwargs_text8pt),
       #Text("50 mV", 32.0*factor_svg, 4.8*factor_svg, **kwargs_text8pt),
       Text("500 ms", 3.2*factor_svg, 22.5*factor_svg, **kwargs_text8pt),
    ).move(37.8*factor_svg, -2.5*factor_svg+dshift),
    
    Panel(
          SVG(PANEL_C).scale(svg_scale).move(-10*factor_svg,0*factor_svg),
          Text("c", -11.5*factor_svg, 2.7*factor_svg-dshift, **kwargs_text),
    ).move(90.1*factor_svg, -0.2*factor_svg+dshift),
           
    Panel(
          SVG(PANEL_C2).scale(svg_scale).move(-10*factor_svg,0*factor_svg),
        #Text("1", 3.1*factor_svg, 5.2*factor_svg, **kwargs_consistent1),
        Text("1", 11.2*factor_svg, 11.3*factor_svg, **kwargs_consistent1),
        Text("2", 7.5*factor_svg, 6.7*factor_svg, **kwargs_inconsistent),
    ).move(90*factor_svg, 35.2*factor_svg+dshift),

    Panel(
          SVG(PANEL_D).scale(svg_scale).move(0*factor_svg, 0*factor_svg),
          Text("d", 0*factor_svg, 3.5*factor_svg-dshift, **kwargs_text),
        #Text("1", 41.5*factor_svg, 4*factor_svg, **kwargs_consistent),
        Text("1", 4*factor_svg, 23.5*factor_svg, **kwargs_consistent1),
        Text("2", 41.5*factor_svg, 23.5*factor_svg, **kwargs_inconsistent),
        Text("50 mV", 68.4*factor_svg, 4*factor_svg, **kwargs_text8pt),
    ).move(0*factor_svg, 23.2*factor_svg+dshift)

)

!mkdir -p fig
f.save("fig/fig8_stg_31D.svg")
svg('fig/fig8_stg_31D.svg')