#  Plot a model with many observables (no external parameter dependence)


We will generate a large number of datapoints using our trained model, and compare them with our trained model.


In [1]:
#  Required imports

print("Importing standard library")
import os, sys, time

print("Importing python data libraries")
import numpy as np
from   matplotlib import pyplot as plt, colors

print("Importing third party libraries")
import dill as pickle

print("Importing custom backends")
sys.path.append("/Users/Ste/PostDoc/git-with-DP/SM-sandbox/proj5.4-EFT-Density-Estimation")
from backends.density_model    import DensityModel
from backends.plot             import histo_to_line, plot_data, plot_ratio, plot_pull, get_ratio_1D
from backends.stats            import whiten_axes, unwhiten_axes

from backends import plot as plot, density_model as density_model, VBFZ_analysis as VBFZ


Importing standard library
Importing python data libraries
Importing third party libraries
Importing custom backends


In [7]:
#  Inputs config

input_fname = "../Data/SM_EWK_1M_rivet_output.pickle"

#    v4
###  IF REGENERATING DATA MAKE SURE THAT TRANSFORMATIONS ARE OFF (AS MISSING OBSERVABLES)
###    AND OBSERVABLE ORDER IS REVERSED
#load_whitening_funcs = ".whitening_funcs_paper_0D_v4.pickle"
#load_model_dir       = ".EWK_density_model_paper_0D_v4"
#remove_observables   = ["rap_jj", "pT_jj", "m_ll", "m_jj", "N_jets", "N_gap_jets", "Dy_j_j", "Dphi_j_j"]

#    v5
###  IF REGENERATING DATA MAKE SURE THAT TRANSFORMATIONS ARE OFF (AS NO pT_jj)
###    AND OBSERVABLE ORDER IS REVERSED
#load_whitening_funcs = ".whitening_funcs_paper_0D_v5.pickle"
#load_model_dir       = ".EWK_density_model_paper_0D_v5"
#remove_observables   = ["pT_jj", "m_ll", "Dy_j_j", "Dphi_j_j"]

#    v7
###  IF REGENERATING DATA MAKE SURE THAT TRANSFORMATIONS ARE OFF (AS NO pT_jj)
###    AND OBSERVABLE ORDER IS REVERSED
#load_whitening_funcs = ".whitening_funcs_paper_0D_v7.pickle"
#load_model_dir       = ".EWK_density_model_paper_0D_v7"
#remove_observables   = ["pT_jj"]

#    v8
#   relatively wide, 30 gauss
###  IF REGENERATING DATA MAKE SURE THAT TRANSFORMATIONS ARE OFF (AS NO pT_jj)
###    AND OBSERVABLE ORDER IS REVERSED
#load_whitening_funcs = ".whitening_funcs_paper_0D_v8.pickle"
#load_model_dir       = ".EWK_density_model_paper_0D_v8"
#remove_observables = ["pT_jj", "N_jets", "N_gap_jets", "m_ll", "Dy_j_j"]

#    v9
#   relatively wide, 50 gauss, first 4 obs trained
###  IF REGENERATING DATA MAKE SURE THAT TRANSFORMATIONS ARE OFF (AS NO pT_jj)
###    AND OBSERVABLE ORDER IS REVERSED
#load_whitening_funcs = ".whitening_funcs_paper_0D_v10.pickle"
#load_model_dir       = ".EWK_density_model_paper_0D_v10"
#remove_observables = ["pT_jj", "N_jets", "N_gap_jets", "m_ll", "Dy_j_j"]

tag = "GQR1p0"

load_whitening_funcs = ".whitening_funcs_paper_0D_nominal.pickle"
remove_observables = ["pT_jj", "N_jets", "N_gap_jets", "m_ll", "Dy_j_j"]
load_model_dir       = f".EWK_density_model_paper_0D_{tag}"

In [3]:
#  Configure VBFZ observables
#
VBFZ.configure(remove_observables)
print(f"Configured with {VBFZ.num_observables} observables: " + ", ".join(VBFZ.observables))

#  Configure plot functions with observable information
#
plot.int_observables   = VBFZ.int_observables
plot.observable_limits = VBFZ.transformed_observable_limits


Configured with 7 observables: rap_ll, rap_jj, pT_ll, pT_j2, pT_j1, m_jj, Dphi_j_j


In [4]:
#  Load and format the data
#
data_table = VBFZ.load_table(input_fname)


Loading events from file ../Data/SM_EWK_1M_rivet_output.pickle
 -- Table created with 1000000 events
 -- filtering observable m_ll between 75 and 105
 -- 660799 events survived
 -- filtering observable pT_ll between 0 and 900
 -- 660766 events survived
 -- filtering observable theta_ll between 0 and 3.141592653589793
 -- 660766 events survived
 -- filtering observable rap_ll between 0 and 2.2
 -- 652255 events survived
 -- filtering observable m_jj between 150 and 5000
 -- 643177 events survived
 -- filtering observable pT_jj between 0 and 900
 -- 643177 events survived
 -- filtering observable theta_jj between 0 and 3.141592653589793
 -- 643177 events survived
 -- filtering observable rap_jj between 0 and 4.4
 -- 643177 events survived
 -- filtering observable pT_j1 between 40 and 1200
 -- 643054 events survived
 -- filtering observable pT_j2 between 40 and 1200
 -- 641867 events survived
 -- filtering observable Dy_j_j between 0 and 8.8
 -- 641867 events survived
 -- filtering observ

In [5]:
#  Load whitening funcs if a file was provided
#  -  this is faster when re-running with the same data and whitening settings later on

print(f"Loading whitening functions from file {load_whitening_funcs}")
whitening_funcs = pickle.load(open(load_whitening_funcs, "rb"))


#  Separate data from weights
true_data, true_data_weights = data_table.get_observables_and_weights()


#  Transform data
print("Transforming data")
transformed_data = VBFZ.transform_observables_fwd(true_data, data_table.keys)


#  Whiten data
print("Projecting data onto latent space")
white_data, _ = whiten_axes(transformed_data, data_table.types, whitening_funcs=whitening_funcs)


Loading whitening functions from file .whitening_funcs_paper_0D_nominal.pickle
Transforming data
Projecting data onto latent space


In [6]:
#  Load model
#

print(f"Loading density model from file {load_model_dir}")
density_model = DensityModel.from_dir(load_model_dir)


Loading density model from file .EWK_density_model_paper_0D_GQR0p1


RuntimeError: Required entry 'learning_rate_evo_factor' not found in file '.EWK_density_model_paper_0D_GQR0p1/density_model.pickle'

In [None]:
#  Generate a large number of datapoints at 0.
#

n_gen = 1000000

print(f"Generating {n_gen} fake datapoints")
start = time.time()
fake_white_datapoints = density_model.sample(n_gen, [1.], n_processes=8)
end = time.time()
print(f"{n_gen} datapoints generated in {int(end-start):.0f}s")

#  Unwhiten generated data
#

print("Unwhitening fake datapoints")
start = time.time()
fake_transformed_datapoints = unwhiten_axes(fake_white_datapoints, whitening_funcs)
fake_datapoints = VBFZ.transform_observables_back(fake_transformed_datapoints, data_table.keys)
end = time.time()
print(f"{n_gen} datapoints unwhitened in {int(end-start):.0f}s")


In [None]:

log_observables = "pT_j1", "pT_j2", "pT_jj", "pT_ll", "m_jj", "rap_jj"

VBFZ.obs_ticks ["rap_jj"] = [1, 2.5, 4]
VBFZ.obs_ticklabels ["rap_jj"] = ["1", "2.5", "4"]

log_axis_functions = (lambda x : x**(1./3.), lambda x : x*x*x)
        
        
def get_bins_latent (obs, num_bins=20) :
    #global int_observables, transformed_observable_limits  #  VBFZ-tag
    transformed_observable_limits = VBFZ.transformed_observable_limits  #  NB-tag
    observable_limits             = VBFZ.observable_limits  #  NB-tag
    int_observables               = VBFZ.int_observables   #  NB-tag
    if obs in int_observables :
        obs_lims = observable_limits[obs]
        #obs_lims = transformed_observable_limits[obs]
        return np.linspace(obs_lims[0]-0.5, obs_lims[1]+0.5, 2+(obs_lims[1]-obs_lims[0]))
    return np.linspace(-5, 5, num_bins+1)

def get_bins_physical (obs, num_bins=20, base=np.e) :
    #global int_observables, observable_limits  #  VBFZ-tag
    observable_limits = VBFZ.observable_limits  #  NB-tag
    int_observables   = VBFZ.int_observables   #  NB-tag
    obs_lims = observable_limits[obs] 
    if obs in int_observables : 
        return np.linspace(obs_lims[0]-0.5, obs_lims[1]+0.5, 2+(obs_lims[1]-obs_lims[0]))
    if obs in log_observables :
        log_physical_limits = np.log(np.array(obs_lims) + 10) / np.log(base)
        bins = np.exp(np.linspace(log_physical_limits[0], log_physical_limits[1], num_bins+1)*np.log(base)) - 10
        if np.fabs(bins[0]) < 1e-15 : bins[0] = 0
        return bins
    return np.linspace(obs_lims[0], obs_lims[1], num_bins+1)
            
            
def get_bins (obs, is_latent=False, num_bins=20) :
    if is_latent :
        return get_bins_latent (obs, num_bins=num_bins)
    return get_bins_physical (obs, num_bins=num_bins)



def get_obs_label (obs) :
    return VBFZ.get_obs_label(obs)
    
def get_obs_ticks (obs, is_latent=False) :
    #global int_observables  #  VBFZ-tag
    int_observables = VBFZ.int_observables   #  NB-tag
    if is_latent :
        if obs not in int_observables : return np.array([-3, 0, 3])
        return VBFZ.get_obs_ticks(obs)
    return VBFZ.get_obs_ticks(obs)

def get_obs_ticklabels (obs, is_latent=False) :
    #global int_observables  #  VBFZ-tag
    int_observables = VBFZ.int_observables   #  NB-tag
    if is_latent :
        if obs not in int_observables : return np.array(["-3", "0", "3"])
        return VBFZ.get_obs_ticklabels(obs)
    return VBFZ.get_obs_ticklabels(obs)


def get_obs_for_2D_plot (observables) :
    num_observables = len(observables)
    obs_to_plot = []
    for obs_idx_x, obs_x in enumerate(observables) :
        if obs_idx_x == num_observables-1 : continue  #  Don't plot observable -1 on x axis
        for obs_idx_y, obs_y in enumerate(observables) :
            if obs_idx_y == 0         : continue   #  Don't plot observable 0 on y axis
            if obs_idx_y <= obs_idx_x : continue   #  Don't plot above diagonal on y axis
            obs_to_plot.append((obs_idx_x, obs_x, obs_idx_y, obs_y))
    return obs_to_plot


In [None]:
def plot_2D_projections (datapoints, weights=None, label="", savefig="", is_latent=False, num_bins=20, vmin=1e-5) :
    """plot the 2D projections of the datapoints provided"""
    
    #global observables, num_observables, observable_limits, transformed_observable_limits, int_observables, log_observables   # VBFZ-tag
    observables      , num_observables = VBFZ.observables      , VBFZ.num_observables    # NB-tag
    observable_limits, int_observables = VBFZ.observable_limits, VBFZ.int_observables    # NB-tag
    transformed_observable_limits = VBFZ.transformed_observable_limits    # NB-tag
    #
    #  If no weights provided then assume uniform
    #
    if type(weights) == type(None) :
        weights = np.ones(shape=(datapoints.shape[0],))
    #
    #  Save the list of indices to plot (to make sure all loops are over consistent sets)
    #
    norm_const = {}
    obs_to_plot = get_obs_for_2D_plot (observables)
    for obs_idx_x, obs_x, obs_idx_y, obs_y in obs_to_plot :
        bins_x, bins_y = get_bins(obs_x, is_latent=is_latent, num_bins=num_bins), get_bins(obs_y, is_latent=is_latent, num_bins=num_bins)
        vals, _, _     = np.histogram2d(datapoints[:,obs_idx_x], datapoints[:,obs_idx_y], weights=weights, bins=(bins_x, bins_y))
        norm_const[(obs_idx_x, obs_idx_y)] = np.nanmax(vals.flatten())
    #
    #  Make plot
    #
    fig = plt.figure(figsize=(20, 14))
    for obs_idx_x, obs_x, obs_idx_y, obs_y in obs_to_plot :
        xlo    = obs_idx_x / (num_observables-1)    #  Get axis x coordinates
        xwidth = 1.        / (num_observables-1)
        ylo     = (num_observables-obs_idx_y-1) / (num_observables-1)   #  Get axis y coordinates
        yheight = 1.                            / (num_observables-1)
        #
        #  Create axis
        #
        ax = fig.add_axes([xlo, ylo, 0.95*xwidth, 0.95*yheight])
        #
        #  Format log axes
        #
        if not is_latent :
            if obs_x in log_observables : ax.set_xscale("function", functions=log_axis_functions )
            if obs_y in log_observables : ax.set_yscale("function", functions=log_axis_functions )
        #
        #  Draw axis ticks and labels
        #
        ax.set_xticks(get_obs_ticks(obs_x, is_latent=is_latent))
        ax.set_yticks(get_obs_ticks(obs_y, is_latent=is_latent))
        if obs_idx_y == num_observables-1 : 
            ax.get_xaxis().set_ticklabels(get_obs_ticklabels(obs_x, is_latent=is_latent))
            ax.set_xlabel(get_obs_label(obs_x).replace("  [","\n["), fontsize=19, labelpad=20, va="top", ha="center")
        else :
            ax.get_xaxis().set_ticklabels([])
        if obs_idx_x == 0 : 
            ax.get_yaxis().set_ticklabels(get_obs_ticklabels(obs_y, is_latent=is_latent))
            ax.set_ylabel(get_obs_label(obs_y).replace("  [","\n["), fontsize=19, labelpad=20, rotation=0, va="center", ha="right")
        else :
            ax.get_yaxis().set_ticklabels([])
        #
        #  Format tick params
        #
        ax.tick_params(which="both", right=True, top=True, direction="in", labelsize=15)
        #
        #  Draw histogram
        #
        bins_x, bins_y = get_bins(obs_x, is_latent=is_latent, num_bins=num_bins), get_bins(obs_y, is_latent=is_latent, num_bins=num_bins)
        _, _, _, patches = ax.hist2d(datapoints[:,obs_idx_x], datapoints[:,obs_idx_y], weights=weights/norm_const[(obs_idx_x, obs_idx_y)], bins=(bins_x, bins_y),
                                  vmin=vmin, vmax=1, norm=colors.LogNorm(), cmap="inferno")
        #
        #  Draw label
        #
        if (obs_idx_x==0) and (obs_idx_y==1) and len(label) > 0 :
            ax.text(0, 1.2, label, weight="bold", ha="left", va="bottom", transform=ax.transAxes, fontsize=21)
    #
    #  Draw colour bar
    #
    cbar_ax = fig.add_axes([0.76, 0.5, 0.03, 0.45])
    cbar    = fig.colorbar(patches, cax=cbar_ax)
    cbar_ax.tick_params(labelsize=14)
    cbar   .set_ticks([1, 0.1, 0.01, 0.001, 0.0001, 1e-5])
    cbar   .set_label(r"$\frac{p(x)}{{\rm max}~p(x)}$", fontsize=25, labelpad=50, rotation=0, va="center")
    #
    #  Save and show plot
    #
    if len(savefig) > 0 :
        plt.savefig(savefig, bbox_inches="tight")
    plt.show()

In [None]:
import math
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

cmap_bwr  = cm.get_cmap('bwr', 256)
newcolors = cmap_bwr(np.linspace(0, 1, 256))
newcolors [math.ceil(256*2/6)-1:math.floor(256*4/6)-1] = np.array([68/256, 223/256, 68/256, 1])
custom_colormap = ListedColormap(newcolors, name='BlueToRed')


In [None]:

def plot_2D_ratios (datapoints_num, datapoints_den, weights_num=None, weights_den=None, label="", savefig="", is_latent=False, num_bins=20) :
    
    #global observables, num_observables, observable_limits, transformed_observable_limits, int_observables, log_observables   # VBFZ-tag
    observables      , num_observables = VBFZ.observables      , VBFZ.num_observables    # NB-tag
    observable_limits, int_observables = VBFZ.observable_limits, VBFZ.int_observables    # NB-tag
    transformed_observable_limits = VBFZ.transformed_observable_limits    # NB-tag
    #
    #  If no weights provided then assume uniform
    #
    if type(weights_num) == type(None) : weights_num = np.ones(shape=(datapoints_num.shape[0],))
    if type(weights_den) == type(None) : weights_den = np.ones(shape=(datapoints_den.shape[0],))
    #
    #  Save the list of indices to plot (to make sure all loops are over consistent sets)
    #
    obs_to_plot = get_obs_for_2D_plot (observables)
    #
    #  Make plot
    #
    fig = plt.figure(figsize=(20, 14))
    vmin = 1e-5
    for obs_idx_x, obs_x, obs_idx_y, obs_y in obs_to_plot :
        xlo    = obs_idx_x / (num_observables-1)    #  Get axis x coordinates
        xwidth = 1.        / (num_observables-1)
        ylo     = (num_observables-obs_idx_y-1) / (num_observables-1)   #  Get axis y coordinates
        yheight = 1.                            / (num_observables-1)
        #
        #  Create axis
        #
        ax = fig.add_axes([xlo, ylo, 0.95*xwidth, 0.95*yheight])
        #
        #  Format log axes
        #
        if not is_latent :
            if obs_x in log_observables : ax.set_xscale("function", functions=log_axis_functions )
            if obs_y in log_observables : ax.set_yscale("function", functions=log_axis_functions )
        #
        #  Draw axis ticks and labels
        #
        ax.set_xticks(get_obs_ticks(obs_x, is_latent=is_latent))
        ax.set_yticks(get_obs_ticks(obs_y, is_latent=is_latent))
        if obs_idx_y == num_observables-1 : 
            ax.get_xaxis().set_ticklabels(get_obs_ticklabels(obs_x, is_latent=is_latent))
            ax.set_xlabel(get_obs_label(obs_x).replace("  [","\n["), fontsize=19, labelpad=20, va="top", ha="center")
        else :
            ax.get_xaxis().set_ticklabels([])
        if obs_idx_x == 0 : 
            ax.get_yaxis().set_ticklabels(get_obs_ticklabels(obs_y, is_latent=is_latent))
            ax.set_ylabel(get_obs_label(obs_y).replace("  [","\n["), fontsize=19, labelpad=20, rotation=0, va="center", ha="right")
        else :
            ax.get_yaxis().set_ticklabels([])
        #
        #  Format tick params
        #
        ax.tick_params(which="both", right=True, top=True, direction="in", labelsize=15)
        #
        #  Draw histogram
        #
        bins_x, bins_y = get_bins(obs_x, is_latent=is_latent, num_bins=num_bins), get_bins(obs_y, is_latent=is_latent, num_bins=num_bins)
        X, Y, ratio, ratio_err = plot.get_ratio_2D (datapoints_num[:,obs_idx_x], datapoints_num[:,obs_idx_y],
                                                    datapoints_den[:,obs_idx_x], datapoints_den[:,obs_idx_y],
                                                    bins_x, bins_y, weights1=weights_num, weights2=weights_den)            
        im = ax.pcolormesh(X, Y, ratio.transpose()-1, cmap=custom_colormap, vmin=-0.3, vmax=0.3)
        #
        #  Draw label
        #
        if (obs_idx_x==0) and (obs_idx_y==1) and len(label) > 0 :
            ax.text(0, 1.2, label, weight="bold", ha="left", va="bottom", transform=ax.transAxes, fontsize=21)
    #
    #  Draw colour bar
    #
    cbar_ax = fig.add_axes([0.76, 0.5, 0.03, 0.45])
    cbar    = fig.colorbar(im, cax=cbar_ax)
    cbar_ax.tick_params(labelsize=14)
    cbar   .set_ticks([-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3])
    cbar   .set_label(r"$\frac{p(x)}{{\rm max}~p(x)}$", fontsize=25, labelpad=50, rotation=0, va="center")
    #
    #  Save and show plot
    #
    if len(savefig) > 0 :
        plt.savefig(savefig, bbox_inches="tight")
    plt.show()
    

In [None]:

plot_2D_ratios(fake_white_datapoints, white_data, weights_den=true_data_weights, is_latent=True, 
               label="Samples from density model / MG5 events (latent space)",
               savefig=f"figures/paper0D_model/2D_ratios_latent_{tag}.pdf")

plot_2D_projections(white_data, weights=true_data_weights, is_latent=True, 
                    label="MG5 events (latent space)",
                    savefig=f"figures/paper0D_model/2D_dist_MG5_latent_{tag}.pdf")

plot_2D_projections(fake_white_datapoints, is_latent=True, 
                    label="Samples from density model (latent space)",
                    savefig=f"figures/paper0D_model/2D_dist_model_latent_{tag}.pdf")


In [None]:
#  black below cbar threshold


plot_2D_ratios(fake_datapoints, true_data, weights_den=true_data_weights, is_latent=False, 
               label="Samples from density model / MG5 events (physical space)",
               savefig=f"figures/paper0D_model/2D_ratios_physical_{tag}.pdf")

plot_2D_projections(true_data, weights=true_data_weights, is_latent=False, 
                    label="MG5 events (physical space)",
                    savefig=f"figures/paper0D_model/2D_dist_MG5_physical_{tag}.pdf",
                    vmin=1e-4)

plot_2D_projections(fake_datapoints, is_latent=False, 
                    label="Samples from density model (physical space)",
                    savefig=f"figures/paper0D_model/2D_dist_model_physical_{tag}.pdf",
                    vmin=1e-4)


In [None]:


def plot_1D_projections (datapoints_num, datapoints_den, weights_num=None, weights_den=None, savefig="", is_latent=False, num_bins=20, max_cols=7) :
    """plot the 1D projections of the datapoints provided"""
    
    #global observables, num_observables, observable_limits, transformed_observable_limits, int_observables, log_observables   # VBFZ-tag
    observables      , num_observables = VBFZ.observables      , VBFZ.num_observables    # NB-tag
    observable_limits, int_observables = VBFZ.observable_limits, VBFZ.int_observables    # NB-tag
    transformed_observable_limits = VBFZ.transformed_observable_limits    # NB-tag
    #
    #  If no weights provided then assume uniform
    #
    if type(weights_num) == type(None) : weights_num = np.ones(shape=(datapoints_num.shape[0],))
    if type(weights_den) == type(None) : weights_den = np.ones(shape=(datapoints_den.shape[0],))
    #
    #  Calculate out plot dimensions and create figure
    #
    num_cols = np.min([max_cols, num_observables])
    num_rows = math.ceil(num_observables/num_cols)
    fig = plt.figure(figsize=(2*num_cols, 6*num_rows))
    #
    #  Loop over subplots
    #
    axes1, axes2 = [], []
    ymins, ymaxs = [], []
    for row_idx in range(num_rows) :
        for col_idx in range(num_cols) :
            obs_idx = num_cols*row_idx + col_idx
            if obs_idx >= num_observables : continue
            observable = observables[obs_idx]
            #
            #  Get axis co-ordinates
            #
            xlo, xwidth  = col_idx/num_cols, 1./num_cols
            ylo, yheight = 1. - (1+row_idx)/num_rows, 1./num_rows
            #
            #
            #  Get values of distributions
            #
            #  get binning
            bins = get_bins(observable, is_latent=is_latent, num_bins=num_bins)
            #  numerator histo values
            hvals_num, _ = np.histogram(datapoints_num[:,obs_idx], bins=bins, weights=weights_num            )
            herrs_num, _ = np.histogram(datapoints_num[:,obs_idx], bins=bins, weights=weights_num*weights_num)
            herrs_num    = np.sqrt(herrs_num)
            hvals_num, herrs_num = hvals_num/np.sum(weights_num), herrs_num/np.sum(weights_num)
            #  denominator histo values
            hvals_den, _ = np.histogram(datapoints_den[:,obs_idx], bins=bins, weights=weights_den            )
            herrs_den, _ = np.histogram(datapoints_den[:,obs_idx], bins=bins, weights=weights_den*weights_den)
            herrs_den    = np.sqrt(herrs_den)
            hvals_den, herrs_den = hvals_den/np.sum(weights_den), herrs_den/np.sum(weights_den)
            #  histograms expressed as lines
            plot_x, plot_y_num, plot_ey_num = plot.histo_to_line(bins, hvals_num, herrs_num)
            _     , plot_y_den, plot_ey_den = plot.histo_to_line(bins, hvals_den, herrs_den)
            #
            #  Create absolute distribution plot (top panel of each observable)
            #
            ax1 = fig.add_axes([xlo, ylo+0.6*yheight, 0.95*xwidth, 0.38*yheight])
            ax1.plot(plot_x, plot_y_num, "-", color="k"      , linewidth=2, label="MG5 events")
            ax1.fill_between(plot_x, plot_y_num-plot_ey_num, plot_y_num+plot_ey_num, color="lightgrey", alpha=1)
            ax1.plot(plot_x, plot_y_den, "-", color="darkred", linewidth=2, label="Samples from density model")
            ax1.fill_between(plot_x, plot_y_den-plot_ey_den, plot_y_den+plot_ey_den, color="red", alpha=0.2)
            ax1.set_yscale("log")
            #
            # Save ymin, ymax and top axis for this observable
            #
            ymin, ymax = np.min([plot_y_num-plot_ey_num, plot_y_den-plot_ey_den]), np.max([plot_y_num+plot_ey_num, plot_y_den+plot_ey_den])
            ymins.append(ymin)
            ymaxs.append(ymax)
            axes1.append(ax1)
            #
            #  Create ratio plot (bottom panel of each observable) and save it
            #
            ax2 = fig.add_axes([xlo, ylo+0.2*yheight, 0.95*xwidth, 0.38*yheight])
            ax2.axhline(0, c="darkred", linewidth=2)
            ax2.fill_between(plot_x, -plot_ey_den/plot_y_den, plot_ey_den/plot_y_den, color="red", alpha=0.2)
            ax2.plot(plot_x, (plot_y_num-plot_y_den)/plot_y_den, c="k", linewidth=2)
            ax2.fill_between(plot_x, (plot_y_num-plot_ey_num-plot_y_den)/plot_y_num, (plot_y_num+plot_ey_num-plot_y_den)/plot_y_num, color="lightgrey", alpha=0.5)
            axes2.append(ax2)
            #
            #  Set ylim and draw horizontal reference lines
            #
            ax2.set_ylim([-0.12, 0.12])
            for h in [-0.1, -0.05, 0.05, 0.1] :
                ax2.axhline(h, linestyle="--", c="grey", linewidth=0.5)
            #  
            #  Set x limits and scale
            #  
            ax1.set_xlim([bins[0], bins[-1]])
            ax2.set_xlim([bins[0], bins[-1]])
            if not is_latent :
                if observable in log_observables :
                    ax1.set_xscale("function", functions=log_axis_functions )
                    ax2.set_xscale("function", functions=log_axis_functions )
            #
            #  Set axis ticks
            #   
            if col_idx > 0 :
                ax1.get_yaxis().set_ticklabels([])
                ax2.get_yaxis().set_ticklabels([])
            ax1.set_xticks(get_obs_ticks(observable, is_latent=is_latent))
            ax2.set_xticks(get_obs_ticks(observable, is_latent=is_latent))
            ax1.get_xaxis().set_ticklabels([])
            ax2.get_xaxis().set_ticklabels(get_obs_ticklabels(observable, is_latent=is_latent))
            #   
            #  Set axis labels
            #   
            ax2.set_xlabel(get_obs_label(observable), fontsize=19, labelpad=20)
            if col_idx == 0 : 
                ax1.set_ylabel("Normalised\nentries", fontsize=19, labelpad=75, rotation=0, va="center")
                ax2.set_ylabel("Ratio to\ndensity\nmodel", fontsize=19, labelpad=65, rotation=0, va="center")
                ax2.set_yticks     ([-0.1, -0.05, 0, 0.05, 0.1])
                ax2.set_yticklabels([r"$-10\%$", r"$-5\%$", r"$0$", r"$+5\%$", r"$+10\%$"])
            #  
            #  Set tick params
            #  
            ax1.tick_params(which="both", right=True, top=True, direction="in", labelsize=15)
            ax2.tick_params(which="both", right=True, top=True, direction="in", labelsize=15)
    #
    #  Set consistent axis y lims
    #
    ymin, ymax = np.min([y for y in ymins if y > 0])/2., 2.*np.max(ymaxs)
    for ax in axes1 :
        ax.set_ylim([ymin, ymax])
    #
    #  Set y-axis ticks and legend
    #
    axes1[0].legend(loc=(0, 1.05), frameon=True, edgecolor="white", facecolor="white", ncol=2, fontsize=17)
    #
    #  Save and show figure
    #
    if len(savefig) > 0 :
        plt.savefig(savefig, bbox_inches="tight")
    plt.show()


In [None]:

plot_1D_projections(true_data, fake_datapoints, weights_num=true_data_weights, is_latent=False,
                    savefig=f"figures/paper0D_model/1D_dist_physical_{tag}.pdf")


In [None]:

plot_1D_projections(white_data, fake_white_datapoints, weights_num=true_data_weights, is_latent=True,
                    savefig=f"figures/paper0D_model/1D_dist_latent_{tag}.pdf")
