In [None]:
%load_ext autoreload
%autoreload 1
# https://stackoverflow.com/a/52897289

import os

iteration_idx = int(os.environ.get('ITERATION', 0))

# Seed value
seed_value = iteration_idx

# 2. Set `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)

# 3. Set `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)

# 4. Set the `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.compat.v1.set_random_seed(seed_value)

import hashlib
import math
from typing import NamedTuple
from pathlib import Path, PosixPath
import sys
from multiprocessing import cpu_count
import tempfile
import shutil
import pickle
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow import keras
from tensorflow.keras import layers, regularizers
from IPython.display import Image

%aimport utils

# Note: Start jupyter with CUDA_VISIBLE_DEVICES=-1 env var to force CPU mode.
if tf.test.gpu_device_name():
    print('Using GPU')
else:
    print('Not using GPU')

In [None]:
import synthia as syn

## Constants

In [None]:
extract_datasets = False
enhance_data = True
cloudy_profiles_only = True
is_windows = os.name == 'nt'

pbs_id = str(os.environ.get('PBS_JOBID', '1'))
pbs_array_idx = str(os.environ.get('PBS_ARRAY_INDEX', '0'))
is_pbs_job = 'PBS_JOBID' in os.environ 

if is_windows:
    wsl = 'wsl'
else:
    wsl = ''

path_proj = Path.cwd().parent
path_data = path_proj / 'data'

path_logs = Path(tempfile.gettempdir()) if is_pbs_job else path_proj / 'logs'
path_logs.mkdir(exist_ok=True)

path_results = path_proj / 'results'
path_results.mkdir(exist_ok=True)

path_common = path_data / 'common'
path_saf = path_data / 'saf'
path_saf_out_dir = Path(tempfile.gettempdir()) if is_pbs_job else path_saf

path_saf_in =  path_saf / 'nwp_saf_profiles_in.nc'
path_saf_in_synth = path_saf_out_dir / 'nwp_saf_profiles_in_synth.nc'

path_ecrad = path_proj / 'ecrad'
path_ecrad_bin = path_ecrad / 'bin' / 'ecrad'

path_triple_nml = path_saf / 'configCY47R1_tripleclouds.nam'
path_sparta_nml = path_saf / 'configCY47R1_spartacus.nam'
path_triple_out_synth = path_saf_out_dir / 'ecrad_nwp_saf_profiles_tripleclouds_out_synth.nc'
path_sparta_out_synth = path_saf_out_dir / 'ecrad_nwp_saf_profiles_spartacus_out_synth.nc'
window_paths = [path_ecrad, path_ecrad_bin, path_triple_nml, path_sparta_nml, path_saf_in_synth, path_triple_out_synth, path_sparta_out_synth, path_logs]
posix_paths = []
for path in window_paths:
    if is_windows:
        x = ! wsl wslpath "{path}"
        x = x[0]
    else:
        x = path.as_posix()
    posix_paths.append(x)

# Note: See first cell for ITERATION env var.
    
use_diff = int(os.environ.get('USE_DIFF', 1))

if use_diff:
    use_pressure_level = int(os.environ.get('USE_PRESSURE_LEVEL', 5000)) # Pa
else:
    use_pressure_level = 0

ncores = int(os.environ.get('NCORES', cpu_count()))

# Copula vars
copula_type = os.environ.get('COPULA_TYPE', 'gaussian')
var_synth = os.environ.get('VAR_SYNTH', 'sw_albedo,cos_solar_zenith_angle').split(',')
unif_ratio = float(os.environ.get('UNIF_RATIO', 1.0))
stretch_factor = float(os.environ.get('STRETCH_FACTOR', 1.0))
synth_mul_factor = int(os.environ.get('SYNTH_MUL_FACTOR', 0))

# ML vars
save_model = os.environ.get('SAVE_MODEL', '1') == '1'
var_ml = os.environ.get('VAR_ML', 'optical_depth_fl,cos_solar_zenith_angle,sw_albedo,skin_temperature,cloud_fraction,temperature_fl,q,dz').split(',')
epochs = int(os.environ.get('EPOCHS', 2))
model_type = os.environ.get('MODEL_TYPE', 'MLP') # 'RNN', 'MLP'
rnn_type = os.environ.get('RNN_TYPE', 'GRU') # 'GRU', SimpleRNN', 'LSTM'
rnn_direction = os.environ.get('RNN_DIRECTION', 'bi') # 'fwd', 'bwd', 'bi'
n_hidden_layers = int(os.environ.get('N_HIDDEN_LAYERS', 3))
hidden_size = float(os.environ.get('HIDDEN_SIZE', 1))
activation = str(os.environ.get('ACTIVATION', 'elu'))
loss = str(os.environ.get('LOSS', 'mse'))
l1_penalty = float(os.environ.get('L1_PENALTY', 1e-5)) 
l2_penalty = float(os.environ.get('L2_PENALTY', 1e-5))
var_regularizer_factor = float(os.environ.get('VAR_REGULARIZER_FACTOR', 0))
dropout_ratio_input = float(os.environ.get('DROPOUT_RATIO_INPUT', 0))
dropout_ratio_hidden = float(os.environ.get('DROPOUT_RATIO_HIDDEN', 0))
use_heating_rates = int(os.environ.get('USE_HEATING_RATES', '1'))

job_name = str(os.environ.get('JOB_NAME', 'default'))

input_path = path_saf_in_synth
triple_output_path = path_triple_out_synth
sparta_output_path = path_sparta_out_synth
shuffle = False
    
case_name = str(os.environ.get('CASE_NAME', 'split'))

if case_name == 'combined':
    is_lw_sw_split = 0
elif case_name == 'split':
    is_lw_sw_split = 1
else:
    raise RuntimeError('Invalid case name')

case_name_without_iteration = (
    f'case={case_name},' +
    f'use_diff={use_diff},' +
    f'use_heating_rates={use_heating_rates},' +
    f'copula_type={copula_type},' +
    f'synth_mul_factor={synth_mul_factor},' +
    f'unif_ratio={unif_ratio},' +
    f'stretch_factor={stretch_factor},' + 
    f'var_synth={",".join(var_synth)},' +
    f'var_ml={",".join(var_ml)},' + 
    f'loss={loss},' +
    f'activation={activation},' +
    f'l1_penalty={l1_penalty},' +
    f'l2_penalty={l2_penalty},' +
    f'var_regularizer_factor={var_regularizer_factor},' +
    f'dropout_ratio_input={dropout_ratio_input},' +
    f'dropout_ratio_hidden={dropout_ratio_hidden},' +
    f'model_type={model_type},' +
    f'rnn_type={rnn_type},' +
    f'rnn_direction={rnn_direction},' +
    f'layers={n_hidden_layers},' +
    f'size={hidden_size}'
)
case_name = case_name_without_iteration + f',iteration={iteration_idx}'
case_name_hashed = hashlib.md5(case_name.encode('ascii')).hexdigest()
   
stat_quantities_sw = ['flux_dn_direct_sw', 'flux_dn_sw', 'flux_up_sw', 'heating_rate_sw']
_heating_rate_sw = ['heating_rate_sw']
stat_quantities_lw = ['flux_dn_lw', 'flux_up_lw', 'heating_rate_lw']
_heating_rate_lw = ['heating_rate_lw']

(case_name, case_name_hashed)

In [None]:
stats_pkl_path = path_results / f"job_stats_{job_name}" / (case_name_hashed + '.pkl')

if is_pbs_job and stats_pkl_path.exists():
    print('job results exist already, exiting')
    raise RuntimeError('early stop')

In [None]:
# Create empty file to keep track of job failures
if is_pbs_job:
    f_job_failure = path_results / f"job_failures_{job_name}" / f'{case_name_hashed}_{pbs_id}_{pbs_array_idx}'
    if not f_job_failure.parent.exists():
        f_job_failure.parent.mkdir()
    with open(f_job_failure, 'w') as f:
        f.write('')

## Data Preparation and Enhancment

In [None]:
def sel_cloudy_profiles(ds, cloudy=True):
    is_cloudy = ds['cloud_fraction'].max('level') > 0
    if cloudy:
        return ds.sel(column=is_cloudy)
    else:
        return ds.sel(column=~is_cloudy)

In [None]:
if extract_datasets:
    ! {wsl} find -name "*.nc.xz" -exec unxz -k '\{\}' \

if enhance_data:
    # Load profiles
    ds_true = xr.open_dataset(path_saf_in)
    
    if cloudy_profiles_only:
        ds_true = sel_cloudy_profiles(ds_true)
    
    display(ds_true)

In [None]:
if enhance_data:
    ds_train, ds_temp = utils.train_test_split_dataset(ds_true, train_size=0.6, dim='column', shuffle=True, seed=42)
    ds_test, ds_validation = utils.train_test_split_dataset(ds_temp, test_size=0.5, dim='column', shuffle=True, seed=42)
    print(len(ds_train.column), len(ds_validation.column), len(ds_test.column))

In [None]:
if enhance_data:
    # Generate a new dataset with x the number of profiles. NB: only albedo
    # and solar angle are synthetic, other vars are just duplicated.

    ds_train_synth = utils.compute_augmented_dataset(
        ds_train, var_synth,
        synth_mul_factor=synth_mul_factor,
        uniformization_ratio=unif_ratio,
        stretch_factor=stretch_factor,
        copula_type=copula_type,
        num_threads=ncores
    )

    ds_train_synth['sw_albedo'].plot.hist(); plt.show();
    ds_train_synth['cos_solar_zenith_angle'].plot.hist(alpha=0.5);

In [None]:
if enhance_data:
    # Add Test dataset to merged generated random saf_profiles
    ds_synth = xr.concat([ds_train, ds_train_synth, ds_validation, ds_test], dim='column')

    # When concatenating xarray adds 'column' dim to o2_vmr (1D) so we replace with that from original dataset
    ds_synth['o2_vmr'] = ds_true['o2_vmr']

    print(len(ds_synth.column))

In [None]:
if enhance_data:
    ds_synth.to_netcdf(path_saf_in_synth)

In [None]:
if enhance_data:
    # Spartacus
    # ! cd {ecrad_dir} && wsl OMP_NUM_THREADS= {ecrad} {sparta_nml_path} {derived_input_path} {sparta_output_path} > sparta_synth.log
    ! cd {path_ecrad} && {wsl} OMP_NUM_THREADS= {posix_paths[1]} {posix_paths[3]} {posix_paths[4]} {posix_paths[6]} > {window_paths[7]}/sparta_synth.log

In [None]:
if enhance_data:
    # Tripleclouds
    ! cd {path_ecrad} && {wsl} OMP_NUM_THREADS= {posix_paths[1]} {posix_paths[2]} {posix_paths[4]} {posix_paths[5]} > {window_paths[7]}/triple_synth.log

## Load data

In [None]:
if enhance_data:
    new_train_size_with_synthetic = len(ds_train.column) + len(ds_train_synth.column)
else:
    new_train_size_with_synthetic = 13703 #This number is generated with the forumla above for creating 5 times the number of samples.

print(new_train_size_with_synthetic)

In [None]:
ds_synth = utils.load_scheme_inputs(path_saf_in_synth, only_relevant=False)

if use_pressure_level:
    level = int(np.fabs(ds_synth.pressure_fl.mean(axis=0) -use_pressure_level).argmin())
    level = range(0, int(level) + 1) # End of range is exclusive
else:
    level = None
display(level)

ds_synth = utils.add_derived_inputs(ds_synth)

input_quantities = var_ml

# Select input quantities
ds_synth = ds_synth[input_quantities]

# Subset input levels
if level:
    ds_synth = ds_synth.sel(level=level)

# Normalize inputs
normalize_inputs = True
norm_root = 1
if normalize_inputs:
    norm_stats_in = utils.compute_norm_stats(ds_synth, norm_root)
    ds_synth = utils.normalize_inputs(ds_synth, norm_stats_in)

utils.plot_inputs(ds_synth, normalize_inputs)


ds_train_synth, ds_temp = utils.train_test_split_dataset(ds_synth, train_size=new_train_size_with_synthetic, dim='column', shuffle=False)
ds_validation, ds_test = utils.train_test_split_dataset(ds_temp, test_size=0.5, dim='column', shuffle=False)

print(len(ds_train_synth.column), len(ds_validation.column), len(ds_test.column))

column_training_range = slice(0, len(ds_train_synth.column))
column_training_validation_range = slice(0, len(ds_train_synth.column) + len(ds_validation.column))
column_test_range = slice(len(ds_train_synth.column) + len(ds_validation.column), None)
print(column_training_validation_range)
print(column_test_range)

# Physical scheme inputs

In [None]:
def xr_1d_to_2d_vars(ds, var_names):
    n_columns = ds.dims['column']
    n_levels = ds.dims['level']
    for var_name in var_names:
        da = ds[var_name]
        arr = np.empty((n_columns, n_levels), np.float64)
        arr[:,:] = da.values.reshape(n_columns, 1)
        da = xr.DataArray(arr, dims=['column', 'level'])
        ds = ds.assign({var_name: da})
    return ds

In [None]:
# Show all available inputs. We load them so that we can save them with the data later
ds_inputs_all = utils.load_scheme_inputs(path_saf_in_synth, only_relevant=False)

# For training, we use a subset of inputs only.
if is_lw_sw_split:
    ds_inputs_lw = ds_synth.drop(['cos_solar_zenith_angle', 'sw_albedo'])
    ds_inputs_sw = ds_synth.drop(['skin_temperature', 'temperature_fl'])
    if model_type == 'RNN':
        ds_inputs_lw = xr_1d_to_2d_vars(ds_inputs_lw, ['skin_temperature'])
        ds_inputs_sw = xr_1d_to_2d_vars(ds_inputs_sw, ['cos_solar_zenith_angle', 'sw_albedo'])
    display(ds_inputs_lw, ds_inputs_sw)
else:
    ds_inputs = ds_synth
    display(ds_inputs)

# Physical scheme outputs

In [None]:
normalize_outputs = True

if normalize_outputs:
    if use_diff:
        # Do not apply root 8 since values can be negative.
        norm_root_out = 1
    else:
        norm_root_out = norm_root

if is_lw_sw_split:
    # Load LW outputs
    lw_fluxes = ['flux_up_lw', 'flux_dn_lw']
    ds_triple_lw = utils.load_scheme_outputs(triple_output_path)[lw_fluxes]
    ds_sparta_lw = utils.load_scheme_outputs(sparta_output_path)[lw_fluxes]
    
    # Difference between fast and slow physical scheme outputs
    if use_diff:
        ds_outputs_lw = ds_sparta_lw - ds_triple_lw
    else:
        ds_outputs_lw = ds_sparta_lw
    
    if use_heating_rates:
        ds_outputs_lw = utils.add_heating_rates(ds_outputs_lw, ds_inputs_all)
    
    # Subset levels
    if level:
        ds_outputs_lw = ds_outputs_lw.sel(half_level=level)
        if use_heating_rates:
            ds_outputs_lw = ds_outputs_lw.sel(level=level)
    
    # Normalize
    if normalize_outputs:   
        norm_stats_out_lw = utils.compute_norm_stats(ds_outputs_lw, norm_root_out)
        ds_outputs_lw = utils.normalize_outputs(ds_outputs_lw, norm_stats_out_lw)
    
    # Load SW outputs
    sw_fluxes = ['flux_up_sw', 'flux_dn_sw', 'flux_dn_direct_sw']
    ds_triple_sw = utils.load_scheme_outputs(triple_output_path)[sw_fluxes]
    ds_sparta_sw = utils.load_scheme_outputs(sparta_output_path)[sw_fluxes]

    # Difference between fast and slow physical scheme outputs
    if use_diff:
        ds_outputs_sw = ds_sparta_sw - ds_triple_sw
    else:
        ds_outputs_sw = ds_sparta_sw
    
    if use_heating_rates:
        ds_outputs_sw = utils.add_heating_rates(ds_outputs_sw, ds_inputs_all)
    
    # Subset levels
    if level:
        ds_outputs_sw = ds_outputs_sw.sel(half_level=level)
        if use_heating_rates:
            ds_outputs_sw = ds_outputs_sw.sel(level=level)
    
    # Normalize
    if normalize_outputs:
        norm_stats_out_sw = utils.compute_norm_stats(ds_outputs_sw, norm_root_out)
        ds_outputs_sw = utils.normalize_outputs(ds_outputs_sw, norm_stats_out_sw)
    
    display(ds_outputs_lw, ds_outputs_sw)

else:
    # Load outputs
    ds_triple = utils.load_scheme_outputs(triple_output_path)
    ds_sparta = utils.load_scheme_outputs(sparta_output_path)

    # Difference between fast and slow physical scheme outputs
    if use_diff:
        ds_outputs = ds_sparta - ds_triple
    else:
        ds_outputs = ds_sparta
    
    # Normalize
    if normalize_outputs:
        norm_stats_out = utils.compute_norm_stats(ds_outputs, norm_root_out)
        ds_outputs = utils.normalize_outputs(ds_outputs, norm_stats_out)
    
    display(ds_outputs)

# Machine Learning

In [None]:
def create_train_data(ds_x_inputs, ds_y_inputs, shuffle, test_size, apply_pca=False):
    
    train_data = utils.create_training_data(
        ds_x_inputs, ds_y_inputs,
        apply_pca=apply_pca,
        shuffle=False,
        test_size=test_size,
        x_flat_dims=2 if model_type == 'RNN' else 1
    )
    print(train_data.x_train_flat.shape)

    return train_data

def sel_train_val(ds):
    return ds.sel(column=column_training_validation_range)

def sel_train_test(ds):
    # ds_synth['o2_vmr'] = ds_true['o2_vmr']
    return xr.concat([ds.sel(column=column_training_range), ds.sel(column=column_test_range)], dim='column')

if is_lw_sw_split:
    train_val_data_lw = create_train_data(sel_train_val(ds_inputs_lw), sel_train_val(ds_outputs_lw), shuffle, test_size=len(ds_validation.column))
    train_val_data_sw = create_train_data(sel_train_val(ds_inputs_sw), sel_train_val(ds_outputs_sw), shuffle, test_size=len(ds_validation.column))
    train_test_data_lw = create_train_data(sel_train_test(ds_inputs_lw), sel_train_test(ds_outputs_lw), shuffle, test_size=len(ds_test.column))
    train_test_data_sw = create_train_data(sel_train_test(ds_inputs_sw), sel_train_test(ds_outputs_sw), shuffle, test_size=len(ds_test.column))
else:   
    train_val_data = create_train_data(sel_train_val(ds_inputs), sel_train_val(ds_outputs), shuffle, test_size=len(ds_validation.column))
    train_test_data = create_train_data(sel_train_test(ds_inputs), sel_train_test(ds_outputs), shuffle, test_size=len(ds_validation.column))

In [None]:
hidden_size_lw = int(hidden_size * train_val_data_lw.x_train_flat.shape[1])
hidden_size_sw = int(hidden_size * train_val_data_sw.x_train_flat.shape[1])
print(hidden_size_lw, hidden_size_sw)

In [None]:
# %%capture

# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

def train_and_save_model(case_name,
                         train_val_data, train_test_data,
                         ds_inputs_all_train_test, ds_triple_train_test, ds_sparta_train_test,
                         model_type, rnn_type, rnn_direction,
                         n_hidden_layers, hidden_size, activation, l1_penalty,
                         l2_penalty, var_regularizer_factor, dropout_ratio_input=0.0, dropout_ratio_hidden=0.0, 
                         learning_rate=0.001, epochs=epochs, save_model=False):

    hp = dict(model_type=model_type,
              rnn_type=rnn_type,
              rnn_direction=rnn_direction,
              n_hidden_layers=n_hidden_layers,
              hidden_size=hidden_size,
              activation=activation,
              dropout_ratio_input=dropout_ratio_input,
              dropout_ratio_hidden=dropout_ratio_hidden,
              l1_penalty=l1_penalty,
              l2_penalty=l2_penalty,
              var_regularizer_factor=var_regularizer_factor,
              learning_rate=learning_rate,
              loss=loss)

    if model_type == 'RNN':
        model_fn = utils.build_rnn_model
    elif model_type == 'MLP':
        model_fn = utils.build_dense_nn_model
    else:
        raise ValueError('unknown model type: ' + model_type)

    trained_model = utils.train_model(model_fn,
                                      train_val_data,
                                      hp,
                                      epochs=epochs,
                                      batch_size=256,
                                      early_stopping=True,
                                      verbose=2,
                                      save_best=True)

    d = utils.save_training_results(case_name,
                                    trained_model,
                                    train_test_data,
                                    ds_inputs_all_train_test,
                                    ds_triple_train_test,
                                    ds_sparta_train_test,
                                    skip_save=not save_model,
                                    save_dir=str(path_results))
    return d, trained_model

if is_lw_sw_split:
    d_lw, trained_model_lw = train_and_save_model(
        'model_lw',
        train_val_data_lw, train_test_data_lw,
        sel_train_test(ds_inputs_all), sel_train_test(ds_triple_lw), sel_train_test(ds_sparta_lw),
        model_type, rnn_type, rnn_direction,
        n_hidden_layers, hidden_size_lw, activation, l1_penalty, l2_penalty, var_regularizer_factor, dropout_ratio_input, dropout_ratio_hidden, epochs=epochs, save_model=save_model)
    
    d_sw, trained_model_sw = train_and_save_model(
        'model_sw',
        train_val_data_sw, train_test_data_sw,
        sel_train_test(ds_inputs_all), sel_train_test(ds_triple_sw), sel_train_test(ds_sparta_sw),
        model_type, rnn_type, rnn_direction,
        n_hidden_layers, hidden_size_sw, activation, l1_penalty, l2_penalty, var_regularizer_factor, dropout_ratio_input, dropout_ratio_hidden, epochs=epochs, save_model=save_model)    
else:
    d, trained_model = train_and_save_model(
        'model',
        train_val_data, train_test_data,
        sel_train_test(ds_inputs_all), sel_train_test(ds_triple), sel_train_test(ds_sparta),
        model_type, rnn_type, rnn_direction,
        n_hidden_layers, hidden_size_sw, activation, l1_penalty, l2_penalty, var_regularizer_factor, dropout_ratio_input, dropout_ratio_hidden, save_model=save_model)

In [None]:
if is_lw_sw_split:
    trained_model_lw.history.plot()
    trained_model_sw.history.plot()
else:
    trained_model.history.plot()
    trained_model_history = trained_model.history

In [None]:
# Unnormalize outputs
vars_to_unnormalize = [
     'y_true_train', 'y_true_test',
     'y_pred_train', 'y_pred_test'
]

if normalize_outputs:
    for name in vars_to_unnormalize:
        if is_lw_sw_split:
            d_lw[name] = utils.unnormalize_outputs(d_lw[name], norm_stats_out_lw)
            d_sw[name] = utils.unnormalize_outputs(d_sw[name], norm_stats_out_sw)
        else:
            d[name] = utils.unnormalize_outputs(d[name], norm_stats_out)

## Postprocessing

If the NN have been trained using a slice of pressure levels we recover the full profiles

In [None]:
# Merge the results from the two NN
def merge_lw_sw_split(ds_1, ds_2):
    # At the moment this is a bit hacky as we create a new ds based on ds_1
    # This means that the the combilne dataset will onyl have the NN specs from ds_1
    # We do not need to merge 'x_train' and 'x_test' as these are all the inputs and therefore same for lw and sw 
    vars_to_merge = ['y_triple_train', 'y_triple_test',
     'y_sparta_train', 'y_sparta_test', 
     'y_true_train', 'y_true_test',
     'y_pred_train', 'y_pred_test']
 
    d = ds_1 # This is a hack
    for var in vars_to_merge:
        if d.has_test_data:
            d[var] = xr.merge([ds_1[var], ds_2[var]])
        else:
            if '_test' in var:
                continue
            else:
                d[var] = xr.merge([ds_1[var], ds_2[var]])
    return d

if is_lw_sw_split:
    d = merge_lw_sw_split(d_lw, d_sw)
else:
    pass

In [None]:
d.y_sparta_train = utils.add_heating_rates(d.y_sparta_train, d.x_train)
d.y_sparta_test = utils.add_heating_rates(d.y_sparta_test, d.x_test)
d.y_triple_train = utils.add_heating_rates(d.y_triple_train, d.x_train)
d.y_triple_test = utils.add_heating_rates(d.y_triple_test, d.x_test)
d.y_sparta_train

In [None]:
def recover_levels(d_subset, d_template):
    # Fill the missing values with zeros for downwelling fluxes -- i.e. the diff should be zero
    # And propagate the value from the last calculated level from the inference to TOA for upwelling fluxes.
    last_valid_level = len(d_subset.half_level) - 1
    diff_full = xr.zeros_like(d_template)
    for var in d_subset.keys():
        # We only update values that have been calculated from the inference  -- all the others remain zero
        diff_full[var][:, :last_valid_level + 1] = d_subset[var]
        # In the case of upwelling fluxes, we need to propagate the last value calculated from the inference
        # to the TOA to keep the physical meaning.
        if '_up_' in var or 'heating_rate' in var:
            diff_full[var][:, last_valid_level + 1:] = diff_full[var][:, last_valid_level:last_valid_level + 1]
    diff_full = diff_full[[z for z in d_subset.variables]]
    return diff_full

In [None]:
if use_pressure_level:
    diff_pred_train = recover_levels(d.y_pred_train, d.y_sparta_train)
    diff_pred_test = recover_levels(d.y_pred_test, d.y_sparta_test)
    # Update the main data dictiorty with data for all the levels
    d.y_pred_train = diff_pred_train
    d.y_pred_test = diff_pred_test
    # In the case of reference data -- i.e. for the difference -- we can calculate them directly as y_sparta_train were 
    # were saved with all the levels
    d.y_true_train = d.y_sparta_train - d.y_triple_train
    d.y_true_test = d.y_sparta_test - d.y_triple_test
else:
    pass
d.y_true_train

In [None]:
if not use_heating_rates:
    d.y_true_train = utils.add_heating_rates(d.y_true_train, d.x_train)
    d.y_true_test = utils.add_heating_rates(d.y_true_test, d.x_test)
    d.y_pred_train = utils.add_heating_rates(d.y_pred_train, d.x_train)
    d.y_pred_test = utils.add_heating_rates(d.y_pred_test, d.x_test)
d.y_pred_test

In [None]:
if level:
    m_lw = utils.compute_error_metrics(d.y_true_train.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_lw].drop(_heating_rate_lw),
                                       d.y_pred_train.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_lw].drop(_heating_rate_lw)).add_suffix('_train')
    m_lw = pd.concat([m_lw, utils.compute_error_metrics(d.y_true_test.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_lw].drop(_heating_rate_lw),
                                                        d.y_pred_test.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_lw].drop(_heating_rate_lw)).add_suffix('_test')], 'columns')
else:
    m_lw = utils.compute_error_metrics(d.y_true_train[stat_quantities_lw].drop(_heating_rate_lw),
                                       d.y_pred_train[stat_quantities_lw].drop(_heating_rate_lw)).add_suffix('_train')
    m_lw = pd.concat([m_lw, utils.compute_error_metrics(d.y_true_test[stat_quantities_lw].drop(_heating_rate_lw),
                                                        d.y_pred_test[stat_quantities_lw].drop(_heating_rate_lw)).add_suffix('_test')], 'columns')
m_lw

In [None]:
if level:
    m_sw = utils.compute_error_metrics(d.y_true_train.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_sw].drop(_heating_rate_sw),
                                       d.y_pred_train.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_sw].drop(_heating_rate_sw)).add_suffix('_train')
    m_sw = pd.concat([m_sw, utils.compute_error_metrics(d.y_true_test.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_sw].drop(_heating_rate_sw),
                                                        d.y_pred_test.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1)).rename(level='half_level')[stat_quantities_sw].drop(_heating_rate_sw)).add_suffix('_test')], 'columns')
else:
    m_sw = utils.compute_error_metrics(d.y_true_train[stat_quantities_sw].drop(_heating_rate_sw),
                                       d.y_pred_train[stat_quantities_sw].drop(_heating_rate_sw)).add_suffix('_train')
    m_sw = pd.concat([m_sw, utils.compute_error_metrics(d.y_true_test[stat_quantities_sw].drop(_heating_rate_sw),
                                                        d.y_pred_test[stat_quantities_sw].drop(_heating_rate_sw)).add_suffix('_test')], 'columns')
m_sw

In [None]:
d.y_true_train[stat_quantities_sw + stat_quantities_lw]

In [None]:
d.y_sparta_train

In [None]:
# Save stats to file to allow comparison between different runs.
# See grid-stats.ipynb.

case_stats = pd.concat([
  pd.DataFrame({
      'name': [case_name_without_iteration], # Useful for easy aggregation across iterations.
      'pbs_id': [pbs_id],
      'pbs_array_idx': [pbs_array_idx],
      'job_name': [job_name],
      'iteration': [iteration_idx],
      'use_diff': [use_diff],
      'use_heating_rates': [use_heating_rates],
      'model_type': [model_type],
      'rnn_type': [rnn_type],
      'rnn_direction': [rnn_direction],
      'hidden_size': [hidden_size],
      'n_hidden_layers': [n_hidden_layers],
      'copula_type': [copula_type],
      'synth_mul_factor': [synth_mul_factor],
      'unif_ratio': [unif_ratio],
      'stretch_factor': [stretch_factor],
      'loss': [loss],
      'activation': [activation],
      'l1_penalty': [l1_penalty],
      'l2_penalty': [l2_penalty],
      'var_regularizer_factor': [var_regularizer_factor],
      'dropout_ratio_input': [dropout_ratio_input],
      'dropout_ratio_hidden': [dropout_ratio_hidden],
      'var_synth': [','.join(var_synth)],
      'var_ml': [','.join(var_ml)],
  }),
  ((m_sw.loc[['all']] + m_lw.loc[['all']]) / 2).reset_index().drop(columns='quantity'),
  m_sw.loc[['all']].reset_index().drop(columns='quantity').add_prefix('sw_'),
  m_lw.loc[['all']].reset_index().drop(columns='quantity').add_prefix('lw_')
], axis=1)
display(case_stats)

if not stats_pkl_path.parent.exists():
    stats_pkl_path.parent.mkdir()

with open(stats_pkl_path, 'wb') as f:
    pickle.dump(case_stats, f)

In [None]:
def derive_boa_toa_quantities(ds, toa_level):
    boa = ds.sel(half_level=slice(0, 1), level=slice(0, 1)).rename(level='half_level')
    toa = ds.sel(half_level=slice(toa_level, toa_level + 1), level=slice(toa_level, toa_level + 1)).rename(level='half_level')
    return xr.merge([
        boa[['flux_dn_lw', 'flux_dn_sw', 'flux_dn_direct_sw']].rename({
            'flux_dn_lw': 'flux_dn_lw_boa',
            'flux_dn_sw': 'flux_dn_sw_boa',
            'flux_dn_direct_sw': 'flux_dn_direct_sw_boa'
        }),
        toa[['flux_up_lw', 'flux_up_sw']].rename({
            'flux_up_lw': 'flux_up_lw_toa',
            'flux_up_sw': 'flux_up_sw_toa',
        })
    ])
    
def subset_levels(ds, toa_level):
    return ds.sel(half_level=slice(0, toa_level + 1), level=slice(0, toa_level + 1)).rename(level='half_level')

if level:
    toa_level = level.stop
else:
    toa_level = d.y_sparta_test.level[-1].item()
ordering = [
    'flux_dn_lw', 'flux_up_lw',
    'flux_dn_sw', 'flux_dn_direct_sw', 'flux_up_sw',
    'heating_rate_lw', 'heating_rate_sw',
    'flux_dn_lw_boa', 'flux_up_lw_toa',
    'flux_dn_sw_boa', 'flux_dn_direct_sw_boa', 'flux_up_sw_toa'
]
stats_cols = ['mbe', 'mae', 'rmse', 'std']

m_baseline = pd.concat([
    utils.compute_error_metrics(
        subset_levels(d.y_sparta_test if use_diff else d.y_sparta_test + d.y_triple_test, toa_level),
        subset_levels(d.y_triple_test, toa_level))[stats_cols].drop(index='all'),
    utils.compute_error_metrics(
        derive_boa_toa_quantities(d.y_sparta_test if use_diff else d.y_sparta_test + d.y_triple_test, toa_level),
        derive_boa_toa_quantities(d.y_triple_test, toa_level))[stats_cols].drop(index='all')
]).reindex(ordering)

m_error = pd.concat([
    utils.compute_error_metrics(
        subset_levels(d.y_true_test, toa_level),
        subset_levels(d.y_pred_test, toa_level))[stats_cols].drop(index='all'),
    utils.compute_error_metrics(
        derive_boa_toa_quantities(d.y_true_test, toa_level),
        derive_boa_toa_quantities(d.y_pred_test, toa_level))[stats_cols].drop(index='all')
]).reindex(ordering)

m_percent = m_error / m_baseline * 100

def merge_multiindex_with_label(dfs, labels):
    return pd.concat(dfs, axis=1,keys=labels).swaplevel(0,1,axis=1).sort_index(axis=1)

m_table = merge_multiindex_with_label([m_baseline, m_error, m_percent], ['baseline', 'error', 'percent'])
m_table = m_table.rename(columns={'mbe': '_mbe'}).sort_index(axis=1)

def fmt(v):
    s = '{0:.2g}'.format(v)
    if 'e' in s:
        v = float(s)
        s = str(v).rstrip('0').rstrip('.')
    return s

with pd.option_context('display.float_format', fmt):
    display(m_table)
    header = [f'{m}_{n}' for m, n in m_table.columns]
    print(m_table.to_csv(header=header, float_format='%.2g'))

In [None]:
def plot_random_profiles(ds_out_true, ds_out_pred, n_rand_profiles):
    # Ensure same results across multiple runs
    np.random.seed(0)
    idxs = np.random.choice(ds_out_true.column, n_rand_profiles)
    quantities = [x for x in ds_out_true.keys()]
    fig, ax = plt.subplots(len(idxs), len(quantities), figsize=(6*len(quantities), 3*len(idxs)))
    for c_count, c_idx in enumerate(idxs):
        for q_idx, q_name in enumerate(quantities):
            ds_out_true[q_name].sel(column=c_idx).plot(ax=ax[c_count, q_idx], label='true', c='k', linestyle='--')
            ds_out_pred[q_name].sel(column=c_idx).plot(ax=ax[c_count, q_idx], label='pred', c='r')
            ax[c_count, q_idx].set_title(f'Column Index: {c_idx}')
            ax[c_count, q_idx].legend()

In [None]:
plot_random_profiles(d.y_true_test, d.y_pred_test, 3)

In [None]:
def plot_scatter(ds_3d_true, ds_3d_pred, ds_sparta, plot_x_absolute=False):
    if use_diff:
        plot_x_absolute = False
    
    d_names = {
        'flux_dn_lw' : 'Downwelling longwave',
        'flux_up_lw' : 'Upwelling longwave',
        'flux_dn_sw' : 'Total downwelling shortwave',
        'flux_dn_direct_sw' : 'Direct downwelling shortwave',
        'flux_up_sw' : 'Upwelling shortwave'
    }
    
    q_l_map = {'flux_dn_lw' : 0,
             'flux_up_lw' : -1,
             'flux_dn_sw' : 0,
             'flux_dn_direct_sw' : 0,
             'flux_up_sw' : -1}
    fig, ax = plt.subplots(1, len(q_l_map), figsize=(6*len(q_l_map), 5))
    for q_idx, q_name in enumerate(q_l_map.keys()):
        if plot_x_absolute:
            x = ds_3d_true[q_name].sel(half_level=q_l_map[q_name]) + ds_sparta[q_name].sel(half_level=q_l_map[q_name])
        else:
            x = ds_3d_true[q_name].sel(half_level=q_l_map[q_name])

        y = ds_3d_pred[q_name].sel(half_level=q_l_map[q_name])
        ax[q_idx].scatter(x=x,y=y, s=10, facecolors='none', edgecolors='r', alpha=0.5)
        x_y_lim_min = xr.concat([x,y], dim='column').min()
        x_y_lim_max = xr.concat([x,y], dim='column').max()
    
        # Line
        if plot_x_absolute:
            x_min = ds_sparta.flux_dn_sw.sel(half_level=q_l_map[q_name]).min()
            x_max = ds_sparta.flux_dn_sw.sel(half_level=q_l_map[q_name]).max()
            ax[q_idx].plot([x_min, x_max], [0,0], c='k', linestyle='--')

            ax[q_idx].set_xlabel('Reference flux in W m⁻²')
        else:
            ax[q_idx].set_xlabel(f'Reference {d_names[q_name].lower()} in W m⁻²')
        
        ax[q_idx].set_ylabel(f'Predicted {d_names[q_name].lower()} in W m⁻²')
        
        l_name = 'BOA' if q_l_map[q_name] == 0 else 'TOA'
        if use_diff:
            ax[q_idx].set_title(f'3D radiative effect flux at {l_name}', fontweight='bold')
        else:
            ax[q_idx].set_title(f'Flux at {l_name}', fontweight='bold')
        ax[q_idx].set_xlim(x_y_lim_min, x_y_lim_max)
        ax[q_idx].set_ylim(x_y_lim_min, x_y_lim_max)
        ax[q_idx].plot(ax[q_idx].get_xlim(), ax[q_idx].get_ylim(), ls="--", c=".3")

In [None]:
plot_scatter(d.y_true_test, d.y_pred_test, d.y_sparta_test)

In [None]:
def compute_atmos_abs(ds, flux_type):
    net_pred = (ds[f'flux_dn_{flux_type}'].sel(half_level=-1) - ds[f'flux_up_{flux_type}'].sel(half_level=-1)) \
             - (ds[f'flux_dn_{flux_type}'].sel(half_level=0) - ds[f'flux_up_{flux_type}'].sel(half_level=0))
    return net_pred

In [None]:
abs_sw_true = compute_atmos_abs(d.y_true_test, 'sw')
abs_sw_pred = compute_atmos_abs(d.y_pred_test, 'sw')
abs_lw_true = compute_atmos_abs(d.y_true_test, 'lw')
abs_lw_pred = compute_atmos_abs(d.y_pred_test, 'lw')

In [None]:
def plot_net_toa_boa(ds, flux_type):
    tmp_toa = (ds[f'flux_dn_{flux_type}'].sel(half_level=-1) - ds[f'flux_up_{flux_type}'].sel(half_level=-1))
    tmp_boa = (ds[f'flux_dn_{flux_type}'].sel(half_level=0) - ds[f'flux_up_{flux_type}'].sel(half_level=0))
    _, ax = plt.subplots(1,2, figsize=(10,4))
    lim_min = min(tmp_toa.min(), tmp_boa.min())
    lim_max = max(tmp_toa.max(), tmp_boa.max())
    ax[0].scatter(tmp_boa, tmp_toa, alpha=0.1, color='k')
    ax[0].set_xlabel('3D Effect on net BOA SW flux in W/m2')
    ax[0].set_ylabel('3D effect on net TOA SW flux  in W/m2')
    ax[0].set_ylim([lim_min,lim_max])
    ax[0].set_xlim([lim_min,lim_max])
    ax[0].axline((0, 0), (1, 1), linewidth=1, color='r')
    
    ax[1].scatter(tmp_boa, abs_sw_true, alpha=0.1, color='k')
    ax[1].set_xlabel('3D Effect on net BOA SW flux in W/m2')
    ax[1].set_ylabel('3D effect on net SW atmospheric absorption in W/m2')
    ax[1].set_ylim([lim_min,lim_max])
    ax[1].set_xlim([lim_min,lim_max])
    ax[1].axline((0, 0), (1, 1), linewidth=1, color='r')

plot_net_toa_boa(d.y_true_test, 'sw')

In [None]:
_, ax = plt.subplots(1,2, figsize=(10,4))
lim_min = min(abs_sw_true.min(), abs_sw_pred.min())
lim_max = max(abs_sw_true.max(), abs_sw_pred.max())
ax[0].scatter(abs_sw_true, abs_sw_pred, alpha=0.1, color='k')
ax[0].set_xlabel('3D signal in W/m2')
ax[0].set_ylabel('3D prediction in W/m2')
ax[0].set_title('Atmospheric Absorption (Shortwave)')
ax[0].set_ylim([lim_min,lim_max])
ax[0].set_xlim([lim_min,lim_max])
ax[0].axline((0, 0), (1, 1), linewidth=1, color='r')

lim_min = min(abs_lw_true.min(), abs_lw_pred.min())
lim_max = max(abs_lw_true.max(), abs_lw_pred.max())
ax[1].scatter(abs_lw_true, abs_lw_pred, alpha=0.1, color='k')
ax[1].set_xlabel('3D signal in W/m2')
ax[1].set_ylabel('3D prediction in W/m2')
ax[1].set_title('Atmospheric Absorption (Longwave)')
ax[1].set_ylim([lim_min,lim_max])
ax[1].set_xlim([lim_min,lim_max])
ax[1].axline((0, 0), (1, 1), linewidth=1, color='r')

plt.show()

In [None]:
_, ax = plt.subplots(1,2, figsize=(10,4))
pred_ref_dn_sw_boa = d.y_pred_test['flux_dn_sw'].sel(half_level=0) * d.x_test['sw_albedo']
pred_com_dn_sw_boa = d.y_pred_test['flux_up_sw'].sel(half_level=0)

lim_min = min(pred_ref_dn_sw_boa.min(), pred_com_dn_sw_boa.min())
lim_max = max(pred_ref_dn_sw_boa.max(), pred_com_dn_sw_boa.max())
ax[0].scatter(pred_ref_dn_sw_boa, pred_com_dn_sw_boa, alpha=0.1, color='k')
ax[0].set_xlabel('3D S DN at BOA x Surface Albedo in W/m2')
ax[0].set_ylabel('3D S DN at BOA')
ax[0].set_title('3D predictions')
ax[0].set_ylim([lim_min,lim_max])
ax[0].set_xlim([lim_min,lim_max])
ax[0].axline((0, 0), (1, 1), linewidth=1, color='r')

true_ref_dn_sw_boa = d.y_true_test['flux_dn_sw'].sel(half_level=0) * d.x_test['sw_albedo']
true_com_dn_sw_boa = d.y_true_test['flux_up_sw'].sel(half_level=0)

lim_min = min(true_ref_dn_sw_boa.min(), true_com_dn_sw_boa.min())
lim_max = max(true_ref_dn_sw_boa.max(), true_com_dn_sw_boa.max())
ax[1].scatter(true_ref_dn_sw_boa, true_com_dn_sw_boa, alpha=0.1, color='k')
ax[1].set_xlabel('3D S DN at BOA x Surface Albedo in W/m2')
ax[1].set_ylabel('3D S UP at BOA')
ax[1].set_title('3D signal (SPARTACUS - Tripleclouds)')
ax[1].set_ylim([lim_min,lim_max])
ax[1].set_xlim([lim_min,lim_max])
ax[1].axline((0, 0), (1, 1), linewidth=1, color='r')


In [None]:
from utils import multi_plot
y_sparta = d.y_sparta_test.sel(half_level=slice(0, -1)).rename(half_level='level')
y_triple = d.y_triple_test.sel(half_level=slice(0, -1)).rename(half_level='level')
if use_diff:
    y_nn = d.y_triple_test.sel(half_level=slice(0, -1)).rename(half_level='level') \
         + d.y_pred_test.sel(half_level=slice(0, -1)).rename(half_level='level')
else:
    y_nn = d.y_pred_test.sel(half_level=slice(0, -1)).rename(half_level='level')
x_true = d.x_test
x_true = x_true.assign(pressure_fl=x_true['pressure_fl'] / 100)

utils.multi_plot(y_sparta, y_triple, y_nn, x_true, 
                     y_axis='pressure_fl', y_axis_label='Pressure in hPa',
                     variant='fluxes', is_diff=use_diff)
utils.plt_show_svg()

In [None]:
utils.multi_plot(y_sparta, y_triple, y_nn, x_true, 
                     y_axis='pressure_fl', y_axis_label='Pressure in hPa',
                     variant='hr', is_diff=use_diff)
utils.plt_show_svg()

## Era5slice

In [None]:
path_era5 = path_data / 'era5slice'
path_era5_in =  path_era5 / 'era5slice.nc'

path_ecrad = path_proj / 'ecrad'
path_ecrad_bin = path_ecrad / 'bin' / 'ecrad'

path_triple_era5_nml = path_era5 / 'config_era5slice_tripleclouds.nam'
path_sparta_era5_nml = path_era5 / 'config_era5slice_spartacus.nam'

path_era5_out_dir = Path(tempfile.gettempdir()) if pbs_id else path_era5
path_triple_era5_out = path_era5_out_dir / 'ecrad_nwp_era5slice_tripleclouds_out.nc'
path_sparta_era5_out = path_era5_out_dir / 'ecrad_nwp_era5slice_spartacus_out.nc'

nn_out_path = path_era5_out_dir /  'era5slice_nn_out.nc'

windows_paths_era5 = [path_ecrad, path_ecrad_bin, path_triple_era5_nml, path_sparta_era5_nml, 
                      path_era5_in, path_triple_era5_out, path_sparta_era5_out, path_logs]
posix_paths_era5 = []    
for path in windows_paths_era5:
    if is_windows:
        x = ! wsl wslpath "{path}"
        x = x[0]
    else:
        x = path.as_posix()
    posix_paths_era5.append(x)

plots_path = path_era5_out_dir / 'plots'
ecrad_plots_path = plots_path / 'ecrad'
nn_plots_path = plots_path / 'nn'
shutil.rmtree(ecrad_plots_path, ignore_errors=True)
ecrad_plots_path.mkdir(parents=True)
shutil.rmtree(nn_plots_path, ignore_errors=True)
nn_plots_path.mkdir(parents=True)

In [None]:
if enhance_data:
    # Spartacus
    # ! cd {ecrad_dir} && wsl OMP_NUM_THREADS= {ecrad} {sparta_nml_path} {derived_input_path} {sparta_output_path} > sparta_synth.log
    ! cd {path_ecrad} && {wsl} OMP_NUM_THREADS= {posix_paths_era5[1]} {posix_paths_era5[3]} {posix_paths_era5[4]} {posix_paths_era5[6]} > {windows_paths_era5[7]}/sparta_synth__era_era5.log

In [None]:
if enhance_data:
    # Tripleclouds
    ! cd {path_ecrad} && {wsl} OMP_NUM_THREADS= {posix_paths_era5[1]} {posix_paths_era5[2]} {posix_paths_era5[4]} {posix_paths_era5[5]} > {windows_paths_era5[7]}/triple_synth_era5.log

In [None]:
triple_out_path = path_triple_era5_out
sparta_out_path = path_sparta_era5_out
in_path = path_era5_in

In [None]:
in_data_all = utils.load_scheme_inputs(path_era5_in, only_relevant=False)

ds_inputs = utils.add_derived_inputs(in_data_all)
ds_inputs = ds_inputs[input_quantities]
if level:
    ds_inputs = ds_inputs.sel(level=level)

if normalize_inputs:
    ds_inputs = utils.normalize_inputs(ds_inputs, norm_stats_in)
    
utils.plot_inputs(ds_inputs, normalize_inputs)

In [None]:
ds_inputs

In [None]:
if is_lw_sw_split:
    # Longwave
    in_data_lw = ds_inputs.drop(['sw_albedo', 'cos_solar_zenith_angle'])
    if model_type == 'RNN':
        in_data_lw = xr_1d_to_2d_vars(in_data_lw, ['skin_temperature'])
    triple_out_lw = utils.load_scheme_outputs(triple_out_path, only_relevant=True)[lw_fluxes]
    if use_heating_rates:
        triple_out_lw = utils.add_heating_rates(triple_out_lw, in_data_all)

    x_flat_lw = d_lw.model.to_model_input(in_data_lw)
    y_pred_lw = d_lw.model.predict(x_flat_lw)
    if normalize_outputs:
        y_pred_lw = utils.unnormalize_outputs(y_pred_lw, norm_stats_out_lw)
    if use_diff:
        y_pred_lw = recover_levels(y_pred_lw, triple_out_lw)
        y_pred_lw = triple_out_lw + y_pred_lw
    
    # Shortwave
    in_data_sw = ds_inputs.drop(['skin_temperature', 'temperature_fl'])
    if model_type == 'RNN':
        in_data_sw = xr_1d_to_2d_vars(in_data_sw, ['cos_solar_zenith_angle', 'sw_albedo'])
    triple_out_sw = utils.load_scheme_outputs(triple_out_path, only_relevant=True)[sw_fluxes]
    if use_heating_rates:
        triple_out_sw = utils.add_heating_rates(triple_out_sw, in_data_all)

    x_flat_sw = d_sw.model.to_model_input(in_data_sw)
    y_pred_sw = d_sw.model.predict(x_flat_sw)
    if normalize_outputs:
        y_pred_sw = utils.unnormalize_outputs(y_pred_sw, norm_stats_out_sw)
    if use_diff:
        y_pred_sw = recover_levels(y_pred_sw, triple_out_sw)
        y_pred_sw = triple_out_sw + y_pred_sw
    
    # Merge/load all
    y_pred = xr.merge([y_pred_lw, y_pred_sw])
    in_data = utils.load_scheme_inputs(in_path, only_relevant=True)
    triple_out = utils.load_scheme_outputs(triple_out_path, only_relevant=True)
    sparta_out = utils.load_scheme_outputs(sparta_out_path, only_relevant=True)

else:
    in_data = ds_inputs
    triple_out = utils.load_scheme_outputs(triple_out_path, only_relevant=True)
    sparta_out = utils.load_scheme_outputs(sparta_out_path, only_relevant=True)

    x_flat = d.model.to_model_input(in_data)
    y_pred = d.model.predict(x_flat)
    if normalize_outputs:
        y_pred = utils.unnormalize_outputs(y_pred, norm_stats_out)
    if use_diff:
        y_pred = recover_levels(y_pred, triple_out)
        y_pred = triple_out + y_pred
    
display(y_pred)

In [None]:
# Copy over clear-sky fluxes from Tripleclouds.
# Since 3D effects arise only when clouds are present, the clear-sky fluxes
# are identical between Tripleclouds and SPARTACUS.

triple_out_temp = utils.load_scheme_outputs(triple_out_path, only_relevant=False)
if use_heating_rates:
    triple_out_temp = utils.add_heating_rates(triple_out_temp, in_data_all)

clear_sky_fluxes = ['flux_up_lw_clear', 'flux_dn_lw_clear', 'flux_up_sw_clear', 'flux_dn_sw_clear', 'cloud_cover_sw']
y_pred_temp = y_pred.merge(triple_out_temp[clear_sky_fluxes])

# Copy over profiles from Tripleclouds for all non-cloudy input profiles.
if cloudy_profiles_only:
    non_cloudy_idx = sel_cloudy_profiles(in_data_all, cloudy=False).column
    for var_name in y_pred_temp:
        y_pred_temp[var_name].loc[dict(column=non_cloudy_idx)] = triple_out_temp[var_name].sel(column=non_cloudy_idx)

# store as NetCDF for 
y_pred_reversed = utils.reverse_levels(y_pred_temp)
y_pred_reversed.to_netcdf(nn_out_path)

In [None]:
# Inputs
# add --include-t to plot T as second y axis
stdout = ! python {path_ecrad}/practical/plot_input.py {in_path} --mode scalars --dstdir {ecrad_plots_path}
fig_path = stdout[0].split()[-1]
Image(filename=fig_path)

In [None]:
# 3D effect: True and Predicted
stdout = ! python {path_ecrad}/practical/compare_output.py {in_path} {triple_out_path} {sparta_out_path} --output2 {nn_out_path} --dstdir {ecrad_plots_path}
fig_path = stdout[0].split()[-1]
Image(filename=fig_path)

In [None]:
# NN error
stdout = ! python {path_ecrad}/practical/compare_output.py {in_path} {sparta_out_path} {nn_out_path} --dstdir {nn_plots_path}
fig_path = stdout[0].split()[-1]
Image(filename=fig_path)

In [None]:
# TOA
# T
stdout = ! python {path_ecrad}/practical/compare_output_scalar.py {in_path} {triple_out_path} {sparta_out_path} {nn_out_path} --mode paper --dstdir {nn_plots_path}
fig_path = stdout[0].split()[-1]
utils.plt_show_svg(fig_path)

## Validation Statistics

In [None]:
y_pred = utils.add_heating_rates(y_pred, in_data_all)
triple_out = utils.add_heating_rates(triple_out, in_data_all)
sparta_out = utils.add_heating_rates(sparta_out, in_data_all)

In [None]:
plot_random_profiles(sparta_out - triple_out, y_pred - triple_out, 3)

In [None]:
plot_scatter(sparta_out - triple_out, y_pred - triple_out, sparta_out)

In [None]:
if level:
    m_era5slice = utils.compute_error_metrics(y_pred.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1))[stat_quantities_sw + stat_quantities_lw].drop(_heating_rate_sw + _heating_rate_lw),
                                          sparta_out.sel(half_level=slice(level.start, level.stop + 1), level=slice(level.start, level.stop + 1))[stat_quantities_sw + stat_quantities_lw].drop(_heating_rate_sw + _heating_rate_lw)).add_prefix('era5slice_')
else:
    m_era5slice = utils.compute_error_metrics(y_pred[stat_quantities_sw + stat_quantities_lw].drop(_heating_rate_sw + _heating_rate_lw),
                                          sparta_out[stat_quantities_sw + stat_quantities_lw].drop(_heating_rate_sw + _heating_rate_lw)).add_prefix('era5slice_')
m_era5slice

In [None]:
if is_pbs_job:
    f_job_failure.unlink()