# Imports

In [1]:
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style("ticks")
sns.set_context("talk")

import arviz as az
import warnings
warnings.filterwarnings("ignore")

from pathlib import Path
PROJECT_ROOT = Path.cwd().parents[0]
import sys
sys.path.append(str(PROJECT_ROOT))
from bahamas_lig.utils import *

model_dir = PROJECT_ROOT / "model_outputs/"
inference_dir = PROJECT_ROOT / "model_outputs/holocene"
data_dir = PROJECT_ROOT / "data/"

from IPython.display import clear_output
import time
import ipywidgets as widgets
from ipywidgets import Layout
from IPython.display import display
from ipywidgets import Button, HBox, VBox, Label

The following javascript cell runs bottom the cell of notebook to define some helper functions and force refresh the IPython displays.

In [2]:
%%javascript 
Jupyter.notebook.execute_cells([-1,-5,-4,-3])

<IPython.core.display.Javascript object>

In [4]:
print("Select models to inspect:")
display(h_box)
display(HBox(buttons))

Select models to inspect:


HBox(children=(VBox(children=(Label(value='Lithosphere'), SelectMultiple(layout=Layout(height='16vh', width='5…

HBox(children=(Button(description='Plot weighted inference for selected', layout=Layout(width='33%'), style=Bu…

In [5]:
display(output_simulation)

Output()

In [6]:
display(filtered_df_output)

Output()

## Helper functions and widget definitions

In [3]:
def load_holocene_data():
    
    data = pd.read_csv(data_dir / "processed/gmsl_holocene_data.csv")
    data['age'] = data['age']/1000
    data['age_uncertainty'] = data['age_uncertainty']/1000

    return data

def clear_and_run(click):
    samp=samples_slider.value
    accept=accept_slider.value
    output_simulation.clear_output()
    with output_simulation:     
        to_run = list(models_df.query(
                    f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}"
                ).index)
    
        backup(to_run)
        run_inferences(to_run,samp,accept)
        
def plot_on_click(click):
    models_df = get_model_status(inference_dir,model_dir/'get_GIA/output_Hol_new/')
    output_simulation.clear_output(wait=True)
    with output_simulation: 
        filtdb=models_df.query(
                    f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}")
                
        to_run = list(filtdb.index)
        
        if any(filtdb['posterior_trace']==False):
            print('Selected models have no traces, please run inference and generate weights..')
            return None
            
        fig=weighted_inference_plot(to_run)
        plt.show()
    return
    
def backup(to_run):
    

        for f in to_run:
            try:
                model_posterior_dir = str(inference_dir)+'/'+str('arviz_traces_2021')
                os.rename(model_posterior_dir+'/'+f+'.nc', 
                          model_posterior_dir+'_backup/'+f+'.nc')
                model_predict_dir = str(inference_dir)+'/'+str('pymc3_post_predict_2021')
                os.rename(model_predict_dir+'/'+f+'.pkl', 
                          model_predict_dir+'_backup/'+f+'.pkl')
            except FileNotFoundError:
                pass

def run_inferences(to_run,samp,accept):

    data2 = load_holocene_data()
    N = data2["elevation"].size
    data2=data2.sort_values(['type'])
    keys = list(data2['type'].unique())
    

    count=1

    for m in to_run:
        clear_output(wait=True)
        print("running simulation number " + str(count) + " of " + str(len(to_run)))
        print("running model: " + m)
        count+=1
        model_name = m

        ## Build the statistical model
        GIA_MODEL, age, model_dims = load_model(m, rsl_dir = 'output_Hol_new')
        GIA_MODEL = [GIA_MODEL[a] for a in np.argsort(age)]
        age = np.sort(age)
        z_functions = interpolation_functions(data2["lat"], data2["lon"], GIA_MODEL, age, model_dims)
        model, gp = inference_model(data2,z_functions,keys=keys,holocene=True)
        
        with model:
            ## The Hamiltonian Monte-Carlo sampling step, ie the inference button
            az_trace = pm.sample(tune=samp,draws=samp,
                    init='adapt_full', progressbar=True, cores=1, target_accept=accept, chains=1, 
                              return_inferencedata=True
                )

            #### After fitting, lets save our hard work
            
            ## and we will collect our hard work in this subfolders
            model_posterior_dir = str(inference_dir)+'/'+str('arviz_traces_2021/')
            model_predict_dir = str(inference_dir)+'/'+str('pymc3_post_predict_2021/')

            az_trace.to_netcdf(model_posterior_dir+model_name+'.nc',groups=["posterior","log_likelihood"])
            print('Success')
        
def re_weight(click):
    
    filtered_df_output.clear_output()
    with filtered_df_output:
        print('Recalculating model weights...')
    
    model_posterior_dir = str(inference_dir)+'/'+str('arviz_traces_2021')
    model_posterior_list=[o[:-3] for o in os.listdir(model_posterior_dir) if '.nc' in o]    

    all_traces = {}
    for f in model_posterior_list:
        all_traces[f]=az.from_netcdf(model_posterior_dir+'/'+f+'.nc')

    comp = az.compare(all_traces, ic="loo", method='BB-pseudo-BMA', b_samples=50000, alpha=1) 
    comp.to_csv(str(inference_dir)+'/'+str('model_weights/model_weights.csv'))
    
    update_table(click)

def weighted_inference_plot(to_run):
    
    if type(to_run)!=type([]):
        to_run=[to_run]
    
    model_predict_dir = str(inference_dir)+'/'+str('pymc3_post_predict_2021')
    model_predict_list=[o[:-4] for o in os.listdir(model_predict_dir) if '.pkl' in o]
    
    preds = {}
    for f in to_run:
        if f in model_predict_list:
            preds[f]=load(model_predict_dir+'/'+f+'.pkl')
            
    if len(preds.keys())==0:
        print('Selected models have no traces, please run inference and generate weights..')
        return None

    model_weights = pd.read_csv(str(model_dir)+'/'+str('model_weights/model_weights.csv'),index_col=0)
    sub_list=[m for m in to_run if m in list(model_weights.index)]
    model_weights=model_weights.loc[sub_list]
    if np.sum(model_weights['weight'])==0:
        model_weights['weight']=1
    else:
        model_weights['weight']=model_weights['weight']/np.sum(model_weights['weight'])

    gmsl=weighted_trace(preds,model_weights,iters=10000)

    X_new = np.linspace(115, 130, 200)[:, np.newaxis]

    f_size=18

    sns.set_style(
        "ticks",
        {
            "axes.edgecolor": ".3",
            "xtick.color": ".3",
            "ytick.color": ".3",
            "text.color": ".3",
            "axes.facecolor": "(.98,.98,.98)",
            "axes.grid": True,
            "grid.color": ".95",
            "grid.linestyle": u"--",
        },
    )
    flatui = ["#D08770", "#BF616A", "#A3BE8C", "#B48EAD", "#34495e", "#5E81AC"]
    cs = sns.color_palette(flatui)

    ##Figure

    scale=1.5
    fig = plt.figure(figsize=(11*scale,4*scale))
    ax1=fig.add_subplot()

    plot_gmsl_inference(X_new,gmsl,cs[4],ax1,False)
    plt.gca().set_title(
        "A. Last Interglacial GMSL",
        fontsize=f_size,
    )
    ax1.set_aspect(1/2)
    ax1.set_ylim([-2, 6])
    # ax1.set_yticks([-2,0,2,4,6])
    # ax1.set_yticklabels([-2,0,2,4,6],fontsize=f_size)
    ax1.set_xlim(117, 128)
    ax1.invert_xaxis()
    ax1.set_xticks(np.arange(128,116,-1))
    ax1.set_xticklabels(np.arange(128,116,-1),fontsize=f_size)
    ax1.legend(loc="best", frameon=True, fontsize=f_size*.66)

    ax1.set_ylabel("Global Mean Sea Level\n(m above MSL)", fontsize=f_size)
    ax1.set_xlabel("Age (kya)",fontsize=f_size)
    ax1.grid(linewidth=1)


    fig.tight_layout(w_pad=0,h_pad=0)
    return fig

models_df = get_model_status(inference_dir,model_dir/'get_GIA/output_Hol_new/')
fmt = Layout(width="5vw", height="16vh")
filtered_df_output = widgets.Output()
output_simulation = widgets.Output()


def update_table(change):
    models_df = get_model_status(inference_dir,model_dir/'get_GIA/output_Hol_new/')
    filtered_df_output.clear_output(wait=True)
    with filtered_df_output:
        display(
            models_df.query(
                f"posterior_trace == {list(widge_post.value)} & posterior_predict == {list(widge_predict.value)} & Lithosphere == {list(widge_lith.value)} & UMV == {list(widge_umv.value)} & LMV == {list(widge_lmv.value)} & `ice_history` == {list(widge_ice.value)} & esl_curve == {list(widge_gmsl.value)}"
            ).sort_values('weight',ascending=False)
        )


widge_lith = widgets.SelectMultiple(
    options=np.sort(models_df["Lithosphere"].unique().astype(int)),
    # rows=10,
    #     description="Lithosphere",
    disabled=False,
    layout=fmt,
)


widge_umv = widgets.SelectMultiple(
    options=np.sort(models_df["UMV"].unique().astype(int)),
    # rows=10,
    #     description="UMV",
    disabled=False,
    layout=fmt,
)

widge_lmv = widgets.SelectMultiple(
    options=np.sort(models_df["LMV"].unique().astype(int)),
    # rows=10,
    #     description="LMV",
    disabled=False,
    layout=fmt,
)

widge_ice = widgets.SelectMultiple(
    options=models_df["ice_history"].unique(),
    # rows=10,
    #     description="Ice History",
    disabled=False,
    layout=fmt,
)


widge_gmsl = widgets.SelectMultiple(
    options=models_df["esl_curve"].unique(),
    # rows=10,
    #     description="GMSL",
    disabled=False,
    layout=fmt,
)

widge_post = widgets.SelectMultiple(
    options=[True, False],
    # rows=10,
    #     description="Posterior Trace",
    disabled=False,
    layout=fmt,
)

widge_predict = widgets.SelectMultiple(
    options=[True, False],
    # rows=10,
    #     description="Posterior Prediction",
    disabled=False,
    layout=fmt,
)

samples_slider=widgets.IntSlider(value=500,
    min=100,step=50,
    max=2000,layout=fmt,orientation='vertical')

accept_slider=widgets.FloatSlider(value=.95,
    min=.8,step=0.01,
    max=1,layout=fmt,orientation='vertical')

label_list = [
    "Lithosphere",
    "UMV",
    "LMV",
    "ice_history",
    "esl_curve",
    "posterior_trace",
    "posterior_prediction",
    "Posterior samples",
    "Acceptance Target"
]
widget_list = [
    widge_lith,
    widge_umv,
    widge_lmv,
    widge_ice,
    widge_gmsl,
    widge_post,
    widge_predict,
    samples_slider,
    accept_slider
]

wv_list = [VBox([Label(l), w]) for l, w in zip(label_list, widget_list)]
h_box = HBox(wv_list)

for w in widget_list:
    w.observe(update_table)

N_button = 3
pct = 100 / N_button
bt_layout = Layout(width=str(int(pct)) + "%")
plot_inference_button = Button(
    description="Plot weighted inference for selected", layout=bt_layout
)
rerun_weights = Button(description="Recalculate weights for all", layout=bt_layout)
rerun_inference_button = Button(
    description="Rerun GMSL inference for selected", layout=bt_layout
)

plot_inference_button.on_click(plot_on_click)
rerun_weights.on_click(re_weight)
rerun_inference_button.on_click(clear_and_run)

buttons = [plot_inference_button, rerun_weights, rerun_inference_button]