# Common code for NIPS figures

In [2]:
%matplotlib inline
from __future__ import division

import colormaps as cmaps
import IPython.display as IPd
import likelihoodfree.io as io
import likelihoodfree.viz as viz
import likelihoodfree.PDF as lfpdf
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy.signal as ss
import socket
import svgutil

from math import factorial
from mpl_toolkits.axes_grid1 import make_axes_locatable
from svgutil.compose import Unit

In [None]:
COL = {}
COL['GT']   = (35/255,86/255,167/255)
COL['SNPE'] = (0, 174/255,239/255)
COL['ESS']  = (244/255, 152/255, 25/255)
COL['IBEA']  = (102/255, 179/255, 46/255)
COL['EFREE'] = (105/255, 105/255, 105/255)

LABELS_HH =[r'$g_{Na}$', r'$g_{K}$', r'$g_{l}$', r'$E_{Na}$', r'$-E_{K}$', r'$-E_{l}$',
            r'$g_{M}$', r'$t_{max}$', r'$k_{b_{n1}}$', r'$k_{b_{n2}}$', r'$V_{T}$', r'$noise$']

In [None]:
# shorthands
sc = svgutil.compose
st = svgutil.transform

# conversion
def cm2cm(cm):
    return cm * 1

def cm2in(cm):
    return cm * 1 / Unit.per_inch['cm']

def cm2mm(cm):
    return cm * Unit.per_inch['mm'] / Unit.per_inch['cm']

def cm2pt(cm):
    return cm * Unit.per_inch['pt'] / Unit.per_inch['cm']

def cm2px(cm):
    return cm * Unit.per_inch['px'] / Unit.per_inch['cm']

In [None]:
def svg(img):
    IPd.display(IPd.SVG(img))

In [None]:
SCALE_PX = 6
SCALE_IN = 2.25
FIG_WIDTH_MM = 160

def mm2px(mm, scale=SCALE_PX):
    return scale*mm

def mm2inch(mm):
    return mm*0.0393701

mm2inches = mm2inch

def create_fig(width_mm, height_mm):
    return st.SVGFigure(mm2px(FIG_WIDTH_MM), mm2px(FIG_HEIGHT_MM))

def add_label(fig, letter, x_pos_mm=0, y_pos_mm=0, font_size_px=18, weight='bold'):
    fig.append(st.TextElement(mm2px(x_pos_mm),  # location of letter in x direction
                              mm2px(y_pos_mm),  # location of letter in y direction
                              letter,  # letter
                              size=font_size_px,  # in px of font
                              weight=weight))
    return fig

def add_grid(fig, x_spacing_mm=10, y_spacing_mm=10, font_size_px=10, width_px=1):
    fig.append(sc.Grid(mm2px(x_spacing_mm),  # in mm, spacing of grid in x direction
                       mm2px(y_spacing_mm),  # in mm, spacing of grid in y direction
                       size=font_size_px,  # in px of font 
                       width=width_px,  # in px of grid lines
                       xmax=int(fig.width), 
                       ymax=int(fig.height),
                       multiply=1/SCALE_PX))
    return fig

def add_svg(fig, filename, x_pos_mm=0, y_pos_mm=0, scale=1, verbose=False):
    svg_file = st.fromfile(filename)
    width, height = svg_file.get_size()
    if verbose:
        print('size of svg of {} : {}'.format(filename, (width, height)))
    svg_root = svg_file.getroot()
    svg_root.moveto(mm2px(x_pos_mm), mm2px(y_pos_mm), scale=scale)
    fig.append([svg_root])
    return fig

def get_num(x):
    if type(x) == str:
        return float(''.join(ele for ele in x if ele.isdigit() or ele == '.'))
    else:
        return None

In [None]:
HOSTNAME = socket.gethostname()

INKSCAPE = 'inkscape'
if HOSTNAME == 'nsa3004':  # jm workstation
    PATH_DROPBOX = '/home/jm/Mackelab/team/Write/Manuscripts/2017_NIPS_NeuralModelInference/'
elif HOSTNAME == 'nsa3010':  # pedro workstation
    PATH_DROPBOX = '/home/pedro/Mackelab/team/Write/Manuscripts/2017_NIPS_NeuralModelInference/'
elif HOSTNAME == 'Pep.local':  # pedro macbook
    PATH_DROPBOX = '/Users/pedro/Mackelab/team/Write/Manuscripts/2017_NIPS_NeuralModelInference/'
    INKSCAPE = '/Applications/Inkscape.app/Contents/Resources/script'
elif HOSTNAME == 'jml.local':  # jm macbook
    PATH_DROPBOX = '/Users/jm/Mackelab/team/Write/Manuscripts/2017_NIPS_NeuralModelInference/'
elif HOSTNAME == 'nsa2002.local':  # no, not the nsa
    PATH_DROPBOX = '/Users/kaan/Dropbox/2017_NIPS_NeuralModelInference/'
else:
    raise ValueError('Unknown hostname {}, add in if-else block'.format(HOSTNAME))

PATH_DROPBOX_FIGS = PATH_DROPBOX + 'figs/'

MPL_RC = 'NIPS2017.rc'

In [None]:
dirs = {}
for model in ['gauss', 'mog', 'glm', 'autapse', 'hh']:
    dirs['dir_nets_{}'.format(model)] = '../results/'+model+'/nets/'
    dirs['dir_sampler_{}'.format(model)] = '../results/'+model+'/sampler/'
    dirs['dir_genetic_{}'.format(model)] = '../results/'+model+'/genetic/'

In [None]:
dirs_dropbox = {}
for fig in ['fig1', 'fig2', 'fig3', 'fig4', 'fig5']:
    dirs_dropbox['dir_nets_{}'.format(fig)] = PATH_DROPBOX + 'results/'+fig+'/'

In [None]:
def plot_pdf(pdf, lims, gt=None, contours=False, levels=(0.68, 0.95),
        resolution=500, labels_params=None, ticks = False, diag_only=False, diag_only_cols=4,
        diag_only_rows=4, figsize=(5,5), fontscale=1, partial=False, samples=None):
    """Plots marginals of a pdf, for each variable and pair of variables.

    Parameters
    ----------
    pdf : object
    lims : array
    contours : bool
    levels : tuple
        For contours
    resolution
    labels_params : array of strings
    ticks: bool
        If True, includes ticks in plots
    diag_only : bool
    diag_only_cols : int
        Number of grid columns if only the diagonal is plotted
    diag_only_rows : int
        Number of grid rows if only the diagonal is plotted
    fontscale: int
    partial: bool
        If True, plots partial posterior with at the most 3 parameters.
        Only available if `diag_only` is False
    samples: array
        If given, samples of a distribution are plotted along `pdf`.
        If given, `pdf` is plotted with default `levels` (0.68, 0.95), if provided `levels` is None.
        If given, `lims` is overwritten and taken to be the respective
        limits of the samples in each dimension.
        
    Todo
    ----
    - Option to pass samples that will be plotted on diagonal under
      the posterior
    - Post NIPS: merge back into likelihoodfree.viz
    """
    
    if samples is not None:
        contours = True
        if levels is None:
            levels=(0.68, 0.95)
        lims_min = np.min(samples,axis=1)
        lims_max = np.max(samples,axis=1)
        lims = np.asarray(lims)
        lims = np.concatenate((lims_min.reshape(-1,1),lims_max.reshape(-1,1)),axis=1)
    else:
        lims = np.asarray(lims)
        lims = np.tile(lims, [pdf.ndim, 1]) if lims.ndim == 1 else lims

    if pdf.ndim == 1:

        fig, ax = plt.subplots(1, 1, facecolor='white', figsize=figsize)
                
        if samples is not None:
            ax.hist(samples[i,:], bins=100, normed = True, 
                     color=COL['ESS'],
                     edgecolor=COL['ESS'])

        xx = np.linspace(lims[0,0], lims[0,1], resolution)

        pp = pdf.eval(xx[:, np.newaxis], log=False)
        ax.plot(xx, pp, color=COL['SNPE'])
        ax.set_xlim(lims[0])
        ax.set_ylim([0, ax.get_ylim()[1]])
        if gt is not None: ax.vlines(gt, 0, ax.get_ylim()[1], color='r')
            
        if ticks:
            ax.get_yaxis().set_tick_params(which='both', direction='out')
            ax.get_xaxis().set_tick_params(which='both', direction='out')
            ax.set_xticks(np.linspace(lims[0, 0], lims[0, 1],2))
            ax.set_yticks(np.linspace(min(pp),max(pp),2))
            ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
            ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
        else:
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])

    else:

        if not diag_only:           
            if partial:
                rows = min(3, pdf.ndim)
                cols = min(3, pdf.ndim)
            else:
                rows = pdf.ndim
                cols = pdf.ndim
        else:
            cols = diag_only_cols
            rows = diag_only_rows
            r = 0
            c = -1

        fig, ax = plt.subplots(rows, cols, facecolor='white', figsize=figsize)
        ax = ax.reshape(rows, cols)

        for i in range(rows):
            for j in range(cols):

                        
                #if i == 2 and j == 0 and partial:
                #    ax[i, j].text(0.2, 0.2, r'$\mathbb{R}^{'+str(len(labels_params))+'}$',
                #                  fontsize=fontscale*20)
                
                if i == j:                   
                    if samples is not None:
                        ax[i, j].hist(samples[i,:], bins=100, normed = True, 
                                 color=COL['ESS'],
                                 edgecolor=COL['ESS'])
                    xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
                    pp = pdf.eval(xx, ii=[i], log=False)

                    if diag_only:
                        c+=1
                    else:
                        r = i
                        c = j

                    ax[r, c].plot(xx, pp, color=COL['SNPE'])
                    ax[r, c].set_xlim(lims[i])
                    ax[r, c].set_ylim([0, ax[r, c].get_ylim()[1]])                  
                    
                    if gt is not None: ax[r, c].vlines(gt[i], 0, ax[r, c].get_ylim()[1], color='r')
                    
                    if ticks:
                        ax[r, c].get_yaxis().set_tick_params(which='both', direction='out',
                                                             labelsize = fontscale*20)
                        ax[r, c].get_xaxis().set_tick_params(which='both', direction='out',
                                                             labelsize = fontscale*20)
#                         ax[r, c].locator_params(nbins=3)
                        ax[r, c].set_xticks(np.linspace(lims[i, 0], lims[j, 1],2))
                        ax[r, c].set_yticks(np.linspace(min(pp),max(pp),2))
                        ax[r, c].xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
                        ax[r, c].yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
                    else:
                        ax[r, c].get_xaxis().set_ticks([])
                        ax[r, c].get_yaxis().set_ticks([])
                        
                    if labels_params is not None:
                        ax[r, c].set_xlabel(labels_params[i], fontsize=fontscale*15)
                    else:
                        ax[r, c].set_xlabel([])
                    
                    x0,x1 = ax[r, c].get_xlim()
                    y0,y1 = ax[r, c].get_ylim()
                    ax[r, c].set_aspect((x1-x0)/(y1-y0))
                    
                    
                    if partial and i == rows-1:
                        ax[i, j].text(x1+(x1-x0)/6., (y0+y1)/2., '...', fontsize=fontscale*25)
                        plt.text(x1+(x1-x0)/8.4, y0-(y1-y0)/6., '...', fontsize=fontscale*25,rotation=-45)

                else:
                    if diag_only:
                        continue

                    if i>j:
                        ax[i, j].get_yaxis().set_visible(False)
                        ax[i, j].get_xaxis().set_visible(False)
                        ax[i,j].set_axis_off()
                        continue
                        
                    if samples is not None:
                        H, xedges, yedges = np.histogram2d(samples[i,:], samples[j,:], bins=30, normed = True)
                        ax[i, j].imshow(H.T, origin='lower',
                                   extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])

                    xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
                    yy = np.linspace(lims[j ,0], lims[j, 1], resolution)
                    X, Y = np.meshgrid(xx, yy)
                    xy = np.concatenate([X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
                    pp = pdf.eval(xy, ii=[i, j], log=False)
                    pp = pp.reshape(list(X.shape))
                    if contours:
                        ax[i, j].contour(X, Y, viz.probs2contours(pp, levels), levels, colors=('w','y'))
                    else:
                        ax[i, j].imshow(pp,origin='lower', cmap = cmaps.parula,
                                        extent=[lims[i, 0],lims[i, 1],lims[j ,0],lims[j ,1]],
                                        aspect='auto', interpolation='none')
                    ax[i, j].set_xlim(lims[i])
                    ax[i, j].set_ylim(lims[j])

                    if gt is not None: ax[i, j].plot(gt[i], gt[j], 'r.', ms=10, markeredgewidth=0.0)
                        
                    ax[i, j].get_xaxis().set_ticks([])
                    ax[i, j].get_yaxis().set_ticks([])
                    ax[i,j].set_axis_off()
                        
                    x0,x1 = ax[i, j].get_xlim()
                    y0,y1 = ax[i, j].get_ylim()
                    ax[i, j].set_aspect((x1-x0)/(y1-y0))
                    
                    if partial and j == cols-1:
                        ax[i, j].text(x1+(x1-x0)/6., (y0+y1)/2., '...', fontsize=fontscale*25)

                        
                if diag_only and c == cols-1:
                    c = -1
                    r += 1


    
    return fig, ax