In [13]:
import numpy as np
import io
import matplotlib.pyplot as plt
import os
import torch
from src.generative_modelling.models.TimeDependentScoreNetworks.ClassConditionalMarkovianTSPostMeanScoreMatching import \
    ConditionalMarkovianTSPostMeanScoreMatching
from utils.drift_evaluation_functions import experiment_MLP_DDims_drifts
from configs.RecursiveVPSDE.Markovian_8DLorenz.recursive_Markovian_PostMeanScore_8DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_8dlnz_config
from configs.RecursiveVPSDE.Markovian_12DLorenz.recursive_Markovian_PostMeanScore_12DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_12dlnz_config
from configs.RecursiveVPSDE.Markovian_20DLorenz.recursive_Markovian_PostMeanScore_20DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_20dlnz_config
from configs.RecursiveVPSDE.Markovian_40DLorenz.recursive_Markovian_PostMeanScore_40DLorenz_Stable_T256_H05_tl_110data_StbleTgt import get_config as get_40dlnz_config
import pandas as pd

In [14]:
lnz_8d_config = get_8dlnz_config()
lnz_12d_config = get_12dlnz_config()
lnz_20d_config = get_20dlnz_config()
lnz_40d_config = get_40dlnz_config()
root_dir ="/Users/marcos/Library/CloudStorage/OneDrive-ImperialCollegeLondon/StatML_CDT/Year2/DiffusionModels/"

In [15]:

def get_best_epoch(config, type):
    model_dir = "/".join(config.scoreNet_trained_path.split("/")[:-1]) + "/"
    for file in os.listdir(model_dir):
        if config.scoreNet_trained_path in os.path.join(model_dir, file) and f"{type}" in file:
            best_epoch = int(file.split(f"{type}NEp")[-1])
    return best_epoch

def get_best_track_file(root_score_dir, ts_type, best_epoch_track):
    for file in os.listdir(root_score_dir):
        if ("_"+str(best_epoch_track)+"Nep") in file and "true" in file and ts_type in file and "1000FTh" in file and "125FConst" in file:
            with open(root_score_dir+file, 'rb') as f:
                buf = io.BytesIO(f.read())  # hydrates once, sequentially
            true_file = np.load(root_score_dir+file, allow_pickle=True)
        elif ("_"+str(best_epoch_track)+"Nep") in file and "global" in file and ts_type in file and "1000FTh" in file and "125FConst" in file:
            with open(root_score_dir+file, 'rb') as f:
                buf = io.BytesIO(f.read())  # hydrates once, sequentially
            global_file = np.load(root_score_dir+file, allow_pickle=True)
    print(ts_type)
    return true_file, global_file

def get_best_eval_exp_file(config, root_score_dir, ts_type):
    best_epoch_eval = get_best_epoch(config=config,type="EE")
    for file in os.listdir(root_score_dir):
        if ("_"+str(best_epoch_eval)+"Nep") in file and "MSE" in file and ts_type in file and "1000FTh" in file and "125FConst" in file:
            print(f"Starting {file}\n")
            with open(root_score_dir+file, 'rb') as f:
                buf = io.BytesIO(f.read())  # hydrates once, sequentially
            print(f"Starting {file}\n")
            mse = pd.read_parquet(root_score_dir+file, engine="fastparquet")
    return mse

In [39]:
eval_tracks = {t: np.inf for t in ["8DLnz", "12DLnz", "20DLnz", "40DLnz"]}
for config in [lnz_8d_config, lnz_12d_config, lnz_20d_config, lnz_40d_config]:
    assert config.feat_thresh == 1.
    assert config.forcing_const == 1.25
    Xshape = config.ts_length
    root_score_dir = root_dir
    label = "$\mu_{5}$"
    if "8DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/8DLnzLessData/"
        ts_type = "8DLnz"
    elif "12DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/12DLnzLessData/"
        ts_type = "12DLnz"
    elif "20DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/20DLnzLessData/"
        ts_type = "20DLnz"
    elif "40DLnz" in config.data_path:
        root_score_dir = root_dir + f"ExperimentResults/TSPM_Markovian/40DLnzLessData/"
        ts_type = "40DLnz"
    print(f"Starting {ts_type}\n")
    model_dir = "/".join(config.scoreNet_trained_path.split("/")[:-1]) + "/"
    entered = False
    best_epoch = get_best_epoch(config=config,type="EE")
    for file in os.listdir(model_dir):
        if config.scoreNet_trained_path in os.path.join(model_dir, file) and ("EE" in file and "Trk" not in file) and str(best_epoch) in file:
            good = ConditionalMarkovianTSPostMeanScoreMatching(
        *config.model_parameters)
            entered = True
            good.load_state_dict(torch.load(os.path.join(model_dir, file)))
    assert entered
    try:
        print(best_epoch)
        all_true_paths, all_global_paths = get_best_track_file(root_score_dir=root_score_dir, ts_type=ts_type, best_epoch_track=best_epoch)
    except UnboundLocalError as e:
        all_true_paths, all_global_paths = get_best_track_file(root_score_dir=root_score_dir, ts_type=ts_type, best_epoch_track=best_epoch+1)
    all_true_paths = all_true_paths.reshape(-1, config.ts_length+1, config.ts_dims)
    all_global_paths = all_global_paths.reshape(-1, config.ts_length+1, config.ts_dims)
    all_true_states = all_true_paths.reshape(-1, config.ts_dims)
    all_global_states = torch.tensor(all_global_paths.reshape(-1, config.ts_dims), device="cpu", dtype=torch.float32)
    true_drifts = np.zeros_like(all_true_states)
    for i in range(config.ndims):
        true_drifts[:, i] = (all_true_states[:, (i + 1) % config.ndims] - all_true_states[:, i - 2]) * all_true_states[:, i - 1] - all_true_states[:,i] * config.forcing_const
    drift_ests = experiment_MLP_DDims_drifts(config=config, Xs=all_global_states, good=good, onlyGauss=False)
    drift_ests= drift_ests[:, -1, :, :].reshape(drift_ests.shape[0],drift_ests.shape[2],drift_ests.shape[
                                                                                               -1] * 1).mean(axis=1)

    mse = np.mean(np.sum(np.power(true_drifts - drift_ests,2), axis=-1))
    eval_tracks[ts_type] = mse
    break

Starting 8DLnz

2588
8DLnz


In [29]:
eval_tracks

{'8DLnz': 147.40195,
 '12DLnz': 159.96869,
 '20DLnz': 449.1757,
 '40DLnz': 822.525}