# SBI Sketch plot

This script creates some of the panels for the sketch plot of the SBI workflow (Figure 2). They are then combined using any image editing software to create the final figure.

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern
from sbi_ice.simulators import Layer_Tracing_Sim as lts
from tueplots import figsizes
from sbi_ice.utils import plotting_utils,misc,posterior_utils,modelling_utils
from sbi_ice.loaders import ShortProfileLoader
from omegaconf import OmegaConf
import pickle 

data_dir,output_dir = misc.get_data_output_dirs()
root_dir = misc.get_project_root()
color_opts = plotting_utils.setup_plots()

## Some Plots for Prior and Observation

In [None]:
data_file = Path(data_dir,"Synthetic_long","setup_files","uneven_mb_flowline.csv")


df = pd.read_csv(data_file)
xmb = df["x_coord"].to_numpy() #x - coordinates of domain
tmb = df["tmb"].to_numpy() #tmb - total mass balance

nx_iso             = 500
time_init          = 0   # [yr] Starting time
dt                 = 0.5 # [yr] timestep


seed = 42
ker =  Matern(length_scale=1e4,nu=2.5)
gpr = GaussianProcessRegressor(kernel=ker)
var = 0.5
smb_samples = var*gpr.sample_y(xmb.reshape(-1,1),n_samples =5000,random_state=seed)
const_mean = 1.0
smb_samples += const_mean
smb_mean = smb_samples.mean(axis=1)
smb_std = smb_samples.std(axis=1)
print(smb_std.mean())
smb = smb_samples[:,0]

geom = lts.Geom(nx_iso=nx_iso,ny_iso=1)
smb_regrid,bmb_regrid = lts.init_geom_from_fname(geom,data_file,smb=smb)

n_surface_phase = 2
n_base_phase = 10
time_phase = 500

sched1 = lts.Scheduele(time_phase,dt,n_surface_phase,n_base_phase)
print(sched1.total_iterations)



In [None]:
plt.rcParams.update(figsizes.icml2022_half(height_to_width_ratio=0.8))
fig,axs = plt.subplots(2,1,sharex=True)

axs[0].plot(xmb/1e3,const_mean*np.ones_like(xmb),color=color_opts["colors"]["prior"])
axs[0].fill_between(xmb/1e3,(const_mean-2*var)*np.ones_like(xmb),(const_mean+2*var)*np.ones_like(xmb),color=color_opts["colors"]["prior"],alpha=plotting_utils.prior_alpha,linewidth=0.0)
axs[0].set_ylabel(r"$\dot{a}$",size=16,rotation = 0)
axs[0].yaxis.set_label_coords(0.0,0.45)


axs[1].plot(xmb/1e3,const_mean*np.ones_like(xmb)-tmb,color=color_opts["colors"]["prior"])
axs[1].fill_between(xmb/1e3,(const_mean-2*var)*np.ones_like(xmb)-tmb,(const_mean+2*var)*np.ones_like(xmb)-tmb,color=color_opts["colors"]["prior"],alpha=plotting_utils.prior_alpha,linewidth=0.0)
axs[1].set_ylabel(r"$\dot{b}$",size=16,rotation =0)
axs[1].set_xlabel("Distance",size=16)
axs[1].yaxis.set_label_coords(0.0,0.45)
for i in range(5):
    axs[0].plot(xmb/1e3,smb_samples[:,i],color=color_opts["colors"]["prior"],alpha=plotting_utils.samples_alpha)
    axs[1].plot(xmb/1e3,smb_samples[:,i]-tmb,color=color_opts["colors"]["prior"],alpha=plotting_utils.samples_alpha)
for ax in axs:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(
        axis='both',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        left = False,
        right = False,
        labelleft=False,
        labelbottom=False) # labels along the bottom edge are off
    
fig_name = Path(output_dir,"paper_figures","SBI_sketch","spatial_prior.svg")
fig_name.parent.mkdir(parents=True,exist_ok=True)
fig.savefig(fig_name)



In [None]:
geom.initialize_layers(sched1,10)
lts.sim(geom,smb_regrid,bmb_regrid,sched1)

In [None]:
plt.rcParams.update(figsizes.icml2022_full(height_to_width_ratio=1.0))
fig,axs = plotting_utils.plot_isochrones_1d(geom.x,geom.bs.flatten(),geom.ss.flatten(),geom.dsum_iso[:,0,::10],geom.age_iso[::10],bmb_regrid,smb_regrid,real_layers=None,trackers=geom.extract_active_trackers())


## Now plot the layers and pick one layer as the simulated observation

In [None]:
plt.rcParams.update(figsizes.icml2022_half(height_to_width_ratio=0.8))

fig,ax = plt.subplots(1,1)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    left = False,
    right = False,
    labelleft=False,
    labelbottom=False) # labels along the bottom edge are off


plotting_utils.plot_layers(geom.x,geom.bs.flatten(),geom.ss.flatten(),geom.dsum_iso[:,0,::40],geom.age_iso[::40],ax=ax,color="black")

fig_name = Path(output_dir,"paper_figures","SBI_sketch","simulation.svg")
fig_name.parent.mkdir(parents=True,exist_ok=True)
fig.savefig(fig_name)


In [None]:
obs_layer = -100
fig,ax = plt.subplots(1,1)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    left = False,
    right = False,
    labelleft=False,
    labelbottom=False) # labels along the bottom edge are off


ax.plot(geom.x,geom.bs.flatten() + geom.dsum_iso[:,0,obs_layer],color=color_opts["colors"]["observation"],linewidth=1.0)
ax.plot(geom.x,geom.ss.flatten(),color="black",linewidth=1.5)
ax.plot(geom.x,geom.bs.flatten(),color="black",linewidth=1.5)
ax.fill_between(geom.x,geom.bs.flatten(),geom.ss.flatten(),color="black",alpha=0.075,linewidth=0.0)

fig_name = Path(output_dir,"paper_figures","SBI_sketch","simulated_obs.svg")
fig_name.parent.mkdir(parents=True,exist_ok=True)
fig.savefig(fig_name)

## Now Load Posterior and Make Posterior Plots

In [None]:
"""
Read config file for arguments for sbi analysis.
Load relevant results files.
Define which folders ot output to.
"""

shelf = "Synthetic_long"
exp = "exp3"
name = "all_final"
seed = "layer_0_seed_1100"
# shelf = "Ekstrom"
# exp = "exp2"
# name = "advanced_noise"
# seed = "layer_3_seed_2100"
fol = Path(output_dir , shelf, exp, "sbi_sims/post_predictives" , name,seed)
cfg_path = Path(fol,"config.yaml")
cfg = OmegaConf.load(cfg_path)
config_fol = cfg.posterior_config_fol
config_fol = Path(output_dir,shelf,exp,"sbi_sims/posteriors",name)
n_post_samples = cfg.n_post_samples
n_predictive_sims = cfg.n_predictive_sims
overwrite = cfg.overwrite_saved_sims
posterior_config = OmegaConf.load(Path(config_fol,str(cfg.name),"config.yaml"))
paths = posterior_config.paths
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



shelf_setup_path = paths.shelf_setup_path
print(shelf_setup_path)
shelf_folder_path = paths.shelf_folder_path
exp_path = paths.exp_path

loader = ShortProfileLoader.ShortProfileLoader(Path(root_dir,shelf_setup_path),Path(root_dir,shelf_folder_path),exp_path,gt_version = posterior_config.gt_version)
loader._jobs = [i for i in range(1,200)]
loader._num_sims = [1000 for i in range(1,200)]
loader.total_sims = 1000*200


In [None]:

with open(Path(config_fol,str(cfg.name),"inference.p"), "rb") as f:
    if device.type == "cpu":
        out = posterior_utils.CPU_Unpickler(f).load()
    else:
        out = pickle.load(f)


inference = out["inference"]
prior = out["prior"]
layer_mask = out["layer_mask"]
smb_mask = out["smb_mask"]
true_layer = torch.tensor(loader.real_layers[cfg.layer_idx][layer_mask]).float()
posterior = inference.build_posterior(inference._neural_net.to("cpu"))
posterior.set_default_x(true_layer)
samples = posterior.sample((n_post_samples,))
try:
    true_smb = loader._true_smb_const_mean + loader._true_smb_var*loader._true_smb_unperturbed
    #true_smb = true_smb[smb_mask]
    true_smb = modelling_utils.regrid(loader._true_xmb,true_smb,loader.x)
except:
    true_smb=None
spatial_samples = samples
print(spatial_samples.shape)

with open(Path(fol,"post_predictive.p"), "rb") as f:
    out = pickle.load(f)
    bmb_samples = out["bmb_samples"]
    best_layers = out["best_layers"]
    norms = out["norms"]
    ages = out["ages"]
    active_trackers = out["active_trackers"]



In [None]:

contour_arrays,norm_arrays,age_arrays,smb_unperturbed_all,smb_cnst_means_all,smb_sds_all,smb_all,bmb_all = loader.load_training_data("all_layers_final.p","all_mbs_final.p")
perm = torch.randperm(smb_all.size(0))
idx = perm[:1000]
prior_samples = smb_all[idx][:,smb_mask]
prior_spatial_samples = prior_samples

tmb = smb_all[0].numpy() + bmb_all[0].numpy()


In [None]:
"""
Calculate prior, prior predictive, posterior and posterior predictive means and precentiles 
"""
#Parameter Values
prior_samples = smb_all
posterior_samples = posterior.sample((20000,))
percentiles = [2.3,97.7]

post_mean_smb = torch.mean(posterior_samples,axis=0)
post_uq_smb = torch.quantile(posterior_samples,percentiles[1]/100,axis=0)
post_lq_smb = torch.quantile(posterior_samples,percentiles[0]/100,axis=0)
prior_mean_smb = torch.mean(prior_samples,axis=0)
prior_uq_smb = torch.quantile(prior_samples,percentiles[1]/100,axis=0)
prior_lq_smb = torch.quantile(prior_samples,percentiles[0]/100,axis=0)

In [None]:
tmb = (smb_regrid - bmb_regrid).flatten()
fig,axs = plt.subplots(2,1,sharex=True)


axs[0].plot(loader.x[smb_mask]/1e3,post_mean_smb,color=color_opts["colors"]["posterior"],label="Posterior Mean",linewidth=1.0)
axs[0].fill_between(loader.x[smb_mask]/1e3,post_lq_smb,post_uq_smb,color=color_opts["colors"]["posterior"],alpha=plotting_utils.prior_alpha,label="Posterior 95% CI",linewidth=0.0)
axs[0].set_ylabel(r"$\dot{a}$",size=16,rotation = 0)
axs[0].yaxis.set_label_coords(0.0,0.3)

axs[1].plot(loader.x[smb_mask]/1e3,post_mean_smb-tmb[smb_mask],color=color_opts["colors"]["posterior"],linewidth=1.0)
axs[1].fill_between(loader.x[smb_mask]/1e3,post_lq_smb-tmb[smb_mask],post_uq_smb-tmb[smb_mask],color=color_opts["colors"]["posterior"],alpha=plotting_utils.prior_alpha,linewidth=0.0)
axs[1].set_ylabel(r"$\dot{b}$",size=16,rotation = 0)
axs[1].yaxis.set_label_coords(0.0,0.35)
axs[1].set_xlabel("Distance",size=16)
for ax in axs:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(
        axis='both',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        left = False,
        right = False,
        labelleft=False,
        labelbottom=False) # labels along the bottom edge are off
for i in range(5):
    axs[0].plot(loader.x[smb_mask]/1e3,posterior_samples[i],color=color_opts["colors"]["posterior"],alpha=plotting_utils.samples_alpha)
    axs[1].plot(loader.x[smb_mask]/1e3,posterior_samples[i]-tmb[smb_mask],color=color_opts["colors"]["posterior"],alpha=plotting_utils.samples_alpha)

fig_name = Path(output_dir,"paper_figures","SBI_sketch","spatial_post.svg")
fig_name.parent.mkdir(parents=True,exist_ok=True)
fig.savefig(fig_name)
