In [1]:
import os
import numpy as np
import torch
from configs.RecursiveVPSDE.Markovian_4DLorenz.recursive_Markovian_PostMeanScore_4DLorenz_T256_H05_tl_110data import \
    get_config as get_config
from configs import project_config
from src.generative_modelling.models.ClassVPSDEDiffusion import VPSDEDiffusion
from src.generative_modelling.models.TimeDependentScoreNetworks.ClassConditionalMarkovianTSPostMeanScoreMatching import \
    ConditionalMarkovianTSPostMeanScoreMatching


In [4]:
config = get_config()
device = "cpu"

In [14]:
diffusion = VPSDEDiffusion(beta_max=config.beta_max, beta_min=config.beta_min)

Nepoch = 960  # config.max_epochs[0]
# Fix the number of training epochs and training loss objective loss
score_model = ConditionalMarkovianTSPostMeanScoreMatching(*config.model_parameters).to(device)
score_model.load_state_dict(torch.load(config.scoreNet_trained_path + "_NEp" + str(Nepoch)))

<All keys matched successfully>

In [39]:
num_paths = 10
t0 = 0.
num_time_steps = 256
deltaT = 1. / 256
t1 = num_time_steps * deltaT
initial_state = np.repeat(np.array(config.initState)[np.newaxis, np.newaxis, :], num_paths, axis=0)
assert (initial_state.shape == (num_paths, 1, config.ndims))
true_states = np.zeros(shape=(num_paths, 1 + num_time_steps, config.ndims))
true_states[:, [0], :] = initial_state

In [40]:
def true_drift(prev, num_paths, config):
    assert (prev.shape == (num_paths, config.ndims))
    drift_X = np.zeros((num_paths, config.ndims))
    for i in range(config.ndims):
        drift_X[:,i] = (prev[:, (i + 1) % config.ndims] - prev[:, i - 2]) * prev[:, i - 1] - prev[:,
                                                                                           i] + config.forcing_const
    return drift_X[:, np.newaxis, :]
# Euler-Maruyama Scheme for Tracking Errors
for i in range(1, num_time_steps+1):
    eps = np.random.randn(num_paths, 1, config.ndims) * np.sqrt(deltaT)
    assert (eps.shape == (num_paths, 1, config.ndims))
    true_states[:, [i], :] = true_states[:, [i - 1], :] \
                             + true_drift(true_states[:, i - 1, :], num_paths=num_paths, config=config) * deltaT \
                             + eps

In [41]:
def score_error_eval(score_model, diffusion, num_paths, Z_0s, prev, config, device):
    num_taus = 50
    Ndiff_discretisation = config.max_diff_steps
    assert (prev.shape == (num_paths, config.ndims))
    vec_Z_0s = torch.stack([Z_0s for _ in range(num_taus)], dim=0).reshape(num_taus*num_paths, 1, config.ndims).to(device)
    conditioner = torch.Tensor(prev[:, np.newaxis, :]).to(device) # TODO: Check this is how we condition wheen D>1
    vec_conditioner = torch.stack([conditioner for _ in range(num_taus)], dim=0).reshape(num_taus*num_paths, 1, config.ndims).to(device)
    vec_Z_taus = diffusion.prior_sampling(shape=(num_taus*num_paths, 1, config.ndims)).to(device)

    # We are going to be evaluating the score ONLY at \tau_{S-1} --> no need for full reverse diffusion!!
    # diffusion_times = torch.linspace(config.sample_eps, 1., config.max_diff_steps)
    # d = diffusion_times[Ndiff_discretisation - 1].to(device)
    # diff_times = torch.Tensor([d]).to(device)
    difftime_idx = Ndiff_discretisation - 1
    errs = np.zeros((num_paths, Ndiff_discretisation, config.ndims))
    while difftime_idx >= 9900:
        diff_times = torch.Tensor([1.]).to(device)
        eff_times = diffusion.get_eff_times(diff_times).to(device)
        vec_diff_times = torch.concat([diff_times for _ in range(num_taus*num_paths)], dim=0).to(device)
        vec_eff_times = torch.concat([torch.concat([eff_times.unsqueeze(-1).unsqueeze(-1) for _ in range(num_taus*num_paths)], dim=0) for _ in range(config.ndims)], dim=-1).to(device)
        score_model.eval()
        with torch.no_grad():
            vec_predicted_score = score_model.forward(times=vec_diff_times, eff_times=vec_eff_times, conditioner=vec_conditioner, inputs=vec_Z_taus)
        vec_scores, vec_drift, vec_diffParam = diffusion.get_conditional_reverse_diffusion(x=vec_Z_taus,
                                                                                    predicted_score=vec_predicted_score,
                                                                                           diff_index=torch.Tensor([int(0)]).to(device),
                                                                                    max_diff_steps=Ndiff_discretisation)
        vec_scores = vec_scores.reshape((num_taus, num_paths, 1, config.ndims)).permute((1,0,2,3))
        assert (vec_scores.shape == (num_paths, num_taus, 1, config.ndims))

        beta_taus = torch.exp(-0.5 * eff_times).to(device)
        sigma_taus = torch.pow(1. - torch.pow(beta_taus, 2), 0.5).to(device)
        exp_upper_score = -(vec_Z_taus - beta_taus*vec_Z_0s) / sigma_taus
        exp_upper_score = exp_upper_score.reshape((num_taus, num_paths, 1, config.ndims)).permute((1,0,2,3))
        assert (exp_upper_score.shape == (num_paths, num_taus, 1, config.ndims))

        errs[:,[difftime_idx], :] = torch.mean(torch.pow(vec_scores-exp_upper_score,2), dim=1)
        vec_z = torch.randn_like(vec_drift).to(device)
        vec_Z_taus = vec_drift + vec_diffParam * vec_z
    return errs

In [42]:
from tqdm import tqdm
errors = np.zeros((num_paths, 1 + num_time_steps, config.max_diff_steps, config.ndims))
for i in tqdm(range(1, num_time_steps+1)):
    erri = score_error_eval(score_model=score_model, diffusion=diffusion, num_paths=num_paths,Z_0s=torch.Tensor(true_states[:, [i],:]).to(device), prev=true_states[:, i-1,:], config=config, device=device)
    errors[:, i, :, :] = erri

  0%|          | 0/256 [50:30<?, ?it/s]


KeyboardInterrupt: 