In [1]:
from config.batched_dropout3.dropout_0_4_2 import config
config

{'threshold': 1.5,
 'datapath': '../../../climate/sim-data/preproc/',
 'filenames_sims': ['tas_anual_preproc_ssp126_ACCESS-ESM1-5.nc',
  'tas_anual_preproc_ssp126_CanESM5.nc',
  'tas_anual_preproc_ssp126_MIROC-ES2L.nc',
  'tas_anual_preproc_ssp126_MIROC6.nc',
  'tas_anual_preproc_ssp126_UKESM1-0-LL.nc',
  'tas_anual_preproc_ssp245_ACCESS-ESM1-5.nc',
  'tas_anual_preproc_ssp245_CNRM-ESM2-1.nc',
  'tas_anual_preproc_ssp245_CanESM5.nc',
  'tas_anual_preproc_ssp245_GISS-E2-1-G.nc',
  'tas_anual_preproc_ssp245_IPSL-CM6A-LR.nc',
  'tas_anual_preproc_ssp245_MIROC-ES2L.nc',
  'tas_anual_preproc_ssp370_ACCESS-ESM1-5.nc',
  'tas_anual_preproc_ssp370_CESM2.nc',
  'tas_anual_preproc_ssp370_CanESM5.nc',
  'tas_anual_preproc_ssp370_GISS-E2-1-G.nc',
  'tas_anual_preproc_ssp370_IPSL-CM6A-LR.nc',
  'tas_anual_preproc_ssp370_MIROC-ES2L.nc',
  'tas_anual_preproc_ssp370_UKESM1-0-LL.nc'],
 'context': 'ssp&model&prior',
 'informative_prior': {'type': 'truncated-normal', 'mean': 10, 'std': 10},
 'year_bounds

In [2]:
import importlib
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import sys
import time

sys.path.append(os.path.abspath(os.path.join('../../../BayesFlow/')))
sys.path.append(os.path.abspath(os.path.join('../..')))  # access sibling directories
from bayesflow.computational_utilities import posterior_calibration_error
from sklearn.metrics import r2_score
from src.python.helpers import _configure_input, format_names, estimate_data_means_and_stds
from src.python.settings import plotting_update
#sns.set_theme(style='white', rc={'axes.facecolor': (0, 0, 0, 0)})
plt.rcParams.update(plotting_update)

from setup import *

2024-05-06 11:33:55.624686: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from tqdm.autonotebook import tqdm


In [3]:
RNG = np.random.default_rng(config['rng_seed'])

# Load data

In [4]:
DATAPATH = '../../../climate/sim-data/preproc/' 

datasets = load_datasets(config, data_path=DATAPATH)
model_names = list(datasets.keys())

model = build_generative_model(config, datasets, RNG)

train_data_dict = OrderedDict((key, datasets[key].TAS.isel(member=config['member_split']['train']) - datasets[key].TAS_baseline) for key in datasets.keys())
data_means, data_stds = estimate_data_means_and_stds(train_data_dict)
joint_prior_means, joint_prior_stds = model.prior.estimate_means_and_stds()

configure_input = partial(_configure_input, prior_means=joint_prior_means, prior_stds=joint_prior_stds, data_means=data_means, data_stds=data_stds, context=config['context'])

amortizer = build_amortizer(config)

joint_trainer = Trainer(
    amortizer=amortizer, configurator=configure_input, checkpoint_path=config['checkpoint_path'],
    generative_model=model, memory=True, reuse_optimizer=True,
)

INFO:root:Using uniform and truncated normal, (10, 10), prior on time-to-threshold between -40 and 41.
INFO:root:Using the following climate models: ['ssp126_ACCESS-ESM1-5', 'ssp126_CanESM5', 'ssp126_MIROC-ES2L', 'ssp126_MIROC6', 'ssp126_UKESM1-0-LL', 'ssp245_ACCESS-ESM1-5', 'ssp245_CNRM-ESM2-1', 'ssp245_CanESM5', 'ssp245_GISS-E2-1-G', 'ssp245_IPSL-CM6A-LR', 'ssp245_MIROC-ES2L', 'ssp370_ACCESS-ESM1-5', 'ssp370_CESM2', 'ssp370_CanESM5', 'ssp370_GISS-E2-1-G', 'ssp370_IPSL-CM6A-LR', 'ssp370_MIROC-ES2L', 'ssp370_UKESM1-0-LL']
INFO:root:Performing 2 pilot runs with the anonymous model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 72, 144)
INFO:root:No optional prior non-batchable context provided.
INFO:root:Could not determine shape of prior batchable context. Type appears to be non-array: <class 'list'>,                                    so make sure your input configurat

In [5]:
metrics = {}    # keep track of metrics for this run

# To speed up the validation process, we can use pre-simulated train and validation data. If the path to the pre-simulated data is
# not provided in the config, we will simulate the data here. If the path is provided, we will load the data and compute
# the loss.

try:
    if not 'presimulate_path' in config.keys():
        raise KeyError("No presimulate_path in config")

    for state in ['train', 'val']:
        sims_path = os.path.join(config['presimulate_path'], f'presim_{state}.pkl')
        if not os.path.exists(sims_path):
            raise FileNotFoundError(f"Could not find pre-simulated data at {sims_path}")

        sims = np.load(sims_path, allow_pickle=True)
        data = joint_trainer.configurator(sims)
        loss = amortizer.compute_loss(data).numpy()

        metrics[f'{state}_loss'] = {'mean':loss, 'std':None}
        print(f"{state.capitalize()} loss: {loss:.3f}")

        # name in same format as freshly simulated data
        if state == 'train':
            train_data = data
        elif state == 'val':
            sims_val = sims
            val_data = data
        else:
            raise ValueError(f"Unknown state: {state}")

# if no pre-simulated data is available, simulate it here
except (KeyError, FileNotFoundError) as e:
    # Generate and compute loss for train data
    train_loss_batches = []
    train_data_list = []
    for _ in range(10):
        sims_train = joint_trainer.generative_model(batch_size=100, sim_args={'state': 'train'})
        batch_train_data = joint_trainer.configurator(sims_train)
        train_loss_batches.append(amortizer.compute_loss(batch_train_data).numpy())
        train_data_list.append(batch_train_data)

    train_data = {}
    for key in train_data_list[0].keys(): # concatenate the data from the different batches
        train_data[key] = np.concatenate([train_data_list[i][key] for i in range(len(train_data_list))], axis=0)

    metrics['train_loss'] = {'mean': np.mean(train_loss_batches), 'std': np.std(train_loss_batches)}
    print(f"Train loss: {metrics['train_loss']['mean']:.3f} ± {metrics['train_loss']['std']:.3f}")
    del sims_train, batch_train_data    # free up memory


    # Generate and compute loss for validation data
    val_loss_batches = []
    val_data_list = []
    for _ in range(10):
        sims_val = joint_trainer.generative_model(batch_size=500, sim_args={'state': 'val'})
        batch_val_data = joint_trainer.configurator(sims_val)
        val_loss_batches.append(amortizer.compute_loss(batch_val_data).numpy())
        val_data_list.append(batch_val_data)

    val_data = {}
    for key in val_data_list[0].keys(): # concatenate the data from the different batches
        val_data[key] = np.concatenate([val_data_list[i][key] for i in range(len(val_data_list))], axis=0)

    metrics['val_loss'] = {'mean': np.mean(val_loss_batches), 'std': np.std(val_loss_batches)}
    print(f"Val loss:   {metrics['val_loss']['mean']:.3f} ± {metrics['val_loss']['std']:.3f}")
    del sims_val, batch_val_data    # free up memory

Train loss: 0.754
Val loss: 1.049


In [6]:
[i.shape for i in val_data.values()]

[(5000, 72, 144), (5000, 19), (5000, 2)]

In [7]:
val_data['direct_conditions'][0:3,:]

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]], dtype=float16)

# Load separate Networks

In [8]:
def import_config(name):
    try:
        # Dynamically import the module
        module = importlib.import_module(f"config.{name}")
        return module.config
    except ModuleNotFoundError:
        print(f"No module named 'config.{name}'")
        return None
    except AttributeError:
        print(f"Module 'config.{name}' does not have a 'config' attribute")
        return None

In [9]:
class SeparateModel:
    """
    Contains the config and datasets, loaded generative model, configurator, and loaded amortizer.
    """
    def __init__(self, config_name, RNG):

        self.config_name = config_name

        self.config = import_config(config_name)

        self.datasets = load_datasets(self.config, data_path=DATAPATH)
        self.model = build_generative_model(self.config, self.datasets, RNG)

        self.train_data_dict = OrderedDict((key, self.datasets[key].TAS.isel(member=self.config['member_split']['train']) - self.datasets[key].TAS_baseline) for key in self.datasets.keys())
        self.data_means, self.data_stds = estimate_data_means_and_stds(self.train_data_dict)
        self.prior_means, self.prior_stds = self.model.prior.estimate_means_and_stds()

        self.configure_input = partial(_configure_input, prior_means=self.prior_means, prior_stds=self.prior_stds, data_means=self.data_means, data_stds=self.data_stds, context=self.config['context'], context_aware=self.config['context_aware'])

        self.amortizer = build_amortizer(self.config)

        self.trainer = Trainer(
            amortizer=self.amortizer, configurator=self.configure_input, checkpoint_path=self.config['checkpoint_path'],
            generative_model=self.model, memory=True, reuse_optimizer=True,
        )

In [10]:
# Initialize separate models (n = # ensemble members)
sep_model_names = [f'separate_{config_i}_{n}' for n in [1] for config_i in range(18)]
sep_models = OrderedDict((name, SeparateModel(f'batched_separate_dropout2.{name}', RNG)) for name in sep_model_names if name in os.listdir('checkpoints/batched_separate_dropout2'))

INFO:root:Using uniform prior on time-to-threshold between -40 and 41.
INFO:root:Using the following climate models: ['ssp126_ACCESS-ESM1-5']


INFO:root:Performing 2 pilot runs with the anonymous model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 72, 144)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:Could not determine shape of simulation batchable context. Type appears to be non-array: <class 'list'>,                                    so make sure your input configurator takes cares of that!
INFO:root:Built generative model for temperature maps parametrized by the years before the warming threshold 1.5°C is reached.
INFO:root:Loaded loss history from checkpoints/batched_separate_dropout2/separate_0_1/history_15.pkl.
INFO:root:Loaded simulation memory from checkpoints/batched_separate_dropout2/separate_0_1/memory.pkl
INFO:root:Networks loaded from checkpoints/batc

In [11]:
# member_split assigns trajectories to train, val, and test sets
sep_models['separate_0_1'].config['member_split']

{'train': [0, 1, 2, 3, 4, 5, 6], 'val': [7, 8], 'test': [9]}

# Calculate metrics

In [12]:
def get_metrics(parameters, draws, agg_fun_draws=np.median, agg_fun=np.median, override_prior_vars=None):
    """
    Get various approximation performance metrics for a given set of parameters and posterior draws.
    Intermediary output shapes vary by metric, but if aggregation is necessary, agg_fun is applied.

    Parameters:
    ----------
    parameters : np.array of shape (n_data_sets, n_parameters)
        The true parameter values (standardized by net).
    draws : np.array of shape (n_data_sets, n_draws, n_parameters)
        The posterior draws (standardized by net).
    agg_fun_draws : function, optional
        The aggregation function to apply over the posterior draws to get a point estimate per data set and parameter.
        Default is np.median.
    agg_fun : function, optional
        The aggregation function to reduce the computed metrics to a single number if necessary. Default is np.mean.

    Returns:
    -------
    rmse : float
        The root mean squared error.
    mae : float
        The mean absolute error.
    r2 : float
        The R-squared coefficient.
    corr : float
        The correlation coefficient.
    ece : float
        The expected calibration error.
    mmd : float
        The maximum mean discrepancy.
    post_contraction : float
        The posterior contraction.
    """

    def root_mean_squared_error(x_true, x_pred):
        """ Gets the RMSE between true parameters and posterior draws for each test data set, draw and parameter. """
        return np.sqrt(np.mean(np.square(x_true[:, np.newaxis, :] - x_pred)))

    def mean_absolute_error(x_true, x_pred):
        """ Gets the MAE between true parameters and posterior draws for each test data set, draw and parameter. """
        return np.mean(np.abs((x_true[:, np.newaxis, :] - x_pred)))

    # def mmd_over_datasets(x_true, x_pred):
    #     """ Gets the MMD between true parameters and posterior draws for each test data set. """
    #     mmd_array = np.empty(draws.shape[0], dtype=np.float32)
    #     for i in range(x_pred.shape[0]):
    #         mmd_array[i] = bf.computational_utilities.maximum_mean_discrepancy(
    #             x_true.astype('float32')[i, :][np.newaxis, :], # align shapes to (1, n_parameters)
    #             x_pred[i, :, :], # align shapes to (num_draws, n_parameters)
    #         )
    #     return mmd_array

    def post_cont(x_true, x_pred, override_prior_vars=None):
        """ Gets the posterior contraction between true parameters and posterior draws for each test data set and parameter. """
        post_vars = x_pred.var(axis=1, ddof=1)
        prior_vars = x_true.var(axis=0, keepdims=True, ddof=1)
        if override_prior_vars != None:
            # print('Overriding prior vars', prior_vars, 'with', override_prior_vars)
            prior_vars = np.repeat(override_prior_vars, repeats=x_pred.shape[-1])
        return 1 - (post_vars / prior_vars)

    agg_draws = agg_fun_draws(draws, axis=1) # point estimates for r2 and correlation

    rmse = root_mean_squared_error(parameters, draws)
    mae = mean_absolute_error(parameters, draws)
    r2 = r2_score(parameters, agg_draws)
    corr = agg_fun([np.corrcoef(parameters[:, i], np.median(draws, axis=1)[:, i])[0, 1] for i in range(parameters.shape[-1])])
    ece = agg_fun(
        posterior_calibration_error(prior_samples=parameters, posterior_samples=draws)
    ) # shape before aggregation: (num_params)
    #mmd = agg_fun(mmd_over_datasets(parameters, draws)) # shape before aggregation: (num_datasets)
    post_contraction = agg_fun(post_cont(parameters, draws, override_prior_vars)) # shape before aggregation: (num_datasets, num_params)

    return rmse, mae, r2, corr, ece, post_contraction

In [13]:
sep_models.keys()

odict_keys(['separate_0_1', 'separate_1_1', 'separate_2_1', 'separate_3_1', 'separate_4_1', 'separate_5_1', 'separate_6_1', 'separate_7_1', 'separate_8_1', 'separate_9_1', 'separate_10_1', 'separate_11_1', 'separate_12_1', 'separate_13_1', 'separate_14_1', 'separate_15_1', 'separate_16_1', 'separate_17_1'])

In [14]:

metrics_names = ['RMSE', 'MAE', 'R2', 'Correlation', 'ECE', 'Posterior Contraction']
metrics = np.empty((len(sep_models.keys()) * 2, 2 + len(metrics_names)), dtype=object)

sep_inference_times_per_setting = []
joint_inference_times_per_setting = []

for i,sep_model_key in enumerate(sep_models.keys()): # Iterate over all models

    # Separate models

    # Initialize separate model, its filename(s) and index
    sep_model = sep_models[sep_model_key]
    fns = sep_model.config['filenames_sims'] # legacy from joint training over all filenames
    fn_idx = [config['filenames_sims'].index(fn) for fn in fns] # find the index of fn in fn_list
    assert len(fn_idx) == 1 # should only be a single file!
    sim_ind = fn_idx[0]

    # Get true parameters and summary conditions of validation data belonging to the separate model
    condition = np.concatenate([np.eye(len(config['filenames_sims']))[sim_ind], [0]])

    use_presimulated_sep_data = True
    if use_presimulated_sep_data:
        val_sims = pickle.load(open(f'checkpoints/presims-{i}/presim_val.pkl', 'rb'))
        assert f'checkpoints/presims-{i}/' == sep_model.config['presimulate_path']
        true_params = val_sims['prior_draws']
        data = val_sims['sim_data']
        mask = np.ones(val_sims['sim_data'].shape[0], dtype=bool) # all data in presims-i fits for model i
    else:
        # condition selects the corresponding climate model, appends 0 to always select uninformative prior
        mask = np.all(val_data['direct_conditions'] == condition, axis=1) # filter for corresponding data sets
        true_params = sims_val['prior_draws'][mask, :]
        data = sims_val['sim_data'][mask, :, :]

    # Prepare params + data for separate model
    sep_params_stand = (true_params - sep_model.prior_means) / sep_model.prior_stds
    sep_data_stand = (data - sep_model.data_means) / sep_model.data_stds

    sep_start_time = time.time()
    sep_samples = sep_model.amortizer.sample(
        {
            'summary_conditions': sep_data_stand,
        },
        n_samples=100,
        to_numpy=True,
    )
    sep_end_time = time.time()
    inference_time = sep_end_time - sep_start_time
    sep_inference_times_per_setting.append(inference_time)

    sep_name = format_names(sep_models[sep_model_key].config['filenames_sims'][0][18:-3])+ ' ' + sep_model_key.split('_')[-1]

    rmse, mae, r2, corr, ece, post_contraction = get_metrics(
        sep_params_stand[..., 0][..., np.newaxis]*sep_model.prior_stds[:, 0]+sep_model.prior_means[:, 0],
        sep_samples[..., 0][..., np.newaxis]*sep_model.prior_stds[:, 0]+sep_model.prior_means[:, 0],
        override_prior_vars=sep_model.prior_stds[:, 0]**2
    )
    metrics[2*i] = [sep_name, 'separate', rmse, mae, r2, corr, ece, post_contraction]


    # Joint model

    # Prepare params + data for separate model
    joint_params_stand = (true_params - joint_prior_means) / joint_prior_stds
    joint_data_stand = (data - data_means) / data_stds

    joint_start_time = time.time()
    if use_presimulated_sep_data:
        joint_samples = joint_trainer.amortizer.sample(
            {
                'summary_conditions': joint_data_stand,
                'direct_conditions': condition[np.newaxis, :].repeat(5000, axis=0),
            },
            n_samples=100,
            to_numpy=True,
        )
    else:
        joint_samples = joint_trainer.amortizer.sample(
            {
                'summary_conditions': joint_data_stand,
                'direct_conditions': val_data['direct_conditions'][mask],
            },
            n_samples=100,
            to_numpy=True,
        )
    joint_end_time = time.time()
    joint_inference_times_per_setting.append(joint_end_time - joint_start_time)

    rmse, mae, r2, corr, ece, post_contraction = get_metrics(
        joint_params_stand[..., 0][..., np.newaxis]*joint_prior_stds[:, 0]+joint_prior_means[:, 0],
        joint_samples[..., 0][..., np.newaxis]*joint_prior_stds[:, 0]+joint_prior_means[:, 0],
        override_prior_vars=sep_model.prior_stds[:, 0]**2
    )

    metrics[2*i+1] = [sep_name, 'joint', rmse, mae, r2, corr, ece, post_contraction]



    if False: # plotting
        # TODO: Uses old means and stds, should be updated to use the new ones
        # --- plots for separate model

        sep_history = sep_model.trainer.loss_history.get_plottable()
        f = diag.plot_losses(sep_history["train_losses"], sep_history["val_losses"])
        f.axes[0].set_ylim(-1, 3)
        f.set_figwidth(6)
        f.set_figheight(3)

        f = diag.plot_recovery(sep_samples*sep_model.prior_stds+sep_model.prior_means,
                        sep_params_stand*sep_model.prior_stds+sep_model.prior_means)
        f.delaxes(f.axes[1])
        f.suptitle(sep_model_key)

        f = diag.plot_sbc_ecdf(sep_samples*sep_model.prior_stds+sep_model.prior_means,
                        sep_params_stand*sep_model.prior_stds+sep_model.prior_means, difference=True)
        f.delaxes(f.axes[1])
        f.suptitle(sep_model_key)


        # --- plots for joint model

        f = diag.plot_recovery(joint_samples*joint_prior_stds+joint_prior_means,
                        joint_params_stand*joint_prior_stds+joint_prior_means)
        f.delaxes(f.axes[1])
        f.suptitle('joint')


        f = diag.plot_sbc_ecdf(joint_samples*joint_prior_stds+joint_prior_means,
                        joint_params_stand*joint_prior_stds+joint_prior_means, difference=True)
        f.delaxes(f.axes[1])
        f.suptitle('joint')



columns = ['SSP + Climate Model + n', 'separate or joint'] + metrics_names
benchmark_results = pd.DataFrame(metrics, columns=columns)

# Extract SSP, Climate Model, and n from the first column for convenience
benchmark_results['SSP'] = benchmark_results['SSP + Climate Model + n'].apply(lambda x: x.split(' ')[0])
benchmark_results['Climate Model'] = benchmark_results['SSP + Climate Model + n'].apply(lambda x: x.split(' ')[1])
benchmark_results['n'] = benchmark_results['SSP + Climate Model + n'].apply(lambda x: x.split(' ')[-1])
benchmark_results

Unnamed: 0,SSP + Climate Model + n,separate or joint,RMSE,MAE,R2,Correlation,ECE,Posterior Contraction,SSP,Climate Model,n
0,SSP1-2.6 ACCESS-ESM1-5 1,separate,5.80056,4.609503,0.956586,0.982407,0.142495,0.981011,SSP1-2.6,ACCESS-ESM1-5,1
1,SSP1-2.6 ACCESS-ESM1-5 1,joint,5.57007,4.195546,0.966862,0.98366,0.055995,0.984344,SSP1-2.6,ACCESS-ESM1-5,1
2,SSP1-2.6 CanESM5 1,separate,5.352142,4.288163,0.968183,0.987523,0.090763,0.979274,SSP1-2.6,CanESM5,1
3,SSP1-2.6 CanESM5 1,joint,4.970421,3.916087,0.970042,0.991192,0.133363,0.985867,SSP1-2.6,CanESM5,1
4,SSP1-2.6 MIROC-ES2L 1,separate,6.724132,5.300568,0.935558,0.972314,0.227368,0.980544,SSP1-2.6,MIROC-ES2L,1
5,SSP1-2.6 MIROC-ES2L 1,joint,5.627969,4.276042,0.963483,0.985152,0.071689,0.98252,SSP1-2.6,MIROC-ES2L,1
6,SSP1-2.6 MIROC6 1,separate,8.567691,6.914181,0.892326,0.949299,0.245074,0.974019,SSP1-2.6,MIROC6,1
7,SSP1-2.6 MIROC6 1,joint,11.382192,8.863627,0.873671,0.948471,0.065347,0.922927,SSP1-2.6,MIROC6,1
8,SSP1-2.6 UKESM1-0-LL 1,separate,4.739322,3.827223,0.97687,0.989317,0.079084,0.981134,SSP1-2.6,UKESM1-0-LL,1
9,SSP1-2.6 UKESM1-0-LL 1,joint,5.009264,3.914821,0.970561,0.986815,0.079342,0.985415,SSP1-2.6,UKESM1-0-LL,1


In [15]:
sep = benchmark_results[benchmark_results['separate or joint'] == 'separate']
joi = benchmark_results[benchmark_results['separate or joint'] == 'joint']
sep.index = sep['SSP + Climate Model + n']
joi.index = joi['SSP + Climate Model + n']
(sep['MAE'] - joi['MAE']).agg(['mean', 'std'])


mean    0.403009
std     0.894387
Name: MAE, dtype: float64

In [16]:
benchmark_results_ssp3 = benchmark_results[benchmark_results['SSP'] == 'SSP3-7.0']
sep = benchmark_results_ssp3[benchmark_results_ssp3['separate or joint'] == 'separate']
joi = benchmark_results_ssp3[benchmark_results_ssp3['separate or joint'] == 'joint']
sep.index = sep['SSP + Climate Model + n']
joi.index = joi['SSP + Climate Model + n']
(sep['MAE'] - joi['MAE']).agg(['mean', 'std'])

mean    0.797411
std     0.866987
Name: MAE, dtype: float64

In [17]:
sep_model.config['summary_net']

{'type': 'dense',
 'kwargs': {'hidden_units': [25, 25],
  'output_dim': 4,
  'zeroth_layer': {'dropout': 0.4}}}

In [18]:
benchmark_results[benchmark_results['SSP'] == 'SSP3-7.0']

Unnamed: 0,SSP + Climate Model + n,separate or joint,RMSE,MAE,R2,Correlation,ECE,Posterior Contraction,SSP,Climate Model,n
22,SSP3-7.0 ACCESS-ESM1-5 1,separate,4.111844,3.290683,0.986561,0.994975,0.027032,0.982573,SSP3-7.0,ACCESS-ESM1-5,1
23,SSP3-7.0 ACCESS-ESM1-5 1,joint,3.841137,3.053249,0.987657,0.993876,0.028121,0.986249,SSP3-7.0,ACCESS-ESM1-5,1
24,SSP3-7.0 CESM2 1,separate,4.493834,3.5727,0.984237,0.99225,0.058374,0.978384,SSP3-7.0,CESM2,1
25,SSP3-7.0 CESM2 1,joint,4.096689,3.223844,0.985508,0.993135,0.018547,0.985104,SSP3-7.0,CESM2,1
26,SSP3-7.0 CanESM5 1,separate,7.133901,5.541117,0.940573,0.976266,0.065542,0.969918,SSP3-7.0,CanESM5,1
27,SSP3-7.0 CanESM5 1,joint,4.394861,3.406394,0.979584,0.993289,0.010389,0.986019,SSP3-7.0,CanESM5,1
28,SSP3-7.0 GISS-E2-1-G 1,separate,4.276523,3.411009,0.985229,0.9931,0.038037,0.981037,SSP3-7.0,GISS-E2-1-G,1
29,SSP3-7.0 GISS-E2-1-G 1,joint,3.704969,2.918144,0.988825,0.994429,0.052842,0.986627,SSP3-7.0,GISS-E2-1-G,1
30,SSP3-7.0 IPSL-CM6A-LR 1,separate,7.036563,5.452364,0.939032,0.972179,0.109074,0.97414,SSP3-7.0,IPSL-CM6A-LR,1
31,SSP3-7.0 IPSL-CM6A-LR 1,joint,4.42849,3.482682,0.97929,0.989958,0.019368,0.985651,SSP3-7.0,IPSL-CM6A-LR,1


In [28]:
agg_dict = {metric: ['mean', 'std'] for metric in metrics_names}
benchmark_results_agg = benchmark_results.groupby('separate or joint').agg(agg_dict)
benchmark_results_agg.iloc[::-1] # Reverse row order to have separate first

Unnamed: 0_level_0,RMSE,RMSE,MAE,MAE,R2,R2,Correlation,Correlation,ECE,ECE,Posterior Contraction,Posterior Contraction
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
separate or joint,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
separate,5.334577,1.347631,4.238463,1.058117,0.965888,0.025994,0.985028,0.011862,0.082822,0.067427,0.979131,0.003611
joint,4.906359,1.714001,3.835454,1.321066,0.97234,0.025804,0.98821,0.010427,0.044847,0.036872,0.981861,0.014759


In [29]:
benchmark_results_agg[['MAE', 'ECE', 'Posterior Contraction']].iloc[::-1].round(4)

Unnamed: 0_level_0,MAE,MAE,ECE,ECE,Posterior Contraction,Posterior Contraction
Unnamed: 0_level_1,mean,std,mean,std,mean,std
separate or joint,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
separate,4.2385,1.0581,0.0828,0.0674,0.9791,0.0036
joint,3.8355,1.3211,0.0448,0.0369,0.9819,0.0148


In [21]:
# Just for SSP3-7.0
agg_dict = {metric: ['mean', 'sem'] for metric in metrics_names}
benchmark_results_ssp3_agg = benchmark_results[benchmark_results['SSP'] == 'SSP3-7.0'].groupby('separate or joint').agg(agg_dict)
benchmark_results_ssp3_agg.iloc[::-1] # Reverse row order to have separate first


Unnamed: 0_level_0,RMSE,RMSE,MAE,MAE,R2,R2,Correlation,Correlation,ECE,ECE,Posterior Contraction,Posterior Contraction
Unnamed: 0_level_1,mean,sem,mean,sem,mean,sem,mean,sem,mean,sem,mean,sem
separate or joint,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
separate,5.122444,0.517725,4.041501,0.384712,0.971727,0.008355,0.987473,0.00354,0.055101,0.01219,0.978273,0.001767
joint,4.122629,0.124055,3.244091,0.090414,0.983517,0.001971,0.992725,0.0007,0.026667,0.006106,0.986007,0.000284


In [22]:
# Correlation matrix: Agreement between metrics
metric_corr_matrix = np.corrcoef(benchmark_results.loc[:, metrics_names].astype(float).transpose())
pd.DataFrame(metric_corr_matrix, columns=metrics_names, index=metrics_names)

Unnamed: 0,RMSE,MAE,R2,Correlation,ECE,Posterior Contraction
RMSE,1.0,0.998328,-0.984285,-0.965603,0.556007,-0.851479
MAE,0.998328,1.0,-0.987374,-0.97038,0.576504,-0.849219
R2,-0.984285,-0.987374,1.0,0.990424,-0.633719,0.783235
Correlation,-0.965603,-0.97038,0.990424,1.0,-0.64503,0.746081
ECE,0.556007,0.576504,-0.633719,-0.64503,1.0,-0.177849
Posterior Contraction,-0.851479,-0.849219,0.783235,0.746081,-0.177849,1.0


# Get times

In [23]:
times = []
for i,sep_model_key in enumerate(sep_models.keys()): # Iterate over all models

    sep_name = format_names(sep_models[sep_model_key].config['filenames_sims'][0][18:-3])+ ' ' + sep_model_key.split('_')[-1]

    # Load training time
    train_time_path = os.path.join(sep_models[sep_model_key].config['checkpoint_path'], 'training_time.pkl')
    train_time = pickle.load(open(train_time_path, 'rb'))
    times.append([sep_name, 'separate', 'training', train_time])

    # Inference time
    times.append([sep_name, 'separate', 'inference', sep_inference_times_per_setting[i]])
    times.append([sep_name, 'joint', 'inference', joint_inference_times_per_setting[i]])

# Joint model training time
train_time_path = os.path.join(config['checkpoint_path'], 'training_time.pkl')
train_time = pickle.load(open(train_time_path, 'rb'))
times.append(['all', 'joint', 'training', train_time])

In [24]:
times_df = pd.DataFrame(times, columns=['SSP + Climate Model + n', 'seperate or joint', 'training or inference', 'time [s]'])
times_df['time [min]'] = times_df['time [s]'] / 60

In [25]:
times_df.groupby(['seperate or joint', 'training or inference']).agg(['sum'])[['time [s]', 'time [min]']].round(0)

Unnamed: 0_level_0,Unnamed: 1_level_0,time [s],time [min]
Unnamed: 0_level_1,Unnamed: 1_level_1,sum,sum
seperate or joint,training or inference,Unnamed: 2_level_2,Unnamed: 3_level_2
joint,inference,26.0,0.0
joint,training,4041.0,67.0
separate,inference,25.0,0.0
separate,training,18753.0,313.0


In [26]:
for c in [sep_model.config, config]:
    data = []
    for c in [sep_model.config, config]:
        data.append([c['checkpoint_path'][12:], c['summary_net']['kwargs']['zeroth_layer']['dropout'], c['epochs']])

    df = pd.DataFrame(data, columns=['exp_name', 'dropout', 'epochs'])
df

Unnamed: 0,exp_name,dropout,epochs
0,batched_separate_dropout2/separate_17_1,0.4,15
1,batched_dropout3/dropout_0_4_2,0.4,80
