# Figures

This notebook will create the figure for the manuscript.

In [None]:
%matplotlib inline

import theano
theano.config.floatX='float64'

import delfi.distribution as dd
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pickle
import seaborn as sns

from utils import setup_sim, get_data_o

import sys; sys.path.append('../')
from common import col, svg, plot_pdf

In [None]:
!mkdir -p svg/

PANEL_A ='illustration/model.svg'
svg(PANEL_A)

## Gabor RF results

In [None]:
## training data and true parameters, data, statistics
seed = 42
idx_cell = 6 # load toy cell number 6 (cosine-shaped RF with ~1Hz firing rate)

tmp = np.load('results/SNPE/toycell_6/ground_truth_data.npy', allow_pickle=True)[()]
obs_stats, pars_true, rf = tmp['obs_stats'],  tmp['pars_true'], tmp['rf']

sim_info = np.load('results/sim_info.npy', allow_pickle=True)[()]
d, params_ls = sim_info['d'], sim_info['params_ls']

assert obs_stats[0,-1] == 299 # the cell we want to work with should have this number of spikes

labels_params = ['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']

### SNP results

In [None]:
tmp = np.load('results/SNPE/toycell_6/maprf_100k_prior01_run_1_round2_param9_nosvi_CDELFI_posterior.npy', allow_pickle=True)[()]
posterior, proposal, prior = tmp['posterior'], tmp['proposal'], tmp['prior']

plot_prior = dd.TransformedNormal(m=prior.m, S = prior.S,
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1])

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)],
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a,
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1])

lims_post = np.array([[-1.5, -1.1, .001,         0.01,          .001, 0.01, 0.01, -.999, -.999],
                 [ 1.5,  1.1, .999*np.pi, 2.49,   1.999*np.pi, 1.99, 3.99, .999,   .999]]).T

### MCMC sampler

In [None]:
n_samples=1000000
path =  'results/MCMC/'
savefile = path + 'toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param9_5min.npy'
tmp = np.load(savefile, allow_pickle=True)[()]

T, params_dict_true = tmp['T'], tmp['params_dict_true']

params_ls = ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in params_ls])

def symmetrize_sample_modes(samples):

    assert samples.ndim==2 and samples.shape[1] == 9

    # assumes phase in [0, pi]
    assert np.min(samples[:,2]) >= 0. and np.max(samples[:,2] <= np.pi)
    # assumes angle in [0, 2*pi]
    assert np.min(samples[:,4]) >= 0. and np.max(samples[:,4] <= 2*np.pi)
    # assumes freq, ratio and width > 0
    assert np.all(np.min(samples[:,np.array([3,5,6])], axis=0) >= 0.)

    samples1 = samples.copy()
    idx = np.where( samples[:,4] > np.pi )[0]
    samples1[idx,4] = samples1[idx,4] - np.pi
    idx = np.where( samples[:,4] < np.pi )[0]
    samples1[idx,4] = samples1[idx,4] + np.pi
    #samples1[:,2] = np.pi - samples1[:,2]
    samples_all = np.vstack((samples, samples1))[::2, :]

    samples1 = samples_all.copy()
    samples1[:,1] = - samples1[:,1]
    samples1[:,2] = np.pi - samples1[:,2]
    samples_all = np.vstack((samples_all, samples1))[::2, :]

    return samples_all

samples = symmetrize_sample_modes(samples)

pars_raw = np.array([ params_dict_true['glm']['bias'],
                      params_dict_true['kernel']['s']['gain'],
                      params_dict_true['kernel']['s']['phase'] + 0.05, # remove phase a bit from left interval border
                      params_dict_true['kernel']['s']['angle'],
                      params_dict_true['kernel']['s']['freq'],
                      params_dict_true['kernel']['s']['ratio'],
                      params_dict_true['kernel']['s']['width'],
                      params_dict_true['kernel']['l']['xo'],
                      params_dict_true['kernel']['l']['yo'] ])

lims_samples = np.array([[-.5, -1.1, .00001*np.pi, 0, 0.301*np.pi, 0, 0,-0.1,-0.0],
                         [ .1,  1.1, .99999*np.pi, 3, 1.699*np.pi, 3, 5, 0.4, 0.5]]).T

## Panel for summary statistics

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    plt.figure(figsize=(1.9, 1.9))
    plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='nearest', cmap='gray')
    #plt.title('spike-triggered \n average', fontsize=fontsize)
    plt.text(3, 6, 'STA', color='w', fontsize=10)
    plt.tight_layout()
    plt.axis('off')

    # option to add contours of ground-truth RF
    add_gt = False
    if add_gt:
        rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
        plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

    PANEL_B_1 = 'svg/panel_b_1.svg'
    plt.savefig(PANEL_B_1, facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.close()

svg(PANEL_B_1)

## Panel for parameters

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    plt.figure(figsize=(1.9, 1.9))
    plt.imshow(rf, interpolation='nearest', cmap='gray')
    #plt.title('ground-truth \n filter', fontsize=fontsize)
    plt.text(3, 12, 'ground-truth \n'+'receptive field', color='w', fontsize=10)
    plt.tight_layout()
    plt.axis('off')

    PANEL_B_2 = 'svg/panel_b_2.svg'
    plt.savefig(PANEL_B_2, facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.close()

svg(PANEL_B_2)

## (Partial) posterior

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    lims_samples = np.array([[-.5, -1.1, .00001*np.pi, 0, np.pi/3, 0, 0,-0.1,-0.0],
                             [ .1,  1.1, 0.9999*np.pi, 3, 5*np.pi/3, 3, 5, 0.4, 0.5]]).T

    idx = np.array([2,4,7,8])

    labels_params_select = np.array(labels_params)[idx]

    plot_post_select = dd.mixture.MoTG(ms=[x.m[idx] for x in plot_post.xs],
                                      Ss=[x.S[idx][:,idx] for x in plot_post.xs],
                                      a=plot_post.a,
                                      flags=plot_post.flags[idx],
                                      lower=plot_post.lower[idx],
                                      upper=plot_post.upper[idx]
                                     )

    plot_prior_select = dd.TransformedNormal(m=plot_prior.m[idx], S = plot_prior.S[idx][:,idx],
                                flags=plot_prior.flags[idx],
                                lower=plot_prior.lower[idx],
                                upper=plot_prior.upper[idx])

    fig, axes = plot_pdf(plot_post_select,  #pdf2=plot_prior_select,
                        lims=lims_samples[idx],
                        gt=pars_raw.reshape(-1)[idx],
                        figsize=(3.2, 3.2),
                        resolution=100,
                        samples=samples[:,idx].T,
                        levels=[0.95],
                        col1=col['MCMC'],
                        col2=col['SNPE'],
                        col4=col['GT'],
                        partial_dots=True,
                        ticks=False);

    labels_params_select = np.array([r' $y$', r'  $x$', 'angle', 'phase'])

    for i in range(idx.size):
        axes[i, i].set_xlabel(labels_params_select[::-1][i])

    axes[0, 0].set_xticks([0,np.pi])
    axes[0, 0].set_xticklabels([r'$0^\degree$',r'$180^\degree$'])
    axes[1, 1].set_xticks([1/6*np.pi,5/3*np.pi])
    axes[1, 1].set_xticklabels([r'$60^\degree$',r'$300^\degree$'])
    axes[2, 2].set_xticks([-0.1,0.4])
    axes[2, 2].set_xticklabels([-0.1, 0.4])
    axes[3, 3].set_xticks([0, 0.5])
    axes[3, 3].set_xticklabels([0, 0.5])

    sns.despine(offset=5, left=True)

    PANEL_C = 'svg/panel_c.svg'
    fig.savefig(PANEL_C, facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.close()

svg(PANEL_C)

## Posterior samples synthetic data

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):


    # this snippet of code requires the mapRF repository (to instantiate g.model)
    g, prior, d = setup_sim(seed, path='')
    filename = 'results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    rf = g.model.params_to_rf(pars_true)[0]

    lvls, n_draws=[0.2, 0.2], 10
    plt.figure(figsize=(1.9, 1.9))
    plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='nearest', cmap='gray')
    line_cols = [(x[0], x[1], x[2]) for x in sns.light_palette(col['SNPE'], 10)]
    for i in range(n_draws):
        rfm = g.model.params_to_rf(posterior.gen().reshape(-1))[0]
        plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors=[col['SNPE']], linewidth=2)
        #plt.hold(True)
    #plt.title('sampled filters', fontsize=fontsize)

    rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors=[col['GT']], linewidth=2.5)
    plt.tight_layout()
    plt.axis('off')

    plt.text(3, 10, 'receptive field \n'+'samples', color='w', fontsize=10)


    PANEL_D =  'svg/panel_d.svg'
    plt.savefig(PANEL_D, facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.close()

svg(PANEL_D)

## Posterior samples real data

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):


    # this snippet of code requires the mapRF repository (to instantiate g.model)
    g, prior, d = setup_sim(seed, path='')
    filename = 'results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    rf = g.model.params_to_rf(pars_true)[0]

    lvls, n_draws=[0.2, 0.2], 10
    plt.figure(figsize=(3.4, 3.4))
    plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='nearest', cmap='gray')
    line_cols = [(x[0], x[1], x[2]) for x in sns.light_palette(col['SNPE'], 10)]
    for i in range(n_draws):
        rfm = g.model.params_to_rf(posterior.gen().reshape(-1))[0]
        plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors=[col['SNPE']], linewidth=2)
        #plt.hold(True)
    #plt.title('sampled filters', fontsize=fontsize)

    plt.tight_layout()
    plt.axis('off')

    plt.text(1, 2.5, 'real data receptive field samples', color='w', fontsize=10)


    PANEL_E =  'svg/panel_e.svg'
    plt.savefig(PANEL_E, facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.close()

svg(PANEL_E)

## Compose figure

In [None]:
from svgutils.compose import *

# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.
# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964
svg_scale = 1.25  # set this to 1.25 for Inkscape, 1.0 otherwise

# Panel letters in Helvetica Neue, 12pt, Medium
kwargs_text = {'size': '12pt', 'font': 'Arial', 'weight': '800'}

pxw = 720
pxh = 760


f = Figure("20.3cm", "12.3cm",

    Panel(
          SVG(PANEL_A).scale(svg_scale).move(20,-15),
          Text("A", 0, 13, **kwargs_text),
    ).move(0, 0),

    Panel(
          SVG(PANEL_B_1).scale(svg_scale).move(15,0),
          SVG(PANEL_B_2).scale(svg_scale).move(15, 137),
          Text("B", 0, 22, **kwargs_text),
    ).move(0, 185),

    Panel(
        SVG(PANEL_C).scale(svg_scale).move(25, 2),
        Text("C", -10, 22, **kwargs_text),
    ).move(180, 185),

    Panel(
        SVG(PANEL_D).scale(svg_scale).move(0, 2),
        Text("D", -10, 22, **kwargs_text),
    ).move(180, 185+137),

    Panel(
          SVG(PANEL_E).scale(svg_scale).move(12, 2),
          Text("E", -5, 22, **kwargs_text),
    ).move(470, 185),
           
    #Grid(10, 10),
)

f.save("fig/fig3_gabor.svg")
svg('fig/fig3_gabor.svg')

## Supplementary figure: full posterior

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):

    labels_params[7] = r'$x$'
    labels_params[8] = r'$y$'

    fig, axes = plot_pdf(plot_post,
                      #pdf2=plot_prior,
                      lims=lims_post,
                      gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1),
                      figsize=(12, 12),
                      #resolution=100,
                      samples=samples.T,
                      #levels=np.asarray([0.95]),
                      col1=col['MCMC'],
                      col2=col['SNPE'],
                      #col3=col['PRIOR'],
                      col4=col['GT'],
                      labels_params=labels_params)

    for i in range(plot_post.ndim):
        axes[i,i].set_xticks([lims_post[i, 0], lims_post[i, 1]])
        axes[i,i].set_yticks([])

    sns.despine(offset=5, left=True)
    
    SUPP_1 = 'fig/fig3_gabor_supp_posterior.svg'
    fig.savefig(SUPP_1, transparent=True)
    plt.close()

svg(SUPP_1)