In [55]:
%load_ext autoreload
%autoreload 2
from configs import project_config
import numpy as np
import torch
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
def rmse_ignore_nans(y_true, y_pred):
    assert (y_true.shape[0] == y_pred.shape[0])
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)  # Ignore NaNs in both arrays
    return np.sqrt(np.mean((y_true[mask] - y_pred[mask]) ** 2))

In [57]:
def plot_drift_estimator(config,mean, stds, numpy_Xs, true_drift, ts_type, Nepoch,toSave:bool = True):
    fig, ax = plt.subplots(figsize=(14,9))
    rmse = rmse_ignore_nans(true_drift, mean).astype(np.float64)#np.power(np.mean(np.power(true_drift - mean, 2)), 0.5)
    ax.scatter(numpy_Xs, true_drift, color="red", label="True Drift")

    #ax.errorbar(numpy_Xs, mean, fmt="o",yerr=2*stds, label="Estimated Drift")
    plt.scatter(numpy_Xs, mean, label="Estimated Drift", color="blue")
    ax.set_title(rf"RMSE {round(rmse,3)} for LSTM Score Estimator", fontsize=40)
    ax.tick_params(labelsize=38)
    ax.set_xlabel("State $X$", fontsize=38)
    ax.set_ylabel("Drift Value", fontsize=38)
    ax.legend(fontsize=24)
    plt.tight_layout()
    if toSave:
        plt.savefig(f"/Users/marcos/Library/CloudStorage/OneDrive-ImperialCollegeLondon/StatML_CDT/Year2/DiffusionModelPresentationImages/{ts_type}_LSTM_{Nepoch}Nep_{config.deltaT:.3e}dT_{config.loss_factor}LFac.png",  bbox_inches='tight')
    plt.show()
    plt.close()

In [58]:
def get_rmses(config):
    rmses = {}
    for Nepoch in config.max_epochs:
        file_path = (
                    project_config.ROOT_DIR + f"experiments/results/TSPM_LSTM_fQuadSinHF_DriftEvalExp_{Nepoch}Nep_{config.t0}t0_{config.deltaT:.3e}dT_{config.quad_coeff}a_{config.sin_coeff}b_{config.sin_space_scale}c_{config.beta_max:.1e}betaMax_{config.loss_factor}LFac").replace(
                ".", "")
        try:
            muhats = torch.Tensor(np.load(file_path+"_muhats.npy"))
            Xshape = muhats.shape[0]
            if config.deltaT > 1/(32*256):
                Xs = torch.linspace(-1.2, 1.2, steps=Xshape)
            else:
                Xs = torch.linspace(-.4, .4, steps=Xshape)
            true_drifts = (-2.*config.quad_coeff * Xs + config.sin_coeff * config.sin_space_scale*np.sin(config.sin_space_scale*Xs)).numpy()
            mu_hats = muhats[:, -1, :].reshape(muhats.shape[0], muhats.shape[-1]*1).mean(dim=-1).numpy()
            rmse = rmse_ignore_nans(true_drifts, mu_hats).astype(np.float64)
            rmses.update({Nepoch: round(rmse, 4)})
        except FileNotFoundError as e:
            continue
    return rmses

In [59]:
from configs.RecursiveVPSDE.LSTM_fQuadSinHF.recursive_LSTM_PostMeanScore_fQuadSinHF_T256_H05_tl_110data import get_config
config = get_config()
assert (config.deltaT == 1./256 and config.loss_factor == 0)
rmses_0 = get_rmses(config=config)

AssertionError: 

In [None]:
rmses_0

In [None]:
%load_ext autoreload
%autoreload 2
from configs.RecursiveVPSDE.LSTM_fQuadSinHF.recursive_LSTM_PostMeanScore_fQuadSinHF_T256_H05_tl_110data import get_config
config = get_config()
assert (config.deltaT == 1./256 and config.loss_factor == 2)
rmses_2 = get_rmses(config=config)

In [None]:
rmses_2

In [None]:
# Convert to DataFrame (pandas automatically fills missing values with NaN)
import pandas as pd
df = pd.DataFrame([rmses_0, rmses_2], ["LFac0", "LFac2"])

# Compute statistics, ignoring NaNs
stats = df.T.describe()
print(stats)  # Summary statistics

In [None]:
print(df)

In [None]:
Nepoch = 2920
%load_ext autoreload
%autoreload 2
from configs.RecursiveVPSDE.LSTM_fQuadSinHF.recursive_LSTM_PostMeanScore_fQuadSinHF_T256_H05_tl_110data import get_config
config = get_config()
file_path = (project_config.ROOT_DIR + f"experiments/results/TSPM_LSTM_fQuadSinHF_DriftEvalExp_{Nepoch}Nep_{config.t0}t0_{config.deltaT:.3e}dT_{config.quad_coeff}a_{config.sin_coeff}b_{config.sin_space_scale}c_{config.beta_max:.1e}betaMax_{config.loss_factor}LFac").replace(
                ".", "")
muhats = torch.Tensor(np.load(file_path+"_muhats.npy"))
Xshape = muhats.shape[0]
if config.deltaT > 1/(32*256):
    Xs = torch.linspace(-1.2, 1.2, steps=Xshape)
else:
    Xs = torch.linspace(-.4, .4, steps=Xshape)
true_drifts = (-2.*config.quad_coeff * Xs + config.sin_coeff * config.sin_space_scale*np.sin(config.sin_space_scale*Xs)).numpy()
mu_hats = muhats[:, -1, :].reshape(muhats.shape[0], muhats.shape[-1]*1).mean(dim=-1).numpy()
print(np.mean(mu_hats), np.std(mu_hats))
stds = muhats[:, -1, :].reshape(muhats.shape[0], muhats.shape[-1]*1).std(dim=-1).numpy()
plot_drift_estimator(mean=mu_hats, stds=stds, numpy_Xs=Xs, toSave=True, true_drift=true_drifts, ts_type="fQuadSinHF", Nepoch=Nepoch, config=config)
print(np.mean(stds))
del muhats, mu_hats, stds, true_drifts