In [None]:
import torch
import numpy as np
from os.path import join
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import matplotlib as mpl

In [None]:
data = np.load(join("plotting_data/camera_samples.npz"))
theta_test,\
obs_test,\
gatsbi_samples,\
npe_samples = data["theta_test"],\
              data["obs_test"],\
              data["gatsbi_samples"],\
              data["npe_samples"]

In [None]:
gt_colors = "Greys"
npe_colors = "Greys"
npe_colors_std = "Greys"

def add_cbar(cax, im, lim_min, lim_max):
    cbar = plt.colorbar(im, cax=cax, ticks=[lim_min, lim_max])
    cbar.ax.tick_params(labelsize=40)
    cbar.outline.set_visible(False)
    return cbar

def hide_ax(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    return ax

with mpl.rc_context(fname='./matplotlibrc'):

    fig, axes = plt.subplots(6, 12, figsize=(50, 22.5))
    left_adjust = .96

    fig.subplots_adjust(right=.95, wspace=.1)
    caxs = []
    caxs.append(fig.add_axes([left_adjust, .79, .01, .1]))
    caxs.append(fig.add_axes([left_adjust, .655, .01, .1]))
    caxs.append(fig.add_axes([left_adjust, .52, .01, .1]))
    caxs.append(fig.add_axes([left_adjust, .385, .01, .1]))
    caxs.append(fig.add_axes([left_adjust, .25, .01, .1]))
    caxs.append(fig.add_axes([left_adjust, .115, .01, .1]))
    # 135
    # imshow settings
    fontname = "Arial"
    fontsize = 50
    origin = 'lower'
    mn_vmin = 0.
    mn_vmax = 1.

    std_vmin = 0.
    std_vmax = .1

    for i, ax in enumerate(axes.T):
        for j, (samp, tit) in enumerate(zip([theta_test, obs_test, [gatsbi_samples, npe_samples]],
                                            [r"$\theta$", r"$x_{o}$", "Mean"]
                                           )
                                       ):
            # Set ylabel
            if i == 0:
                ax[j].set_ylabel(tit, fontsize=fontsize, fontname=fontname)
                ax[4].set_ylabel("Mean", fontsize=fontsize, fontname=fontname)
                
                ax[3].set_ylabel("SD", fontsize=fontsize, fontname=fontname)
                ax[5].set_ylabel("SD", fontsize=fontsize, fontname=fontname)
                
            # Plot GT theta and obs
            if j < 2:
                im = ax[j].imshow((samp[i].squeeze()).T[::-1], 
                                  origin=origin,
                                  vmax=mn_vmax, 
                                  vmin=mn_vmin, 
                                  cmap=gt_colors)
                cbar = add_cbar(caxs[j], im, mn_vmin, mn_vmax)
                hide_ax(ax[j])
                
            # Plot NPE / GATSBI samples
            else:
                for k, (ss, mean_vmax, stdev_vmax) in enumerate(zip(samp, [1., 10.], [.1, 1000.])):
#                     mean = np.mean(npe_samples[i], 0).squeeze()
#                     std = np.std(npe_samples[i], 0).squeeze()
                    mean = np.mean(ss[i], 0).squeeze()
                    std = np.std(ss[i], 0).squeeze()
                    im = ax[j+ 2*k].imshow(mean.T[::-1], 
                                        origin=origin,
                                        vmax=mean_vmax,
                                        vmin=mn_vmin, 
                                        cmap=npe_colors)
                    cbar = add_cbar(caxs[j + 2*k], im, mn_vmin, mean_vmax)
                    hide_ax(ax[j + 2*k])

                    im = ax[j + 2*k + 1].imshow(std.T[::-1],
                                                vmax=stdev_vmax,
                                                vmin=std_vmin,
                                                origin=origin,
                                                cmap=npe_colors_std)
                    cbar = add_cbar(caxs[j + 2*k + 1], im, std_vmin, stdev_vmax)
                    hide_ax(ax[j + 2*k + 1])

    fig.text(0.095, .76, 
             "Groundtruth", 
             rotation='vertical', 
             va='center', 
             fontsize=fontsize,
             fontname=fontname,
             fontweight="bold")

    fig.text(0.095, .5, 
             "GATSBI", 
             rotation='vertical', 
             va='center', 
             fontsize=fontsize,
             fontname=fontname,
             fontweight="bold")
    
    fig.text(0.095, .23, 
             "NPE", 
             rotation='vertical', 
             va='center', 
             fontsize=fontsize,
             fontname=fontname,
             fontweight="bold")

    ax = hide_ax(plt.gca())

    plt.savefig("plots/Figure5.pdf")
