In [None]:
import delfi.distribution as dd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import sys; sys.path.append('../')
from common import col, svg, plot_pdf, samples_nd

In [None]:
# MCMC Samples
idx_cell = 6 # load toy cell number 6 (cosine-shaped RF with ~1Hz firing rate)
n_samples=1000000
path =  'results/MCMC/'
savefile = path + 'toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param9_5min.npy'
tmp = np.load(savefile, allow_pickle=True)[()]
T, params_dict_true = tmp['T'], tmp['params_dict_true']
params_ls = ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in params_ls])
def symmetrize_sample_modes(samples):
    assert samples.ndim==2 and samples.shape[1] == 9 
    # assumes phase in [0, pi]
    assert np.min(samples[:,2]) >= 0. and np.max(samples[:,2] <= np.pi)
    # assumes angle in [0, 2*pi]
    assert np.min(samples[:,4]) >= 0. and np.max(samples[:,4] <= 2*np.pi)
    # assumes freq, ratio and width > 0
    assert np.all(np.min(samples[:,np.array([3,5,6])], axis=0) >= 0.)

    samples1 = samples.copy()
    idx = np.where( samples[:,4] > np.pi )[0]
    samples1[idx,4] = samples1[idx,4] - np.pi
    idx = np.where( samples[:,4] < np.pi )[0]
    samples1[idx,4] = samples1[idx,4] + np.pi
    #samples1[:,2] = np.pi - samples1[:,2]
    samples_all = np.vstack((samples, samples1))[::2, :]

    #samples1 = samples_all.copy()
    #samples1[:,1] = - samples1[:,1] 
    #samples1[:,2] = np.pi - samples1[:,2] 
    #samples_all = np.vstack((samples_all, samples1))[::2, :]

    return samples_all
samples = symmetrize_sample_modes(samples)

# Posterior
tmp = np.load('results/SNPE/toycell_6/maprf_100k_prior01_run_1_round2_param9_nosvi_CDELFI_posterior.npy', allow_pickle=True)[()]
posterior, proposal, prior = tmp['posterior'], tmp['proposal'], tmp['prior']

# Ground truth
tmp = np.load('results/SNPE/toycell_6/ground_truth_data.npy', allow_pickle=True)[()]
obs_stats, pars_true, rf = tmp['obs_stats'],  tmp['pars_true'], tmp['rf']

labels_params = ['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', r'$x$', r'$y$']

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)],
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a,
                            flags=[0,1,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1])

lims_post = np.array([[-3, 0, .001, 0,       .001, 0, 0, -.999, -.999], 
                 [ 3,  5.5, .999*np.pi, 3, 1.999*np.pi, 3, 3, .999,   .999]]).T

In [None]:
# with plot_pdf
with mpl.rc_context(fname='../.matplotlibrc'):
    fig, axes = plot_pdf(plot_post,
                      #pdf2=plot_prior,
                      lims=lims_post,
                      gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1),
                      figsize=(12, 12),
                      #resolution=100,
                      contours=True,
                      levels=(0.95,),
                      samples=samples.T,
                      col1=col['MCMC'],
                      col2=col['SNPE'],
                      col3=col['PRIOR'],
                      col4=col['GT'],
                      labels_params=labels_params)

    for i in range(plot_post.ndim):
        axes[i,i].set_xticks([lims_post[i, 0], lims_post[i, 1]])
        axes[i,i].set_yticks([])

    sns.despine(offset=5, left=True)
    
    SUPP_1 = 'fig/fig3_gabor_supp_posterior.svg'
    fig.savefig(SUPP_1, transparent=True)
    plt.close()

svg(SUPP_1)