# Experiments using Temporal Normalizing Flow

## Training 2d Brownian particle

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim
import torch.nn.init as init
import torch.distributions.transforms as transform
import torch.nn.functional as functional
import matplotlib.pyplot as plt
from matplotlib import cm
torch.set_default_dtype(torch.float64)
# set random seed
SEED_ = 10
np.random.seed(SEED_)
torch.manual_seed(SEED_)

from sklearn import datasets

# custom packages
from utils.VanillaNF import *

In [None]:
# time dimension is not transformed
raw_mask = [[1.0, 1.0, 0.0],
             [1.0, 0.0, 1.0],
             [1.0, 1.0, 0.0],         
             [1.0, 0.0, 1.0],
             [1.0, 1.0, 0.0],         
             [1.0, 0.0, 1.0],
             [1.0, 1.0, 0.0],
             [1.0, 0.0, 1.0]]

raw_mask = [torch.nn.Parameter(
    torch.Tensor(m), requires_grad=False
) for m in raw_mask]




masks = torch.nn.ParameterList(
    raw_mask
)
hidden_dim = 64
in_dim = 2
out_dim = in_dim

In [None]:
# create blocks for RealNVP
_affine_flow_wrapper = lambda mask: tAffineCouplingFlow(
    in_dim=in_dim, hidden_dim=hidden_dim, out_dim=out_dim, 
    n_layers=6, activation=torch.nn.ReLU, mask=mask
)
realnvp_blocks = []
for i in range(len(raw_mask)):
    realnvp_blocks.append(_affine_flow_wrapper(raw_mask[i]))
    if (i + 1) % 2 == 0:
        realnvp_blocks.append(tVanillaNormFlow(in_dim, out_dim))

realnvp_blocks.append(
    tVanillaNormFlow(in_dim, out_dim)
)

# create realnvp
realNVP = NormalizingFlow(
    realnvp_blocks, flow_length=1
)
if torch.cuda.device_count():
    realNVP = realNVP.cuda()
device = next(realNVP.parameters()).device

In [None]:
# simulate n-d Brownian motion
num_paths = 1000
t_start, t_end = 0.0, 5.0
# discretize
nt = 100+1
tgrid = np.linspace(t_start, t_end, nt)
dt = tgrid[1]-tgrid[0]
# record trajectories along with time

# number of dimensions
d = 2
dims = [num_paths] + [nt] + [1+d]
paths = np.zeros(dims)
for i in range(num_paths):
    # initial condition is dirac at (0,0)
    paths[i, 0, 1:] = 0.0
    for j in range(1, nt):
        t_j = tgrid[j]
        paths[i, j, 0] = t_j
        # independent increment
        paths[i, j, 1:] = paths[i, j-1, 1:] + np.sqrt(dt)*np.random.randn(d)

In [None]:
for i in range(100):
    plt.figure(1)
    plt.plot(paths[i, :][:, 1], paths[i, :][:, 2])
plt.show()

In [None]:
# train using RealNVP
optimizer = torch.optim.Adam(realNVP.parameters(), lr = 0.0005)
num_steps = 100

## the following loop learns the RealNVP_2D model by data
## in each loop, data is dynamically sampled from the scipy moon dataset
for idx_step in range(num_steps):
    # shuffle paths
    idx = np.arange(num_paths)
    np.random.shuffle(idx)
    paths = paths[idx, :, :]
    iter_loss = 0
    for j in range(num_paths):
        ## get a path from all paths
        X = paths[j, :, :]
        X = torch.Tensor(X).to(device = device)

        ## transform data X to latent space Z
        z, logdet = realNVP.inverse(X)

        ## calculate the negative loglikelihood of X
        loss = torch.log(z.new_tensor([2*np.pi])) + torch.mean(torch.sum(0.5*z**2, -1) - logdet)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        iter_loss += loss.item()
        if j % 500 == 0:
            print(loss.item())
    if (idx_step + 1) % 2 == 0:
        print(f"idx_steps: {idx_step:}, loss: {iter_loss:.5f}")


In [None]:
# sampling
dims = [1000] + [nt] + [1+d]
new_z = np.zeros(dims)
for i in range(10):
    # initial condition is normal
    new_z[i, :, 0] = tgrid
    new_z[i, :, 1:] = np.random.randn(nt, d)

In [None]:
new_x = np.zeros(dims)
for i in range(10):
    new_x[i, :, :] = realNVP(torch.Tensor(new_z[i, :, :]))[0].detach().numpy()

In [None]:
for i in range(1):
    plt.figure(1)
    plt.plot(new_x[i, :][:, 1], new_x[i, :][:, 2])
plt.show()

In [None]:
# evaluate density at different times
x1_min, x1_max = -6.0, 6.0
x2_min, x2_max = -6.0, 6.0
N = 200
x1_grid = np.linspace(x1_min, x1_max, N)
x2_grid = np.linspace(x2_min, x2_max, N)

dx = x1_grid[1]-x1_grid[0]
assert dx == x2_grid[2]-x2_grid[1]
# meshgrid
x1_mesh, x2_mesh = np.meshgrid(x1_grid, x2_grid)
# get list of coordinates
x_data = np.concatenate((x1_mesh.ravel().reshape(-1,1), x2_mesh.ravel().reshape(-1,1)), axis=1)
x_data = torch.tensor(x_data)
t = 0.5
xt_data = torch.cat([torch.Tensor.repeat(torch.tensor(t), x_data.shape[0]).reshape(-1, 1), x_data], 1)
zt_data, jac = realNVP.inverse(xt_data)
p_x = torch.exp(-(torch.sum(0.5*zt_data**2, -1) - jac + torch.log(2*torch.tensor(torch.pi)))).reshape(N, N)

In [None]:
realNVP.inverse(xt_data[0, :].reshape(1, -1))

In [None]:
realNVP.bijectors[6].inverse(xt_data[0, :].reshape(1, -1))

In [None]:
realNVP.bijectors[5].inverse(realNVP.bijectors[6].inverse(xt_data[4, :].reshape(1, -1))[0])

In [None]:
p_x = p_x.detach().numpy()
#p_x = np.nan_to_num(p_x)
plt.contourf(x1_mesh, x2_mesh, p_x);

In [None]:
int_p_x = np.trapz(np.trapz(p_x, dx=dx), dx=dx)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.plot_surface(x1_mesh, x2_mesh, p_x/int_p_x, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

In [None]:
(p_x/int_p_x).max()

In [None]:
# solve the diffusion directly with dirac initial condition
x1_grid, x2_grid;
x1_mesh, x2_mesh;
nt_pde = 10001
tgrid_pde = np.linspace(t_start, t_end, nt_pde)
dt_pde = tgrid_pde[1]-tgrid_pde[0]
# with dirac initial condition, solution to Fokker-Planck is closed
# https://math.stackexchange.com/questions/3924499/analytical-solution-to-2d-diffusion-equation-with-a-drift-term
def analytic_solution(t, x, y):
    return (1/(2*np.pi*t))*np.exp(-(x**2+y**2)/(2*t))

In [None]:
fig2, ax2 = plt.subplots(subplot_kw={"projection": "3d"})
t_query = 0.5
ax2.plot_surface(x1_mesh, x2_mesh, analytic_solution(t_query, x1_mesh, x2_mesh), cmap=cm.coolwarm, 
                linewidth=0.0, antialiased=False)


In [None]:
analytic_solution(t_query, x1_mesh, x2_mesh).max()

In [None]:
(((p_x/int_p_x)-analytic_solution(t_query, x1_mesh, x2_mesh))**2).max()

In [None]:
# compute L^2 error for a range of times
err_tgrid = np.arange(0.025, 6.025, 0.025)
err_nt = len(err_tgrid)
l2_error = np.zeros(err_nt)
rel_l2_error = np.zeros(err_nt)
all_preds = np.zeros([err_nt, N, N])
for i in range(err_nt):
    t = err_tgrid[i]
    if i % 10 == 0:
        print("Time step = {}".format(t))
    # predict with NF
    x1_min, x1_max = -6.0, 6.0
    x2_min, x2_max = -6.0, 6.0
    N = 200
    x1_grid = np.linspace(x1_min, x1_max, N)
    x2_grid = np.linspace(x2_min, x2_max, N)

    dx = x1_grid[1]-x1_grid[0]
    assert dx == x2_grid[2]-x2_grid[1]
    # meshgrid
    x1_mesh, x2_mesh = np.meshgrid(x1_grid, x2_grid)
    # get list of coordinates
    x_data = np.concatenate((x1_mesh.ravel().reshape(-1,1), x2_mesh.ravel().reshape(-1,1)), axis=1)
    x_data = torch.tensor(x_data)
    xt_data = torch.cat([torch.Tensor.repeat(torch.tensor(t), x_data.shape[0]).reshape(-1, 1), x_data], 1)
    zt_data, jac = realNVP.inverse(xt_data)
    p_x = torch.exp(-(torch.sum(0.5*zt_data**2, -1) - jac + torch.log(2*torch.tensor(torch.pi)))).reshape(N, N)
    p_x = p_x.detach().numpy()
    # replace NaN's with 0.0
    p_x = np.nan_to_num(p_x)
    # save un-normalized prediction
    all_preds[i, :, :] = p_x
    # divide by constant to integrate to 1
    int_p_x = np.trapz(np.trapz(p_x, dx=dx), dx=dx)
    p_x = p_x / int_p_x
    
    
    # compare with analytic
    p_x_analytic = analytic_solution(t, x1_mesh, x2_mesh)
    # divide bby constant to integrate to 1
    int_p_x_analytic = np.trapz(np.trapz(p_x_analytic, dx=dx), dx=dx)
    p_x_analytic = p_x_analytic / int_p_x_analytic
    # compare error between predicted and approximated
    l2_error[i] = ((p_x - p_x_analytic)**2).sum() * dx * dx
    rel_l2_error[i] = l2_error[i] / ((p_x_analytic**2).sum() * dx * dx)
    if i % 10 == 0:
        print(rel_l2_error[i])

In [None]:
plt.plot(err_tgrid[39:], l2_error[39:])

In [None]:
plt.plot(err_tgrid[39:], rel_l2_error[39:])

In [None]:
# plot all graphs
%matplotlib inline
import time
from IPython import display
for i in range(0, err_nt):
    plt.figure(1)
    # display predicted solution i
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    # normalize the density
    plot_density = all_preds[i, :, :]
    int_density = np.trapz(np.trapz(plot_density, dx=dx), dx=dx)
    plot_density = plot_density / int_density
    ax[0].contourf(x1_mesh, x2_mesh, plot_density);
    ax[0].set_title(r"$t = {}$".format(np.round(err_tgrid[i], 3)))
    ax[1].contourf(x1_mesh, x2_mesh, analytic_solution(err_tgrid[i], x1_mesh, x2_mesh));
    ax[1].set_title(r"$t = {}$".format(np.round(err_tgrid[i], 3)))
    ax[2].plot(err_tgrid[0:i], rel_l2_error[0:i], color='red', lw=1.5)
    display.clear_output(wait=True)
    display.display(plt.gcf());
plt.close();

In [None]:
torch.save(realNVP.state_dict(), "trained_models/Brownian_Motion_100")