In [1]:
import os
import numpy as np
import torch
from configs.RecursiveVPSDE.Markovian_fQuadSin.recursive_Markovian_fQuadSinWithPosition_T256_H05_tl_5data import get_config as get_config
from configs.RecursiveVPSDE.Markovian_fQuadSin.recursive_Markovian_PostMeanScore_fQuadSin_T256_H05_tl_5data import get_config as get_config
from tqdm import tqdm
from src.generative_modelling.models.ClassVPSDEDiffusion import VPSDEDiffusion
from src.generative_modelling.models.TimeDependentScoreNetworks.ClassConditionalMarkovianTSPostMeanScoreMatching import \
    ConditionalMarkovianTSPostMeanScoreMatching
from src.generative_modelling.models.TimeDependentScoreNetworks.ClassConditionalMarkovianTSScoreMatching import \
    ConditionalMarkovianTSScoreMatching
import time

In [2]:
config = get_config()
print("Beta Min : ", config.beta_min)
if config.has_cuda:
    device = int(os.environ["LOCAL_RANK"])
else:
    print("Using CPU\n")
    device = torch.device("cpu")

diffusion = VPSDEDiffusion(beta_max=config.beta_max, beta_min=config.beta_min)

max_diff_steps = config.max_diff_steps
sample_eps = config.sample_eps
ts_step = 1 / config.ts_length

Nepoch = 1440  # config.max_epochs[0]
# Fix the number of training epochs and training loss objective loss
if "PM" in config.scoreNet_trained_path:
    PM = ConditionalMarkovianTSPostMeanScoreMatching(*config.model_parameters).to(device)
else:
    PM = ConditionalMarkovianTSScoreMatching(*config.model_parameters).to(device)
PM.load_state_dict(torch.load(config.scoreNet_trained_path + "_NEp" + str(Nepoch)))

# In[23]:


Xshape = config.ts_length
num_taus = 3
num_diff_times = config.max_diff_steps
Ndiff_discretisation = config.max_diff_steps
diffusion_times = torch.linspace(start=config.sample_eps, end=config.end_diff_time,
                                 steps=Ndiff_discretisation).to(device)
mu_hats_mean = np.zeros((Xshape, num_taus))
mu_hats_std = np.zeros((Xshape, num_taus))

Xs = torch.linspace(-3, 3, steps=Xshape).unsqueeze(-1).unsqueeze(-1).permute(1, 0, 2).to(device)
conditioner = torch.stack([Xs for _ in range(1)], dim=0).reshape(Xshape * 1, 1, -1)
B, T = Xshape, 1
mu_hats = np.zeros((Xshape, num_diff_times, num_taus))  # Xvalues, DiffTimes, Ztaus
vec_mu_hats = np.zeros((Xshape, num_diff_times, num_taus))  # Xvalues, DiffTimes, Ztaus
final_vec_mu_hats = np.zeros((Xshape, num_diff_times, num_taus))
PM.eval()
print(mu_hats.shape)

Beta Min :  0.0
Using CPU

(256, 1000, 3)


In [154]:
vec_Z_taus = diffusion.prior_sampling(shape=(Xshape*num_taus, 1, 1)).to(device)
vec_conditioner = torch.stack([conditioner for _ in range(num_taus)], dim=0).reshape(num_taus*Xshape, 1, 1)
for k in tqdm(range(num_taus)):
    difftime_idx = num_diff_times - 1
    #Z_taus = diffusion.prior_sampling(shape=(Xshape, 1, 1)).to(device)
    Z_taus = vec_Z_taus[k*Xshape:(k+1)*Xshape, :, :]
    ts = []
    while difftime_idx >= 0:
        t0 = time.time()
        d = diffusion_times[Ndiff_discretisation - (num_diff_times - 1 - difftime_idx) - 1].to(device)
        # I (will) have a RV for each x (there are B of them) and hence need a diffusion time for each one
        diff_times = torch.stack([d for _ in range(B)]).reshape(B * T).to(device)
        eff_times = diffusion.get_eff_times(diff_times=diff_times).unsqueeze(-1).unsqueeze(-1).to(device)
        vec_diff_times = torch.stack([diff_times for _ in range(num_taus)], dim=0).reshape(num_taus*Xshape)
        vec_eff_times = torch.stack([eff_times for _ in range(num_taus)], dim=0).reshape(num_taus*Xshape, 1, 1)
    
        with torch.no_grad():
            if "PM" in config.scoreNet_trained_path:
                vec_predicted_score = PM.forward(inputs=vec_Z_taus, times=vec_diff_times, conditioner=vec_conditioner,
                                             eff_times=vec_eff_times)
                predicted_score = PM.forward(inputs=Z_taus, times=diff_times, conditioner=conditioner,
                                             eff_times=eff_times)
            else:
                predicted_score = PM.forward(inputs=Z_taus, times=diff_times, conditioner=conditioner)
                vec_predicted_score = PM.forward(inputs=vec_Z_taus, times=vec_diff_times, conditioner=vec_conditioner)
            scores, drift, diffParam = diffusion.get_conditional_reverse_diffusion(x=Z_taus,
                                                                                   predicted_score=predicted_score,
                                                                                   diff_index=torch.Tensor(
                                                                                       [int((
                                                                                               num_diff_times - 1 - difftime_idx))]).to(
                                                                                       device),
                                                                                   max_diff_steps=Ndiff_discretisation)
            vec_scores, vec_drift, vec_diffParam = diffusion.get_conditional_reverse_diffusion(x=vec_Z_taus,
                                                                                   predicted_score=vec_predicted_score,
                                                                                   diff_index=torch.Tensor(
                                                                                       [int((
                                                                                               num_diff_times - 1 - difftime_idx))]).to(
                                                                                       device),
                                                                                   max_diff_steps=Ndiff_discretisation)
            scores = vec_scores[k*Xshape:(k+1)*Xshape,:,:]
            drift = vec_drift[k*Xshape:(k+1)*Xshape,:,:]
            diffParam = vec_diffParam
        # assert np.allclose((scores- predicted_score).detach(), 0)
        beta_taus = torch.exp(-0.5 * d).to(device)
        sigma_taus = torch.pow(1. - torch.pow(d, 2), 0.5).to(device)
        for i in range(Xshape):
            Zts = Z_taus[i, :, :]
            Ss = scores[i, :, :]
            mu_hat = Zts / (ts_step * beta_taus) + (
                    (torch.pow(sigma_taus, 2) + (torch.pow(beta_taus, 2) * ts_step)) / (ts_step * beta_taus)) * Ss
            mu_hats[i, difftime_idx, k] = mu_hat[0, 0].cpu().detach().numpy()

        vec_per_tau_Zs = (Z_taus/(ts_step * beta_taus)) + ( (
                    (torch.pow(sigma_taus, 2) + (torch.pow(beta_taus, 2) * ts_step)) / (ts_step * beta_taus)) * scores)
        assert np.allclose(mu_hats[:, difftime_idx, k] , vec_per_tau_Zs[:, 0, 0])
        for s in range(num_taus):
            vec_Zts = vec_Z_taus[s*Xshape:(s+1)*Xshape, :, :]
            vec_Ss = vec_scores[s*Xshape:(s+1)*Xshape, :, :]
            vec_per_tau_Zs = (vec_Zts/(ts_step * beta_taus)) + ( (
                    (torch.pow(sigma_taus, 2) + (torch.pow(beta_taus, 2) * ts_step)) / (ts_step * beta_taus)) * vec_Ss)
            vec_mu_hats[:, difftime_idx, k] = vec_per_tau_Zs[:, 0, 0]
            if s == k:
                break
        assert np.allclose(vec_mu_hats[:, difftime_idx, k] , mu_hats[:, difftime_idx, k])
        final_mu_hats = (vec_Z_taus/(ts_step * beta_taus)) + ( (
                    (torch.pow(sigma_taus, 2) + (torch.pow(beta_taus, 2) * ts_step)) / (ts_step * beta_taus)) * vec_scores)
        assert np.allclose(mu_hats[:, difftime_idx, k] , final_mu_hats[k*Xshape:(k+1)*Xshape, 0, 0])
        final_vec_mu_hats[:, difftime_idx, :] = final_mu_hats.reshape((num_taus, Xshape)).T.cpu().numpy()
        assert np.allclose(mu_hats[:, difftime_idx, k] , final_vec_mu_hats[:, difftime_idx, k])
        z = torch.randn_like(drift).to(device)
        vec_z = torch.randn_like(vec_drift).to(device)
        z = vec_z[k*Xshape:(k+1)*Xshape, :, :]
        Z_taus = drift + diffParam * z
        vec_Z_taus = vec_drift + vec_diffParam * vec_z
        difftime_idx -= 1
        ts.append(time.time()-t0)
        print(np.mean(ts), difftime_idx)


  0%|          | 0/3 [00:00<?, ?it/s]

0.05420398712158203 998
0.04590559005737305 997
0.04370776812235514 996
0.04269975423812866 995
0.04180941581726074 994
0.04088219006856283 993
0.04081099373953683 992
0.041411370038986206 991
0.041004339853922524 990
0.040916037559509275 989
0.040608210997147995 988
0.04042768478393555 987
0.040728477331308216 986
0.040788003376552036 985
0.04053867657979329 984
0.04028363525867462 983
0.04005712621352252 982
0.03988801108466254 981
0.03979132049962094 980
0.039827609062194826 979
0.039707728794642856 978
0.03956287557428533 977
0.039684565170951515 976
0.03993533054987589 975
0.04008262634277344 974
0.04019632706275353 973
0.04028464246679236 972
0.04022344521113804 971
0.040209564669378875 970
0.040232006708780924 969
0.04024604059034778 968
0.04019559919834137 967
0.04016900785041578 966
0.04011474637424245 965
0.040118013109479635 964
0.04024598333570692 963
0.040192095009056295 962
0.04012491201099597 961
0.04002117499327048 960
0.03997281789779663 959
0.039952167650548424 958
0.

 33%|███▎      | 1/3 [00:40<01:20, 40.10s/it]

0.04000504596813305 0
0.040005228996276856 -1
0.045031070709228516 998
0.04196012020111084 997
0.04150676727294922 996
0.04072660207748413 995
0.04131011962890625 994
0.04137277603149414 993
0.04133994238717215 992
0.04107847809791565 991
0.04106073909335666 990
0.04126617908477783 989
0.0411381721496582 988
0.040907482306162514 987
0.040971572582538314 986
0.040900894573756626 985
0.041247288386027016 984
0.04110740125179291 983
0.04114560519947725 982
0.0409980747434828 981
0.04099212194743909 980
0.04090770483016968 979
0.040749913170224146 978
0.04063291983170943 977
0.04048566196275794 976
0.04039145509401957 975
0.04055720329284668 974
0.04054023669316219 973
0.04053310994748716 972
0.04047114508492606 971
0.04047075633344979 970
0.04049873352050781 969
0.04060932897752331 968
0.04066768288612366 967
0.040725419015595406 966
0.040818340638104605 965
0.040881899424961635 964
0.040929820802476674 963
0.04088109248393291 962
0.04081840891587107 961
0.04080350582416241 960
0.04074442

 33%|███▎      | 1/3 [00:45<01:31, 45.60s/it]


KeyboardInterrupt: 

In [4]:
vec_Z_taus = diffusion.prior_sampling(shape=(Xshape*num_taus, 1, 1)).to(device)
difftime_idx = num_diff_times - 1
ts = []
while difftime_idx >= 0:
    d = diffusion_times[Ndiff_discretisation - (num_diff_times - 1 - difftime_idx) - 1].to(device)
    # I (will) have a RV for each x (there are B of them) and hence need a diffusion time for each one
    diff_times = torch.stack([d for _ in range(B)]).reshape(B * T).to(device)
    eff_times = diffusion.get_eff_times(diff_times=diff_times).unsqueeze(-1).unsqueeze(-1).to(device)
    vec_diff_times = torch.stack([diff_times for _ in range(num_taus)], dim=0).reshape(num_taus*Xshape)
    vec_eff_times = torch.stack([eff_times for _ in range(num_taus)], dim=0).reshape(num_taus*Xshape, 1, 1)
    vec_conditioner = torch.stack([conditioner for _ in range(num_taus)], dim=0).reshape(num_taus*Xshape, 1, 1)

    with torch.no_grad():
        if "PM" in config.scoreNet_trained_path:
            vec_predicted_score = PM.forward(inputs=vec_Z_taus, times=vec_diff_times, conditioner=vec_conditioner,
                                         eff_times=vec_eff_times)
        else:
            predicted_score = PM.forward(inputs=vec_Z_taus, times=vec_diff_times, conditioner=vec_conditioner)
        vec_scores, vec_drift, vec_diffParam = diffusion.get_conditional_reverse_diffusion(x=vec_Z_taus,
                                                                               predicted_score=vec_predicted_score,
                                                                               diff_index=torch.Tensor(
                                                                                   [int((
                                                                                           num_diff_times - 1 - difftime_idx))]).to(
                                                                                   device),
                                                                               max_diff_steps=Ndiff_discretisation)
    # assert np.allclose((scores- predicted_score).detach(), 0)
    beta_taus = torch.exp(-0.5 * d).to(device)
    sigma_taus = torch.pow(1. - torch.pow(d, 2), 0.5).to(device)
    final_mu_hats = (vec_Z_taus/(ts_step * beta_taus)) + ( (
                (torch.pow(sigma_taus, 2) + (torch.pow(beta_taus, 2) * ts_step)) / (ts_step * beta_taus)) * vec_scores)
    #print(vec_Z_taus.shape, vec_scores.shape)
    final_vec_mu_hats[:, difftime_idx, :] = final_mu_hats.reshape((num_taus, Xshape)).T.cpu().numpy()
    vec_z = torch.randn_like(vec_drift).to(device)
    vec_Z_taus = vec_drift + vec_diffParam * vec_z
    difftime_idx -= 1

