# Baselines & Filtering

In [None]:
from overcast.datasets import JASMIN

import pandas as pd
import numpy as np
from scipy import stats

from pathlib import Path

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.pipeline import make_pipeline
from sklearn.linear_model import BayesianRidge, RidgeCV
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.neural_network import MLPRegressor

In [None]:
project_dir = Path("MR-MLforACI")
data_dir = project_dir / "data"

In [None]:
def make_ds(dataset, covariates, treatment, outcomes, target_keys):
    ds_train = JASMIN(
        data_dir=data_dir,
        dataset=dataset,
        split='train',
        x_vars=covariates,
        t_var=treatment,
        y_vars=outcomes,
        t_bins=1,
    )
    ds_valid = JASMIN(
        data_dir=data_dir,
        dataset=dataset,
        split='valid',
        x_vars=covariates,
        t_var=treatment,
        y_vars=outcomes,
        t_bins=1,
    )
    ds_test = JASMIN(
        data_dir=data_dir,
        dataset=dataset,
        split='test',
        x_vars=covariates,
        t_var=treatment,
        y_vars=outcomes,
        t_bins=1,
    )
    ds = {
        "test": ds_test,
        "valid": ds_valid,
        "train": ds_train,
        "TARGET_KEYS": target_keys,
    }
    return ds

In [None]:
def make_xy(ds):
    """
    Combine train and validation
    Combine data & treatment
    """
    x_train = np.concatenate((ds['train'].data, ds['train'].treatments), axis=1)
    x_valid = np.concatenate((ds['valid'].data, ds['valid'].treatments), axis=1)

    x_train_valid = np.concatenate((x_train, x_valid), axis=0)
    y_train_valid = np.concatenate((ds['train'].targets, ds['valid'].targets), axis=0)

    x_test = np.concatenate((ds['test'].data, ds['test'].treatments), axis=1)
    y_test = ds['test'].targets

    return x_train_valid, y_train_valid, x_test, y_test

In [None]:
def make_model_from_name(name):
    polyridge_degree = 2
    if name == 'ridge':
        model = make_pipeline(StandardScaler(), RidgeCV())
    elif name == 'bayridge':
        model = make_pipeline(StandardScaler(), BayesianRidge())
    elif name == 'polyridge2':
        model = make_pipeline(StandardScaler(), PolynomialFeatures(2), RidgeCV())
    elif name == 'polyridge3':
        model = make_pipeline(StandardScaler(), PolynomialFeatures(3), RidgeCV())
    elif name == 'mlp1':
        model = make_pipeline(
            StandardScaler(),
            MLPRegressor(
                hidden_layer_sizes=(100, ), 
                activation="relu", 
                solver="adam",
                early_stopping=True, 
                validation_fraction=0.1,
            )
        )
    return model

In [None]:
def predict_outcomes(model, ds):
    (x_train_valid, y_train_valid, x_test, y_test) = make_xy(ds)
    observed_outcomes = ds['test'].targets_xfm.inverse_transform(y_test)
    predicted_outcomes = np.zeros(shape=observed_outcomes.shape)
    for idx_outcome in range(len(ds['TARGET_KEYS'])):
        model.fit(x_train_valid, y_train_valid[:, idx_outcome])
        predicted_outcomes[:, idx_outcome] = model.predict(x_test)
    predicted_outcomes = ds['test'].targets_xfm.inverse_transform(predicted_outcomes)
    return predicted_outcomes, observed_outcomes

In [None]:
def scatter_plot(TARGET_KEYS, predicted_outcomes, observed_outcomes, color, savepath=None):
    fig, axs = plt.subplots(1, len(TARGET_KEYS), figsize=(len(TARGET_KEYS)*6, 6))
    for idx_outcome in range(len(TARGET_KEYS)):
        qs = np.quantile(observed_outcomes[:, idx_outcome], [0.01, 0.99])
        domain = np.arange(qs[0], qs[1], 0.01)
        slope, intercept, r, p, stderr = stats.linregress(
            observed_outcomes[:, idx_outcome], predicted_outcomes[:, idx_outcome]
        )
        _ = axs[idx_outcome].scatter(
            x=observed_outcomes[:, idx_outcome],
            y=predicted_outcomes[:, idx_outcome],
            s=0.01,
            c=color,
        )
        _ = axs[idx_outcome].plot(domain, domain, c="C2")
        _ = axs[idx_outcome].plot(domain, domain * slope + intercept, c=color, label=f"$r^2$={r**2:.03f}")
        _ = axs[idx_outcome].set_xlim(qs)
        _ = axs[idx_outcome].set_ylim(qs)
        _ = axs[idx_outcome].set_xlabel(f"{TARGET_KEYS[idx_outcome]} true")
        _ = axs[idx_outcome].set_ylabel(f"{TARGET_KEYS[idx_outcome]} predicted")
        _ = axs[idx_outcome].legend(loc="upper left")
    if savepath is not None:
        plt.savefig(f'{savepath}.png', format="png", bbox_inches='tight')

In [None]:
PLOT_PRECISION = 32

def apo_curves(ds, predicted_outcomes, savepath=None):
    TARGET_KEYS = ds['TARGET_KEYS']
    treatments = np.concatenate(ds['train'].treatments, axis=0)
    treatments = ds['train'].treatments_xfm.inverse_transform(treatments)
    treatments = np.quantile(
        treatments,
        q=np.arange(0, 1 + 1 / PLOT_PRECISION, 1 / PLOT_PRECISION),
    )[:-1]
    predicted_outcomes_ = np.quantile(
        predicted_outcomes,
        q=np.arange(0, 1 + 1 / PLOT_PRECISION, 1 / PLOT_PRECISION),
        axis=0,
    )[:-1]
    _, ax = plt.subplots(2, 2, figsize=(12, 12))
    for idx_outcome in range(len(TARGET_KEYS)):
        i, j = idx_outcome//2, idx_outcome%2
        _ = sns.lineplot(x=treatments, y=predicted_outcomes_[:, idx_outcome], ax=ax[i][j])
        _ = ax[i][j].set_xlabel(ds['train'].treatment_names[0][0])
        _ = ax[i][j].set_ylabel(TARGET_KEYS[idx_outcome])
    if savepath is not None:
        plt.savefig(savepath)
    plt.show()

## Figures

In [None]:
experiments = [
    # LRP, ridge polyridge & mlp1
    ('ridge', 'four_outputs_liqcf_pacific', ['RH900', 'RH850', 'RH700', 'LTS', 'EIS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP', 'LPC'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP', 3: 'CF'}), 
    ('polyridge3', 'four_outputs_liqcf_pacific', ['RH900', 'RH850', 'RH700', 'LTS', 'EIS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP', 'LPC'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP', 3: 'CF'}),
    ('mlp1', 'four_outputs_liqcf_pacific', ['RH900', 'RH850', 'RH700', 'LTS', 'EIS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP', 'LPC'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP', 3: 'CF'}),
    # HRP treatments
    ('polyridge3', 'MERRA_25kmres_2003_08', ['RH950', 'RH850', 'RH700', 'LTS', 'W500', 'SST'], 'AOD', ['Nd', 're', 'COD', 'CWP'], {0: r'$N_d$', 1: r'$r_e$', 2: r'$\tau$', 3: 'CWP'}), 
    ('polyridge3', 'MERRA_25kmres_2003_08', ['RH950', 'RH850', 'RH700', 'LTS', 'W500', 'SST', 'AOD'], 'Nd', ['re', 'COD', 'CWP'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP'}), 
    ('polyridge3', 'MERRA_25kmres_2003_08', ['RH950', 'RH850', 'RH700', 'LTS', 'W500', 'SST', 'AOD', 'Nd'], 're', ['COD', 'CWP'], {0: r'$\tau$', 1: 'CWP'}), 
    # LRP 2004
    ('polyridge3', 'four_outputs_liqcf_pacific_2004', ['RH900', 'RH850', 'RH700', 'LTS', 'EIS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP', 'LPC'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP', 3: 'CF'}),
    # # LRP & HRP same x and y
    ('polyridge3', 'four_outputs_liqcf_pacific', ['RH850', 'RH700', 'LTS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP'}),
    ('polyridge3', 'MERRA_25kmres_2003', ['RH850', 'RH700', 'LTS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP'}),
    # HRP 2003
    ('polyridge3', 'MERRA_25kmres_2003_08', ['RH950', 'RH850', 'RH700', 'LTS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP'}), 
    ('polyridge3', 'MERRA_25kmres_2003', ['RH950', 'RH850', 'RH700', 'LTS', 'W500', 'SST'], 'AOD', ['re', 'COD', 'CWP'], {0: r'$r_e$', 1: r'$\tau$', 2: 'CWP'}), 
]

for experiment in experiments: 
    model_name, dataset, covariates, treatment, outcomes, target_keys = experiment
    color = 'C0' if dataset == 'four_outputs_liqcf_pacific' else 'C1'
    cov_str = '_'.join(covariates)
    out_str = '_'.join(outcomes)
    ds = make_ds(dataset, covariates, treatment, outcomes, target_keys)
    model = make_model_from_name(model_name)
    predicted_outcomes, observed_outcomes = predict_outcomes(model, ds)
    print(f'Model: {model_name}, Dataset: {dataset}, X: {cov_str}, T: {treatment}, Y: {out_str}')
    scatter_plot(
        ds['TARGET_KEYS'],
        predicted_outcomes,
        observed_outcomes,
        color,
        savepath=f'{project_dir}/figures/baselines/{model_name}-{dataset}-{treatment}-{cov_str}-{out_str}-scatter'
    )

## Histograms

In [None]:
covariates = ['RH850', 'RH700', 'LTS', 'W500', 'SST']
treatment = 'AOD'
outcomes = ['re', 'COD', 'CWP']

In [None]:
aod = False
re = False
cwp = True
prefix = str()
if aod: prefix += 'aod_'
if re: prefix += 're_'
if cwp: prefix += 'cwp_'
if not aod and not re and not cwp: prefix += 'no_'
prefix += 'filter'

In [None]:
ds_hr_train = JASMIN(
    data_dir=data_dir,
    dataset='MERRA_25kmres_2003',
    split='train',
    x_vars=covariates,
    t_var=treatment,
    y_vars=outcomes,
    filter_aod=aod, 
    filter_re=re, 
    filter_cwp=cwp,
    t_bins=1,
)

ds_lr_train = JASMIN(
    data_dir=data_dir,
    dataset='four_outputs_liqcf_pacific',
    split='train',
    x_vars=covariates,
    t_var=treatment,
    y_vars=outcomes,
    filter_aod=aod, 
    filter_re=re, 
    filter_cwp=cwp,
    t_bins=1,
)

In [None]:
hr_t = ds_hr_train.treatments_xfm.inverse_transform(ds_hr_train.treatments)
lr_t = ds_lr_train.treatments_xfm.inverse_transform(ds_lr_train.treatments)

qs = np.quantile(hr_t, [0.01, 0.99])
bins = np.linspace(qs[0], qs[1], 100)

plt.plot(figsize=(6, 6))
plt.hist(lr_t, bins, density=True, alpha=0.5, label='Low Resolution')
plt.hist(hr_t, bins, density=True, alpha=0.5, label='High Resolution')
plt.xlabel('AOD')
plt.ylabel('Frequency')
plt.legend(loc='upper right')
plt.savefig(f'{project_dir}/figures/filtering/{prefix}-aod-hist.png')

In [None]:
hr_y = ds_hr_train.targets_xfm.inverse_transform(ds_hr_train.targets)[:, 0]
lr_y = ds_lr_train.targets_xfm.inverse_transform(ds_lr_train.targets)[:, 0]

qs = np.quantile(hr_y, [0.01, 0.99])
bins = np.linspace(qs[0], qs[1], 100)

plt.plot(figsize=(6, 6))
plt.hist(lr_y, bins, density=True, alpha=0.5, label='Low Resolution')
plt.hist(hr_y, bins, density=True, alpha=0.5, label='High Resolution')
plt.xlabel(r'$r_e$')
plt.ylabel('Frequency')
plt.legend(loc='upper right')
plt.savefig(f'{project_dir}/figures/filtering/{prefix}-re-hist.png')

In [None]:
hr_y = ds_hr_train.targets_xfm.inverse_transform(ds_hr_train.targets)[:, 1]
lr_y = ds_lr_train.targets_xfm.inverse_transform(ds_lr_train.targets)[:, 1]

qs = np.quantile(hr_y, [0.01, 0.99])
bins = np.linspace(qs[0], qs[1], 100)

plt.plot(figsize=(6, 6))
plt.hist(lr_y, bins, density=True, alpha=0.5, label='Low Resolution')
plt.hist(hr_y, bins, density=True, alpha=0.5, label='High Resolution')
plt.xlabel(r'$\tau$')
plt.ylabel('Frequency')
plt.legend(loc='upper right')
plt.savefig(f'{project_dir}/figures/filtering/{prefix}-tau-hist.png')

In [None]:
hr_t = ds_hr_train.targets_xfm.inverse_transform(ds_hr_train.targets)[:, 2]
lr_t = ds_lr_train.targets_xfm.inverse_transform(ds_lr_train.targets)[:, 2]

qs = np.quantile(hr_t, [0.01, 0.99])
bins = np.linspace(qs[0], qs[1], 100)

plt.plot(figsize=(6, 6))
plt.hist(lr_t, bins, density=True, alpha=0.5, label='Low Resolution')
plt.hist(hr_t, bins, density=True, alpha=0.5, label='High Resolution')
plt.xlabel(r'$CWP$')
plt.ylabel('Frequency')
plt.legend(loc='upper right')
plt.savefig(f'{project_dir}/figures/filtering/{prefix}-cwp-hist.png')