In [61]:
import numpy as np
import io
import matplotlib.pyplot as plt
import os
import torch
from src.generative_modelling.models.TimeDependentScoreNetworks.ClassConditionalMarkovianTSPostMeanScoreMatching import \
    ConditionalMarkovianTSPostMeanScoreMatching
from utils.drift_evaluation_functions import experiment_MLP_DDims_drifts
from configs.RecursiveVPSDE.Markovian_8DLorenz.recursive_Markovian_PostMeanScore_8DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_8dlnz_config
from configs.RecursiveVPSDE.Markovian_12DLorenz.recursive_Markovian_PostMeanScore_12DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_12dlnz_config
from configs.RecursiveVPSDE.Markovian_20DLorenz.recursive_Markovian_PostMeanScore_20DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_20dlnz_config
from configs.RecursiveVPSDE.Markovian_40DLorenz.recursive_Markovian_PostMeanScore_40DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_40dlnz_config
import pandas as pd
from tqdm import tqdm
from utils.drift_evaluation_functions import multivar_score_based_MLP_drift_OOS
from src.generative_modelling.models.ClassVPSDEDiffusion import VPSDEDiffusion

In [62]:

def _get_device(device_str: str | None = None):
    if device_str is not None:
        return torch.device(device_str)
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def true_drifts(device_id, config, state):
    true_drifts = np.zeros_like(state)
    for i in range(config.ndims):
        true_drifts[:, i] = (state[:, (i + 1) % config.ndims] - state[:, i - 2]) * state[:, i - 1] - state[:,i] * config.forcing_const
    return torch.tensor(true_drifts[:, np.newaxis, :], device=device_id, dtype=torch.float32)

In [63]:
lnz_8d_config = get_8dlnz_config()
lnz_12d_config = get_12dlnz_config()
lnz_20d_config = get_20dlnz_config()
lnz_40d_config = get_40dlnz_config()
device_id = _get_device()
root_dir ="/Users/marcos/Library/CloudStorage/OneDrive-ImperialCollegeLondon/StatML_CDT/Year2/DiffusionModels/"

In [71]:
def generate_synthetic_paths(config, device_id, good):
    diffusion = VPSDEDiffusion(beta_max=config.beta_max, beta_min=config.beta_min)
    num_diff_times = 1
    rmse_quantile_nums = 1
    num_paths = 200
    num_time_steps = config.ts_length
    deltaT = config.deltaT
    all_true_states = np.zeros(shape=(rmse_quantile_nums, num_paths, 1 + num_time_steps, config.ndims))
    all_global_states = np.zeros(shape=(rmse_quantile_nums, num_paths, 1 + num_time_steps, config.ndims))
    all_local_states = np.zeros(shape=(rmse_quantile_nums, num_paths, 1 + num_time_steps, config.ndims))
    for quant_idx in tqdm(range(rmse_quantile_nums)):
        good.eval()
        initial_state = np.repeat(np.atleast_2d(config.initState)[np.newaxis, :], num_paths, axis=0)
        assert (initial_state.shape == (num_paths, 1, config.ndims))

        true_states = np.zeros(shape=(num_paths, 1 + num_time_steps, config.ndims))
        global_states = np.zeros(shape=(num_paths, 1 + num_time_steps, config.ndims))
        local_states = np.zeros(shape=(num_paths, 1 + num_time_steps, config.ndims))

        # Initialise the "true paths"
        true_states[:, [0], :] = initial_state + 0.00001 * np.random.randn(*initial_state.shape)
        # Initialise the "global score-based drift paths"
        global_states[:, [0], :] = true_states[:, [0], :]
        local_states[:, [0], :] = true_states[:, [0],
                                  :]  # np.repeat(initial_state[np.newaxis, :], num_diff_times, axis=0)

        # Euler-Maruyama Scheme for Tracking Errors
        for i in range(1, num_time_steps + 1):
            eps = np.random.randn(num_paths, 1, config.ndims) * np.sqrt(deltaT) * config.diffusion

            assert (eps.shape == (num_paths, 1, config.ndims))
            true_mean = true_drifts(state=true_states[:, i - 1, :], device_id=device_id,config=config).numpy()
            denom = 1.
            global_mean = multivar_score_based_MLP_drift_OOS(score_model=good,
                                                             num_diff_times=num_diff_times,
                                                             diffusion=diffusion,
                                                             num_paths=num_paths,
                                                             ts_step=deltaT, config=config,
                                                             device=device_id,
                                                             prev=global_states[:, i - 1, :])

            true_states[:, [i], :] = (true_states[:, [i - 1], :] \
                                      + true_mean * deltaT \
                                      + eps) / denom
            global_states[:, [i], :] = (global_states[:, [i - 1], :] + global_mean * deltaT + eps) / denom

        all_true_states[quant_idx, :, :, :] = true_states
        all_global_states[quant_idx, :, :, :] = global_states
    return all_true_states, all_local_states

In [72]:

def get_best_epoch(config, type):
    model_dir = "/".join(config.scoreNet_trained_path.split("/")[:-1]) + "/"
    for file in os.listdir(model_dir):
        if config.scoreNet_trained_path in os.path.join(model_dir, file) and f"{type}" in file:
            best_epoch = int(file.split(f"{type}NEp")[-1])
    return best_epoch

def get_best_track_file(root_score_dir, ts_type, best_epoch_track):
    for file in os.listdir(root_score_dir):
        if ("_"+str(best_epoch_track)+"Nep") in file and "true" in file and ts_type in file and "1000FTh" in file and "125FConst" in file:
            with open(root_score_dir+file, 'rb') as f:
                buf = io.BytesIO(f.read())  # hydrates once, sequentially
            true_file = np.load(root_score_dir+file, allow_pickle=True)
        elif ("_"+str(best_epoch_track)+"Nep") in file and "global" in file and ts_type in file and "1000FTh" in file and "125FConst" in file:
            with open(root_score_dir+file, 'rb') as f:
                buf = io.BytesIO(f.read())  # hydrates once, sequentially
            global_file = np.load(root_score_dir+file, allow_pickle=True)
    print(ts_type)
    return true_file, global_file

def get_best_eval_exp_file(config, root_score_dir, ts_type):
    best_epoch_eval = get_best_epoch(config=config,type="EE")
    for file in os.listdir(root_score_dir):
        if ("_"+str(best_epoch_eval)+"Nep") in file and "MSE" in file and ts_type in file and "1000FTh" in file and "125FConst" in file:
            print(f"Starting {file}\n")
            with open(root_score_dir+file, 'rb') as f:
                buf = io.BytesIO(f.read())  # hydrates once, sequentially
            print(f"Starting {file}\n")
            mse = pd.read_parquet(root_score_dir+file, engine="fastparquet")
    return mse

In [73]:
eval_tracks = {t: np.inf for t in ["8DLnz", "12DLnz", "20DLnz", "40DLnz"]}
for config in [lnz_8d_config, lnz_12d_config, lnz_20d_config, lnz_40d_config]:
    assert config.feat_thresh == 1.
    assert config.forcing_const == 1.25
    Xshape = config.ts_length
    root_score_dir = root_dir
    label = "$\mu_{5}$"
    if "8DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/8DLnzLessData/"
        ts_type = "8DLnz"
    elif "12DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/12DLnzLessData/"
        ts_type = "12DLnz"
    elif "20DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/20DLnzLessData/"
        ts_type = "20DLnz"
    elif "40DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/40DLnzLessData/"
        ts_type = "40DLnz"
    print(f"Starting {ts_type}\n")
    model_dir = "/".join(config.scoreNet_trained_path.split("/")[:-1]) + "/"
    entered = False
    best_epoch = get_best_epoch(config=config,type="EE")
    for file in os.listdir(model_dir):
        if config.scoreNet_trained_path in os.path.join(model_dir, file) and ("EE" in file and "Trk" not in file) and str(best_epoch) in file:
            good = ConditionalMarkovianTSPostMeanScoreMatching(
        *config.model_parameters)
            entered = True
            good.load_state_dict(torch.load(os.path.join(model_dir, file)))
    assert entered
    all_true_paths, all_global_paths = generate_synthetic_paths(config=config, device_id=device_id, good=good)
    all_true_paths = all_true_paths.reshape(-1, config.ts_length+1, config.ts_dims)
    all_global_paths = all_global_paths.reshape(-1, config.ts_length+1, config.ts_dims)
    all_true_states = all_true_paths[:, 1:,:].reshape(-1, config.ts_dims)
    all_global_states = torch.tensor(all_global_paths[:, 1:,:].reshape(-1, config.ts_dims), device=device_id, dtype=torch.float32)
    true_drift = true_drifts(state=all_true_states, device_id=device_id,config=config).numpy()
    drift_ests = experiment_MLP_DDims_drifts(config=config, Xs=all_global_states, good=good, onlyGauss=False)
    drift_ests= drift_ests[:, -1, :, :].reshape(drift_ests.shape[0],drift_ests.shape[2],drift_ests.shape[
                                                                                               -1] * 1).mean(axis=1)
    mse = np.mean(np.sum(np.power(true_drift - drift_ests,2), axis=-1))
    eval_tracks[ts_type] = mse


Starting 8DLnz



100%|██████████| 1/1 [00:03<00:00,  3.46s/it]


Starting 12DLnz



100%|██████████| 1/1 [00:03<00:00,  3.04s/it]


Starting 20DLnz



100%|██████████| 1/1 [00:03<00:00,  3.58s/it]


Starting 40DLnz



100%|██████████| 1/1 [00:04<00:00,  4.46s/it]


In [74]:
eval_tracks

{'8DLnz': 513.88196,
 '12DLnz': 520.7596,
 '20DLnz': 1338.998,
 '40DLnz': 1332.6201}