In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim
import torch.distributions.transforms as transform
import matplotlib.pyplot as plt

from utils.ParticleDataset import *
from utils.preprocessing import *
from utils.training import *
from utils.TemporalNF import *

In [None]:
# training
dataset = ParticleDataset(
    "/BrownianMotion2d.mat", 
    root_dir=os.getcwd()+"/utils"+"/data"
)


In [None]:
tgrid = dataset.raw_dataset['X_train'][:, 0]
x1 = dataset.raw_dataset['X_train'][:, 1].reshape(800, 501, -1)
x2 = dataset.raw_dataset['X_train'][:, 2].reshape(800, 501, -1)
for i in range(500):
    sample_x1 = x1[i, :].flatten()
    sample_x2 = x2[i, :].flatten()
    plt.plot(sample_x1, sample_x2)

In [None]:
# create flow blocks
realnvp_blocks = [tAutoregressive, tSwapFlow, tBatchNormFlow]
# create normalizing flow
realnvp_flow = tNormalizingFlow(
    space_dim=dataset.dim, 
    blocks=realnvp_blocks,
    flow_length=2
)
# create log base measure: N(0, 1)
normal_logpdf = torch.distributions.multivariate_normal.MultivariateNormal(
    torch.zeros(2), torch.eye(2)
).log_prob
# include time dimension
log_base_measure = lambda x: normal_logpdf(x[:, 1:])

# train
report = train(
    dataset, 
    realnvp_flow,
    log_base_measure, 
    num_epochs=30,
    batch_size=2**12,
    verbose=True,
    lr=5e-3, 
    use_scheduler=True,
    schedule_rate=0.99999,
    grad_clip=1e+20
)

In [None]:
# evaluate density on grid
x1_grid = np.linspace(-10.0, 10.0, 1000)
x2_grid = np.linspace(-10.0, 10.0, 1000)
x1_mesh, x2_mesh = np.meshgrid(x1_grid, x2_grid)
x_data = np.concatenate([x1_mesh.ravel().reshape(-1, 1), x2_mesh.ravel().reshape(-1, 1)], axis=1)
N = x_data.shape[0]
# query at time t=1.0
t = 1.0 * np.ones(N).reshape(-1, 1)
input_data = torch.FloatTensor(np.concatenate([t, x_data], axis=1))
# map to density
latent = realnvp_flow(input_data)
estimated_density = torch.exp(log_base_measure(latent[0]).reshape(-1, 1)+sum(latent[1]).reshape(-1, 1))

In [None]:
plt.contourf(estimated_density.reshape(1000, 1000).detach().numpy())

In [None]:
plt.contourf(torch.exp(log_base_measure(latent[0])).reshape(-1, 1).reshape(1000, 1000).detach().numpy())