# Imports

In [1]:
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style("ticks")
sns.set_context("talk")

import arviz as az
import warnings
warnings.filterwarnings("ignore")

from pathlib import Path
PROJECT_ROOT = Path.cwd().parents[0]
import sys
sys.path.append(str(PROJECT_ROOT))
from bahamas_lig.utils import *
import importlib
importlib.reload(sys.modules['bahamas_lig'])

model_dir = PROJECT_ROOT / "model_outputs/"
inference_dir = PROJECT_ROOT / "model_outputs/"
data_dir = PROJECT_ROOT / "data/"
year = '2023'

import os
from IPython.display import clear_output
import time
import ipywidgets as widgets
from ipywidgets import Layout
from IPython.display import display
from ipywidgets import Button, HBox, VBox, Label
import pandas as pd
import numpy as np
import numpy as np

from scipy.ndimage import gaussian_filter as gaussian
from matplotlib import pyplot as plt
import pickle
import pandas as pd
from scipy.interpolate import interp1d
from scipy.stats import chi2
import os
import datetime as dt
import matplotlib.dates as mdates
import pymc3 as pm
import theano.tensor as tt
from theano import shared
from pymc3.distributions.dist_math import SplineWrapper
from scipy.interpolate import interp1d, UnivariateSpline
from scipy.stats import gaussian_kde
import seaborn as sns

from IPython.display import display

In [69]:
data_file = "Barnett_EuropeLIG_SuppData_April2023_V2.2.xlsx"

The following javascript cell runs bottom the cell of notebook to define some helper functions and force refresh the IPython displays.

In [70]:
%%javascript 
Jupyter.notebook.execute_cells([-2,-1,-6,-5,-4])

<IPython.core.display.Javascript object>

In [73]:
print("Select models to inspect:")
display(h_box)
display(HBox(buttons))

Select models to inspect:


HBox(children=(VBox(children=(Label(value='Lithosphere'), SelectMultiple(layout=Layout(height='250px', width='…

HBox(children=(Button(description='Plot weighted inference for selected', layout=Layout(width='33%'), style=Bu…

In [74]:
display(output_simulation)

Output()

In [75]:
display(filtered_df_output)

Output()

## Helper functions and widget definitions

In [71]:
def get_model_status(inference_dir,model_dir):
    model_posterior_dir = str(inference_dir)+'/'+str(f'arviz_traces_{year}')
    model_posterior_list=[o[:-3] for o in os.listdir(model_posterior_dir) if '.nc' in o]

    model_predict_dir = str(inference_dir)+'/'+str(f'pymc3_post_predict_{year}')
    model_predict_list=[o[:-4] for o in os.listdir(model_predict_dir) if '.pkl' in o]

    model_files_dir = str(model_dir)
    model_files_raw=[o for o in os.listdir(model_files_dir) if '.dat' in o]
    unique_models=list(set(['_'.join(a.split('_')[:-1]) for a in model_files_raw]))


    model_weights = pd.read_csv(str(inference_dir)+'/'+str('model_weights/model_weights.csv'),index_col=0)

    models={}
    for u in unique_models:
        if '_new' in u:
            u_proc=u.replace('_new','')
        else:
            u_proc = u
        models[u]={}
        models[u]['Lithosphere']= int(u_proc.split('output')[1].split('.dat')[0].split('Cp')[0])
        models[u]['UMV']= int(u_proc.split('output')[1].split('.dat')[0].split('Cp')[1][0])
        models[u]['LMV']= int(u_proc.split('output')[1].split('.dat')[0].split('Cp')[1][1:].split('_')[0])
        models[u]['ice_history']= u_proc.split('output')[1].split('.dat')[0].split('_')[1]
        models[u]['esl_curve']= u_proc.split('output')[1].split('.dat')[0].split('Wael_')[1][0]
        if any([u.split('output')[1] in mpl for mpl in model_posterior_list]):
            models[u]['posterior_trace']= True
        else:
            models[u]['posterior_trace']= False
        if any([u.split('output')[1] in mpl for mpl in model_predict_list]):
            models[u]['posterior_predict']= True
        else:
            models[u]['posterior_predict']= False
        if u in list(model_weights.index):
            models[u]['weight']=model_weights.loc[u]['weight']
        else:
            models[u]['weight']=0

    models_df = pd.DataFrame.from_dict(models).T
    return models_df  
  
def inference_model_real_data(data, z_functions, model_name, init="adapt_full", target_accept=0.8, 
                              keys = ["coral", "highstand"], cores=4, chains=4, tune=1000, draws=1000):
    X_new = np.linspace(115, 130, 200)[:, np.newaxis]
    with pm.Model() as model:

        ELEVATION = shared(data["elevation"].values)
        ELEVATION_U = shared(data["elevation_uncertainty"].values)

        age_sd = {}
        age = {}
        
        for samp in data.index():
            key=data.loc[samp,'type']
            
            if (key == "coral" or key == "index" or key == "limiting"):
                BoundedNormal = pm.Bound(pm.Normal, lower=117, upper=128)
                age[samp] = BoundedNormal(str(samp + "_age"), mu=data.loc[samp,'age'], 
                                          sd=data.loc[samp,'age_uncertainty'], shape=(1))
            else:
                print("data type not implemented or key error, check dataframe")
             
            #         for key in keys:
#             type_filter = data["type"] == key
#             AGE = data[type_filter]["age"].values
#             AGE_U = data[type_filter]["age_uncertainty"].values
#             IDS = data[type_filter]["ID"].values
#             N = data[type_filter]["age"].size
            
#             # age errors by data type
#             if (key == "coral" or key == "index"): #normal age errors
#                 BoundedNormal = pm.Bound(pm.Normal, lower=117, upper=128)
#                 age[key] = BoundedNormal(str(key + "_age"), mu=shared(AGE), sd=shared(AGE_U), shape=(N))
#             elif key == "limiting":
#                 for ID in IDS:
                    
#                 BoundedNormal = pm.Bound(pm.Normal, lower=117, upper=128)
#                 age[key] = BoundedNormal(str(key + "_age"), mu=shared(AGE), sd=shared(AGE_U), shape=(N))
#             elif key == "highstand":
#                 age_sd[key] = pm.Wald(str(key + "_age_sd"), mu=2, lam=5, shape=(N))
#                 age[key] = pm.Deterministic(
#                     str(key + "_age"), shared(AGE)-1 + age_sd[key]
#                 )  # reshaped to improve Hamiltonian Monte Carlo
#             else:
#                 print("data type not implemented or key error, check dataframe")
            
            

        

    ## Gaussian Process Kernels
        gp_ls = pm.Wald("gp_ls", mu=2, lam=5, shape=1)
        gp_var = pm.Normal("gp_var", mu=0, sd=500, shape=1)
        m_gmsl = pm.Normal("m_gmsl", 0, 200)
        mean_fun = pm.gp.mean.Constant(m_gmsl)
        

        cov1 = gp_var[0] * pm.gp.cov.ExpQuad(1, gp_ls[0])
        cov2 = pm.gp.cov.WhiteNoise(.05) #this noise helps avoid cholesky decomp failures
      
        gp = pm.gp.Marginal(mean_func=mean_fun,cov_func=cov1+cov2)
        
        ## collect ages from all types of data
        ages = [age[x] for x in data.index()]
        ages = pm.Deterministic("ages", tt.concatenate(ages))

        ## interpolation of simulated age for GIA correction
        N = data["age"].size
        GIA = tt.zeros(N, dtype="float64")
        for i in range(N):
            GIA = tt.set_subtensor(GIA[i], SplineWrapper(z_functions[i])(ages[i]))
        gia_collect = pm.Deterministic(
            "GIA", GIA
        )  # samples of GIA model RSL (includes Wael)

        ## add water depth to GIA by data type
        water_depth_sd = {}
        water_depth = {}
        for samp in data.index():
            key=data.loc[samp,'type']

            if key == "limiting":
                mean = 2
                lam=5 
  
                water_depth[samp] = pm.Wald(
                    str(samp + "_water_depth"), mu=mean, lam=lam, shape=(1)
                )
                water_depth[samp]=-1*(water_depth[samp]-1.15) #negative to make terrestrial, 1.15 sets max_like at 0

            elif (key == "highstand" or key == "index"): #no added water depth
                water_depth[samp] = pm.Deterministic(
                    str(key + "_water_depth"), 0)
                
            else:
                print("data type not implemented or key error, check dataframe")

        ## long term subsidence
        N=data["elevation"].values.size
        uplift = pm.Normal("uplift_master", 0, 1)
        uplift = np.ones(N)*uplift * shared(data["uplift_rate (std)"].values) + shared(data["uplift_rate (per ky)"].values)
        uplift = uplift*ages.flatten() ##had /1000 here for last iteration
        uplift_for_each = pm.Deterministic("uplift",uplift)
        ## collect all through concat
        water_depths = [water_depth[x] for x in data.index()]
        water_depths = pm.Deterministic("water_depths", tt.concatenate(water_depths))

        # master equation:
        # GMSL = Elevation observation +/- elevation uncertainty +/- water depth - GIA + SUBSIDENCE
        # keep in mind its solving for change in GMSL from the GMSL used in GIA model

        elevations_sd = pm.Normal("elev_sd", 0, 1, shape=(data['age'].size))
        elevations = pm.Deterministic("elev", ELEVATION + elevations_sd * ELEVATION_U)
#         elevations = pm.Deterministic("elev", ELEVATION)
        
        gmsl_points = pm.Deterministic(
            "gmsl_points", elevations + water_depths - GIA.flatten() - uplift #+subsidence
        )
        
        noise = pm.HalfCauchy("noise", beta=5)
        
        
        gmsl_inference = gp.marginal_likelihood(
            "gmsl",
            X=ages[:, np.newaxis],
            y=gmsl_points.flatten(),
            shape=((N),),
            noise=pm.gp.cov.WhiteNoise(sigma=noise),
#             Xu=Xu[:, np.newaxis],
        )  # GMSL deviation from Wael (esl)
        
        az_trace = pm.sample(tune=tune,draws=draws,
                init=init, progressbar=True, cores=cores, target_accept=target_accept, chains=chains, 
                          return_inferencedata=True
            )
        
        f_pred = gp.conditional(
            "f_pred", X_new, pred_noise=False
        )  
        
        model_posterior_dir = str(model_dir)+'/'+str(f'arviz_traces_{year}/')
        model_predict_dir = str(model_dir)+'/'+str(f'pymc3_post_predict_{year}/')
        
#         return model_trace
        try:
            pred_samples = pm.fast_sample_posterior_predictive(
            az_trace, var_names=['f_pred'],samples=1000,
            )
#             az_trace=az.from_pymc3(trace)
            az_trace.to_netcdf(model_posterior_dir+model_name+'.nc',groups=["posterior","log_likelihood","sample_stats"])
            with open(model_predict_dir+model_name+'.pkl', "wb") as buff:
                pickle.dump(pred_samples, buff)
            print('Success')
            time.sleep(1)
                
        except np.linalg.LinAlgError:
            print('Sampling failed, no output saved')
            time.sleep(1)
                

In [72]:
 def load_model(name):
    output_dir = 'output_glac_w_ice6g/'

    lats = pd.read_csv(model_dir / "lats", delimiter=",", header=None)
    lons = pd.read_csv(model_dir / "lons", delimiter=",", header=None)
    directory = model_dir / output_dir
    age = np.arange(115, 131, 1)

    extent = [0, 1, 0, 1]
    model_dims = [
        np.min(lons.values),
        np.max(lons.values),
        np.min(lats.values),
        np.max(lats.values),]

    files = np.sort(os.listdir(directory))
    files = [f for f in files if 'output' in f]  # ignores non output
    
    files = [f for f in files if name in f]

    rsl = []
    age=[]
    for i in range(0, len(files)):
        rsl.append(
        pd.read_csv(
            str(directory) + "/" + files[i], delimiter=",", header=None
        ).values
        )
        age.append(float(files[i].split('ka')[0].split('_')[-1]))
    return rsl, age, model_dims

def interpolation_functions(LAT, LON, GIA_MODEL, age, model_dims):
    island_Zs = [
        [lookup_z(lat, lon, m, model_dims) for lat, lon in zip(LAT, LON)]
        for m in GIA_MODEL
    ]
    island_Zs = np.array(island_Zs)
    Zfuns = []
    for k in range(island_Zs.shape[1]):
        rsl_function = UnivariateSpline(age, island_Zs[:, k], k=1, ext=3, s=0)
        Zfuns.append(rsl_function)  ## 3 returns boundary value at extrapolation
    return Zfuns

def lookup_z(lat, lon, model, model_dims):
    """
    Returns the RSL prediction at a specific lat, lon, on a specific GIA model timeslice.
    Parameters
    ----------
    lat: Latitude value
    lon: Longitude value
    model: A 2d matrix from a GIA model output representing a single timeslice.
    model_dims: The real word lat/lon dimensions of the model. [left, right, top, bottom]
    Returns
    -------
    The model RSL prediction nearest the lat, lon pair.
    """
    lat_len = model.shape[0]
    lon_len = model.shape[1]
    lon_list = np.linspace(model_dims[0], model_dims[1], lon_len)
    lat_list = np.linspace(model_dims[3], model_dims[2], lat_len)
    lon_id = np.argmin(
        np.abs(np.linspace(model_dims[0], model_dims[1], lon_len) - (lon))
    )
    lat_id = np.argmin(
        np.abs(np.linspace(model_dims[3], model_dims[2], lat_len) - (lat))
    )
    return model[lat_id, lon_id]

def load(file):
    """
    Custom load command for pickle objects.

    Parameters
    ----------
    file: str or path to file

    Returns
    -------
    Object saved through pickle.

    """
    with open(file, "rb") as input_file:
        return pickle.load(input_file)
    
def load_data():

    df=pd.read_excel(data_dir/f'raw/{data_file}',header=1)
    data = {}
    filtered = df[df['Indicative \nRange (m)']!='(limiting)']
    data["lon"] = filtered['Longitude'].dropna().values
    data["lat"] = filtered['Latitude'].dropna().values
    data["age"] = filtered['Age\n(ka BP)'].dropna().values
    data["age_uncertainty"] = filtered['Age Error\n(1σ, ka)'].dropna().values
    data["elevation"] = filtered['Relative Sea \nLevel (m)'].dropna().values
    data["elevation_uncertainty"] = filtered['RSL Error \n(1σ, m)'].dropna().values
    data["type"] = ['index' for l in filtered['Indicative \nRange (m)'].dropna()]
    data["lower_limit"] = filtered['RSL (m)\n(limiting min)'].dropna().values
    data["region"] = filtered['Region'].dropna().values
    #     keys = [mapping_uplift[i] for i in data['region']]
    data["uplift_rate (per ky)"] = filtered['uplift rate *** (mean, m/kyr)'].dropna().values
    data["uplift_rate (per ky)"] = [a for a in data["uplift_rate (per ky)"] if type(a) == type(2.3)]
    data["uplift_rate (std)"] = filtered['uplift rate *** (1σ, m/kyr)'].dropna().values

    data=pd.DataFrame.from_dict(data)
    # data['elevation'][data['elevation']=='(limiting)']=filtered['RSL (limiting)'].dropna().values
    data['elevation']=pd.to_numeric(data['elevation'])
    return data.copy()

def clear_and_run(click):
    samp=samples_slider.value
    accept=accept_slider.value
    output_simulation.clear_output()
    with output_simulation:     
        to_run = list(models_df.query(
                    f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}"
                ).index)
    
        backup(to_run)
        run_inferences(to_run,samp,accept)
        
def plot_on_click(click):
    output_simulation.clear_output(wait=True)
    with output_simulation: 
        filtdb=models_df.query(
                    f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}")
                
        to_run = list(filtdb.index)
        
        if any(filtdb['posterior_trace']==False):
            print('Selected models have no traces, please run inference and generate weights..')
            return None
            
        fig=weighted_inference_plot(to_run)
#         fig=plt.figure()
#         plt.plot(np.random.normal(0,1,100))
        plt.show()
    return
    
def backup(to_run):

        for f in to_run:
            try:
                model_posterior_dir = str(model_dir)+'/'+str(f'arviz_traces_{year}')
                os.rename(model_posterior_dir+'/'+f+'.nc', 
                          model_posterior_dir+'_backup/'+f+'.nc')
                model_predict_dir = str(model_dir)+'/'+str(f'pymc3_post_predict_{year}')
                os.rename(model_predict_dir+'/'+f+'.pkl', 
                          model_predict_dir+'_backup/'+f+'.pkl')
            except FileNotFoundError:
                pass

#     model_predict_dir = str(model_dir)+'/'+str('pymc3_post_predict_2021')
#     model_predict_list=[o[:-4] for o in os.listdir(model_predict_dir) if '.pkl' in o]
    

# def run_inferences(to_run,samp,accept):

#     data2 = load_data()
#     N = data2["elevation"].size
#     data2['age_uncertainty']=data2['age_uncertainty']
#     data2=data2.sort_values(['type'])
#     keys = list(data2['type'].unique())

#     count=1

#     for m in to_run:
#         clear_output(wait=True)
#         print("running simulation number " + str(count) + " of " + str(len(to_run)))
#         print("running model: " + m)
#         count+=1

#         GIA_MODEL, age, model_dims = load_model(m)

#         z_functions = interpolation_functions(data2["lat"], data2["lon"], GIA_MODEL, age, model_dims)

#         inference_model_real_data(
#         data2,
#         z_functions,
#         m,
#         target_accept=accept,
#         cores=8,chains=8,
#         keys=keys,
#         init="adapt_full",
#         tune=samp,
#         draws=samp)
        
    
def re_weight(click):
    
    filtered_df_output.clear_output()
    with filtered_df_output:
        print('Recalculating model weights...')
    
    model_posterior_dir = str(model_dir)+'/'+str(f'arviz_traces_{year}')
    model_posterior_list=[o[:-3] for o in os.listdir(model_posterior_dir) if '.nc' in o]    

    all_traces = {}
    for f in model_posterior_list:
        all_traces[f]=az.from_netcdf(model_posterior_dir+'/'+f+'.nc')

    comp = az.compare(all_traces, ic="loo", method='BB-pseudo-BMA', b_samples=1000) 
    comp.to_csv(str(model_dir)+'/'+str('model_weights/model_weights.csv'))
    
    update_table(click)  
    
def weighted_inference_plot(to_run):
    
    model_predict_dir = str(model_dir)+'/'+str(f'pymc3_post_predict_{year}')
    model_predict_list=[o[:-4] for o in os.listdir(model_predict_dir) if '.pkl' in o]
    
    preds = {}
    for f in to_run:
        if f in model_predict_list:
            preds[f]=load(model_predict_dir+'/'+f+'.pkl')
            
    if len(preds.keys())==0:
        print('Selected models have no traces, please run inference and generate weights..')
        return None

    model_weights = pd.read_csv(str(model_dir)+'/'+str('model_weights/model_weights.csv'),index_col=0)
    sub_list=[m for m in to_run if m in list(model_weights.index)]
    model_weights=model_weights.loc[sub_list]
    if np.sum(model_weights['weight'])==0:
        model_weights['weight']=1
    else:
        model_weights['weight']=model_weights['weight']/np.sum(model_weights['weight'])

    gmsl=weighted_trace(preds,model_weights,iters=10000)

    X_new = np.linspace(115, 130, 200)[:, np.newaxis]

    f_size=18

    sns.set_style(
        "ticks",
        {
            "axes.edgecolor": ".3",
            "xtick.color": ".3",
            "ytick.color": ".3",
            "text.color": ".3",
            "axes.facecolor": "(.98,.98,.98)",
            "axes.grid": True,
            "grid.color": ".95",
            "grid.linestyle": u"--",
        },
    )
    flatui = ["#D08770", "#BF616A", "#A3BE8C", "#B48EAD", "#34495e", "#5E81AC"]
    cs = sns.color_palette(flatui)

    ##Figure

    scale=1.5
    fig = plt.figure(figsize=(11*scale,4*scale))
    ax1=fig.add_subplot()

    plot_gmsl_inference(X_new,gmsl,cs[4],ax1,gmsl[1])
    plt.gca().set_title(
        "A. Last Interglacial GMSL",
        fontsize=f_size,
    )
    ax1.set_aspect(1/6)
    ax1.set_ylim([0, 12])
    # ax1.set_yticks([-2,0,2,4,6])
    # ax1.set_yticklabels([-2,0,2,4,6],fontsize=f_size)
    ax1.set_xlim(117, 128)
    ax1.invert_xaxis()
    ax1.set_xticks(np.arange(128,116,-1))
    ax1.set_xticklabels(np.arange(128,116,-1),fontsize=f_size)
    ax1.legend(loc="best", frameon=True, fontsize=f_size*.66)

    ax1.set_ylabel("Global Mean Sea Level\n(m above MSL)", fontsize=f_size)
    ax1.set_xlabel("Age (kya)",fontsize=f_size)
    ax1.grid(linewidth=1)


    fig.tight_layout(w_pad=0,h_pad=0)
    return fig
    
def weighted_trace(pred_list, comp, var="f_pred", iters=20000):
    weighted_trace =[]
    
    for i in range(iters):
        choice = np.random.choice(np.arange(len(comp)), 1, p=comp['weight'])
        key=comp.index[choice][0]
        f_preds = pred_list[key][var]
        C=np.random.choice(np.arange(len(f_preds)), 1)
        gmsl=f_preds[C].ravel()
        weighted_trace.append(gmsl)
        
    weighted_trace=np.array(weighted_trace)

    return weighted_trace

def plot_gmsl_inference(X_new,inference,color,ax,max_like):
    
    bot = np.nanpercentile(inference, 2.5, axis=0)
    top = np.nanpercentile(inference, 97.5, axis=0)

    max_like = gaussian(max_like,3)

    ax.fill_between(
            X_new.ravel(),
            bot,
            top,
            fc=(1,1,1),
            zorder=2,
            alpha=1,
            lw=0,
            ec=color,
            aa=True,
            capstyle="round",
        )
    ax.fill_between(
            X_new.ravel(),
            bot,
            top,
            fc=color,
            zorder=3,
            alpha=.1,
            lw=0,
            ec=color,
            aa=True,
            capstyle="round",
        )
    ax.fill_between(
        X_new.ravel(),
        bot,
        top,
        fc='none',
        zorder=4,
        alpha=1,
        lw=1.5,linestyle='--',
        ec=color,
        aa=True,
        capstyle="round",#hatch=''
    )
    
    bot = np.nanpercentile(inference, 33/2, axis=0)
    top = np.nanpercentile(inference, 100-33/2, axis=0)
    
    ax.fill_between(
        X_new.ravel(),
        bot,
        top,
        fc=color,
        zorder=3,
        alpha=.1,
        lw=0,
        ec=color,
        aa=True,
        capstyle="round",
    )
    
    ax.fill_between(
        X_new.ravel(),
        bot,
        top,
        fc='none',
        zorder=4,
        alpha=1,
        lw=1.5,linestyle='-',
        ec=color,
        aa=True,
        capstyle="round",#hatch=''
    )
    
    ## make legend here
    ax.plot([],[],color=color,lw=1.5,linestyle='--',label='95% GMSL envelope')
    ax.plot([],[],color=color,lw=1.5,linestyle='-',label='66% GMSL envelope')
    ax.plot([],[],color=color,lw=4,label='Most likely GMSL')
    

    
#     lig_only=((X_new<128) & (X_new>117)).ravel()
#     ax.plot(X_new,max_like,
#              zorder=13,color=color,lw=4)
    

    return ax

def load_synth_data():
    true_model = "output_new71Cp420_L6G_Wael_T" ## we'll generate synthetic data from this GIA model
    
    ## Synthetic GMSL

    lig_start = 128
    lig_end = 117
    lig_age = np.linspace(lig_start,lig_end,100)
    lig_dt = lig_age-lig_end
    lig_synth_gmsl = 3 * np.sin(lig_dt / (0.5 * np.pi)) + 3 
    synth_gmsl_function = interp1d(lig_age, lig_synth_gmsl,bounds_error=False,fill_value=(lig_synth_gmsl[0],lig_synth_gmsl[-1]))

    ## replace real data elevations with RSL (from selected 'true' GIA model) + GMSL (from curve above)
    data=load_data()
    GIA_MODEL, age, model_dims = load_model(true_model)
    within_lig = np.array((np.array(age)<=lig_start) & (np.array(age)>=lig_end))
    age=[a for a,check in zip(age,within_lig) if check]
    GIA_MODEL=[g for g,check in zip(GIA_MODEL,within_lig) if check]

    z_functions = interpolation_functions(data["lat"], data["lon"], GIA_MODEL, age, model_dims)
    
    Es = []
    for i in range(len(z_functions)):
        Es.append(z_functions[i](data["age"][i]))
    data["rsl"] = np.array(Es)
    data["elevation"] = np.copy(data["rsl"])
    
    for i in range(data["lat"].size):
            data["elevation"][i] += synth_gmsl_function(data["age"][i])
        
    #adjust coral data for mean water depth 117.14214214214215
    for i in range(data["lat"].size):
        if data['type'][i] == 'coral':
            data["elevation"][i] -= data['water depth mean (m)'][i]
        
    return data

def clear_and_run(click):
    samp=samples_slider.value
    accept=accept_slider.value
    output_simulation.clear_output()
    with output_simulation:     
        to_run = list(models_df.query(
                    f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}"
                ).index)
    
        backup(to_run)
        run_inferences(to_run,samp,accept)
        
def plot_on_click(click):
    models_df = get_model_status(inference_dir,model_dir/'output_glac_w_ice6g/')
    output_simulation.clear_output(wait=True)
    with output_simulation: 
        filtdb=models_df.query(
                    f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}")
                
        to_run = list(filtdb.index)
        
        if any(filtdb['posterior_trace']==False):
            print('Selected models have no traces, please run inference and generate weights..')
            return None
            
        fig=weighted_inference_plot(to_run)
        plt.show()
    return
    
def backup(to_run):
    for f in to_run:
        try:
            model_posterior_dir = str(inference_dir)+'/'+str(f'arviz_traces_{year}')
            os.rename(model_posterior_dir+'/'+f+'.nc', 
                      model_posterior_dir+'_backup/'+f+'.nc')
            model_predict_dir = str(inference_dir)+'/'+str(f'pymc3_post_predict_{year}')
            os.rename(model_predict_dir+'/'+f+'.pkl', 
                      model_predict_dir+'_backup/'+f+'.pkl')
        except FileNotFoundError:
            pass

def run_inferences(to_run,samp,accept):

    data2 = load_data()
    N = data2["elevation"].size
    data2['age_uncertainty']=data2['age_uncertainty']
    data2=data2.sort_values(['type'])
    data2 = data2[data2['type']!='limiting']
    keys = list(data2['type'].unique())
    
    count=1

    for m in to_run:
        clear_output(wait=True)
        print("running simulation number " + str(count) + " of " + str(len(to_run)))
        print("running model: " + m)
        count+=1
        model_name = m

        ## Build the statistical model
        GIA_MODEL, age, model_dims = load_model(m)
        z_functions = interpolation_functions(data2["lat"], data2["lon"], GIA_MODEL, age, model_dims)
        model, gp = inference_model_new(data2,z_functions,keys=keys)
        
        with model:
            ## The Hamiltonian Monte-Carlo sampling step, ie the inference button
            az_trace = pm.sample(tune=samp,draws=samp,target_accept=accept,
                    init='adapt_full', progressbar=True, cores=8, chains=8, 
                              return_inferencedata=True
                )

            #### After fitting, lets make predictions for GMSL at the ages in X_new
            X_new = np.linspace(115, 130, 200)[:, np.newaxis]
            f_pred = gp.conditional(
                "f_pred", X_new, pred_noise=False
            )  

            ## and we will collect our hard work in this subfolders
            model_posterior_dir = str(inference_dir)+'/'+str(f'arviz_traces_{year}/')
            model_predict_dir = str(inference_dir)+'/'+str(f'pymc3_post_predict_{year}/')

            ## A hacky try/except structure to keep a loop of inferences running over night if
            ## something breaks on a given run
            try:
                pred_samples = pm.fast_sample_posterior_predictive(az_trace, var_names=['f_pred'])

                ## Export arviz trace object to netcdf
                az_trace.to_netcdf(model_posterior_dir+model_name+'.nc',groups=["posterior","log_likelihood","sample_stats"])

                ## Export prediction dictionary to python pickle (dict with arrays)
                with open(model_predict_dir+model_name+'.pkl', "wb") as buff:
                    pickle.dump(pred_samples, buff)
                print('Success')
                time.sleep(1)

            except np.linalg.LinAlgError:
                print('Sampling failed, no output saved')
                time.sleep(1)
        
def re_weight(click):
    
    filtered_df_output.clear_output()
    with filtered_df_output:
        print('Recalculating model weights...')
    
    model_posterior_dir = str(inference_dir)+'/'+str(f'arviz_traces_{year}')
    model_posterior_list=[o[:-3] for o in os.listdir(model_posterior_dir) if '.nc' in o]    

    all_traces = {}
    for f in model_posterior_list:
        all_traces[f]=az.from_netcdf(model_posterior_dir+'/'+f+'.nc')

    comp = az.compare(all_traces, ic="loo", method='BB-pseudo-BMA', b_samples=50000, alpha=1) 
    comp.to_csv(str(inference_dir)+'/'+str('model_weights/model_weights.csv'))
    
    update_table(click)

def weighted_inference_plot(to_run):
    
    if type(to_run)!=type([]):
        to_run=[to_run]
    
    model_predict_dir = str(inference_dir)+'/'+str(f'pymc3_post_predict_{year}')
    model_predict_list=[o[:-4] for o in os.listdir(model_predict_dir) if '.pkl' in o]
    
    preds = {}
    for f in to_run:
        if f in model_predict_list:
            preds[f]=load(model_predict_dir+'/'+f+'.pkl')
            
    if len(preds.keys())==0:
        print('Selected models have no traces, please run inference and generate weights..')
        return None

    model_weights = pd.read_csv(str(inference_dir)+'/'+str('model_weights/model_weights.csv'),index_col=0)
    sub_list=[m for m in to_run if m in list(model_weights.index)]
    model_weights=model_weights.loc[sub_list]
    if np.sum(model_weights['weight'])==0:
        model_weights['weight']=1
    else:
        model_weights['weight']=model_weights['weight']/np.sum(model_weights['weight'])

    gmsl=weighted_trace(preds,model_weights,iters=10000)

    X_new = np.linspace(115, 130, 200)[:, np.newaxis]

    f_size=18

    sns.set_style(
        "ticks",
        {
            "axes.edgecolor": ".3",
            "xtick.color": ".3",
            "ytick.color": ".3",
            "text.color": ".3",
            "axes.facecolor": "(.98,.98,.98)",
            "axes.grid": True,
            "grid.color": ".95",
            "grid.linestyle": u"--",
        },
    )
    flatui = ["#D08770", "#BF616A", "#A3BE8C", "#B48EAD", "#34495e", "#5E81AC"]
    cs = sns.color_palette(flatui)

    ##Figure

    scale=1.5
    fig = plt.figure(figsize=(11*scale,4*scale))
    ax1=fig.add_subplot()

    plot_gmsl_inference(X_new,gmsl,cs[4],ax1,False)
    plt.gca().set_title(
        "A. Last Interglacial GMSL",
        fontsize=f_size,
    )
    ax1.set_aspect(1/2)
    ax1.set_ylim([0, 12])
    # ax1.set_yticks([-2,0,2,4,6])
    # ax1.set_yticklabels([-2,0,2,4,6],fontsize=f_size)
    ax1.set_xlim(117, 128)
    ax1.invert_xaxis()
    ax1.set_xticks(np.arange(128,116,-1))
    ax1.set_xticklabels(np.arange(128,116,-1),fontsize=f_size)
    ax1.legend(loc="best", frameon=True, fontsize=f_size*.66)

    ax1.set_ylabel("Global Mean Sea Level\n(m above MSL)", fontsize=f_size)
    ax1.set_xlabel("Age (kya)",fontsize=f_size)
    ax1.grid(linewidth=1)
    
#     ## Synthetic GMSL
#     subsidence = 2.5
#     lig_start = 128
#     lig_end = 117
#     lig_age = np.linspace(lig_start,lig_end,100)
#     lig_dt = lig_age-lig_end
#     lig_synth_gmsl = 3 * np.sin(lig_dt / (0.5 * np.pi)) + 3 + subsidence
    
#     ax1.plot(lig_age, lig_synth_gmsl, label="Simulation GMSL", lw=3, linestyle='--', zorder=30, color='k')


    fig.tight_layout(w_pad=0,h_pad=0)
    return fig

models_df = get_model_status(inference_dir,model_dir/'output_glac_w_ice6g/')
fmt = Layout(width="5vw", height="250px")
filtered_df_output = widgets.Output()
output_simulation = widgets.Output()


def update_table(change):
    models_df = get_model_status(inference_dir,model_dir/'output_glac_w_ice6g/')
    filtered_df_output.clear_output(wait=True)
    with filtered_df_output:
        display(
            models_df.query(
                f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}"
            ).sort_values('weight',ascending=False)
        )


widge_lith = widgets.SelectMultiple(
    options=np.sort(models_df["Lithosphere"].unique().astype(int)),
    # rows=10,
    #     description="Lithosphere",
    disabled=False,
    layout=fmt,
)


widge_umv = widgets.SelectMultiple(
    options=np.sort(models_df["UMV"].unique().astype(int)),
    # rows=10,
    #     description="UMV",
    disabled=False,
    layout=fmt,
)

widge_lmv = widgets.SelectMultiple(
    options=np.sort(models_df["LMV"].unique().astype(int)),
    # rows=10,
    #     description="LMV",
    disabled=False,
    layout=fmt,
)

widge_ice = widgets.SelectMultiple(
    options=models_df["ice_history"].unique(),
    # rows=10,
    #     description="Ice History",
    disabled=False,
    layout=fmt,
)


widge_gmsl = widgets.SelectMultiple(
    options=models_df["esl_curve"].unique(),
    # rows=10,
    #     description="GMSL",
    disabled=False,
    layout=fmt,
)

widge_post = widgets.SelectMultiple(
    options=[True, False],
    # rows=10,
    #     description="Posterior Trace",
    disabled=False,
    layout=fmt,
)

widge_predict = widgets.SelectMultiple(
    options=[True, False],
    # rows=10,
    #     description="Posterior Prediction",
    disabled=False,
    layout=fmt,
)

samples_slider=widgets.IntSlider(value=500,
    min=100,step=50,
    max=2000,layout=fmt,orientation='vertical')

accept_slider=widgets.FloatSlider(value=.95,
    min=.8,step=0.01,
    max=1,layout=fmt,orientation='vertical')

label_list = [
    "Lithosphere",
    "UMV",
    "LMV",
    "ice_history",
    "esl_curve",
    "posterior_trace",
    "posterior_prediction",
    "Posterior samples",
    "Acceptance Target"
]
widget_list = [
    widge_lith,
    widge_umv,
    widge_lmv,
    widge_ice,
    widge_gmsl,
    widge_post,
    widge_predict,
    samples_slider,
    accept_slider
]

wv_list = [VBox([Label(l), w]) for l, w in zip(label_list, widget_list)]
h_box = HBox(wv_list)

for w in widget_list:
    w.observe(update_table)

N_button = 3
pct = 100 / N_button
bt_layout = Layout(width=str(int(pct)) + "%")
plot_inference_button = Button(
    description="Plot weighted inference for selected", layout=bt_layout
)
rerun_weights = Button(description="Recalculate weights for all", layout=bt_layout)
rerun_inference_button = Button(
    description="Rerun GMSL inference for selected", layout=bt_layout
)

plot_inference_button.on_click(plot_on_click)
rerun_weights.on_click(re_weight)
rerun_inference_button.on_click(clear_and_run)

buttons = [plot_inference_button, rerun_weights, rerun_inference_button]