# Experiments using Temporal Normalizing Flow

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim
import torch.nn.init as init
import torch.distributions.transforms as transform
import torch.nn.functional as functional
import matplotlib.pyplot as plt
from matplotlib import cm
torch.set_default_dtype(torch.float64)
# set random seed
SEED_ = 10
np.random.seed(SEED_)
torch.manual_seed(SEED_)

from sklearn import datasets

import scipy
import scipy.integrate
import h5py

# custom packages
from utils.VanillaNF import *

1 dimensional autoregressive flow (not working)

In [None]:
raw_mask = [
    [0.0],
    [0.0]
]
raw_mask = [torch.nn.Parameter(
    torch.Tensor(m), requires_grad=False
) for m in raw_mask]

masks = torch.nn.ParameterList(
    raw_mask
)
hidden_dim = 64
in_dim = 1
out_dim = in_dim

_affine_flow_wrapper = lambda mask: AffineCouplingFlow(
    in_dim=in_dim, hidden_dim=hidden_dim, out_dim=out_dim, 
    n_layers=4, activation=torch.nn.ReLU, mask=mask
)
realnvp_blocks = []
for i in range(len(raw_mask)):
    realnvp_blocks.append(_affine_flow_wrapper(raw_mask[i]))
# realnvp_blocks.append(
#     VanillaNormFlow(in_dim, out_dim, scaling=5.0)
# )
realnvp_blocks.append(
    tBatchNormFlow(in_dim)
)
realNVP = NormalizingFlow(
    realnvp_blocks, flow_length=1
)
if torch.cuda.device_count():
    realNVP = realNVP.cuda()
device = next(realNVP.parameters()).device

In [None]:
# generate data from a 3-modal Gaussian
n_samples = 2**15
mixture_data = np.zeros([n_samples, 1])

for i in range(n_samples):
    tmp = np.random.rand()
    if tmp < 1/3:
        mixture_data[i] = 0.5*np.random.randn()+(-2.0)
    elif 1/3 <= tmp <= 2/3:
        mixture_data[i] = np.random.randn()
    else:
        mixture_data[i] = 0.5*np.random.randn()+(2.0)

In [None]:
# optimize
optimizer = torch.optim.Adam(realNVP.parameters(), lr = 0.0001)
num_steps = 100
batch_size = 2**7
num_batches = int(np.floor(n_samples/batch_size))
for idx_step in range(num_steps):
    # shuffle data
    idx = np.arange(n_samples)
    np.random.shuffle(idx)
    mixture_data = mixture_data[idx, :]
    iter_loss = 0
    for j in range(num_batches):
        batch_idx = np.arange(j*batch_size, (j+1)*batch_size)
        # get batch
        X = mixture_data[batch_idx, :]
        X = torch.Tensor(X).to(device = device)

        ## transform data X to latent space Z
        z, logdet = realNVP.inverse(X)

        ## calculate the negative loglikelihood of X
        loss = torch.log(z.new_tensor([2*np.pi])) + torch.mean(torch.sum(0.5*z**2, -1) - logdet)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        iter_loss += loss.item()
        if j % 500 == 0:
            print(loss.item())
    if (idx_step + 1) % 2 == 0:
        print(f"idx_steps: {idx_step:}, loss: {iter_loss:.5f}")

2d stochastic Van der Pol

In [2]:
# time dimension is not transformed
raw_mask = [[1.0, 1.0, 0.0],
             [1.0, 0.0, 1.0],
             [1.0, 1.0, 0.0],         
             [1.0, 0.0, 1.0],
             [1.0, 1.0, 0.0],         
             [1.0, 0.0, 1.0],
             [1.0, 1.0, 0.0],
             [1.0, 0.0, 1.0]]

raw_mask = [torch.nn.Parameter(
    torch.Tensor(m), requires_grad=False
) for m in raw_mask]


masks = torch.nn.ParameterList(
    raw_mask
)
hidden_dim = 64
in_dim = 2
out_dim = in_dim
# create blocks for RealNVP
_affine_flow_wrapper = lambda mask: tAffineCouplingFlow(
    in_dim=in_dim, hidden_dim=hidden_dim, out_dim=out_dim, 
    n_layers=6, activation=torch.nn.ReLU, mask=mask
)

realnvp_blocks = []
for i in range(len(raw_mask)):
    realnvp_blocks.append(_affine_flow_wrapper(raw_mask[i]))
realnvp_blocks.append(tVanillaNormFlow(in_dim, out_dim, scaling=5.0))
        
# create realnvp
realNVP = NormalizingFlow(
    realnvp_blocks, flow_length=1
)
if torch.cuda.device_count():
    realNVP = realNVP.cuda()
device = next(realNVP.parameters()).device

In [None]:
# simultate van der pol with random normal initial condition
def rhs(z, t, mu=1.0):
    """ 
        Right hand side of the van der pol oscillator.
        Formula taken from Tyler's nonlocal paper.
    """
    dx = z[1]
    dy = mu * (1 - z[0]**2)*z[1] - z[0]
    dzdt = [dx, dy]
    return dzdt

In [12]:
load_simulated = True
num_paths = 500
t_end = 20.0
nt = 500+1
tgrid = np.linspace(0.0, t_end, nt)
dt = tgrid[1]-tgrid[0]
paths = np.zeros([num_paths, nt, 3])
mean = np.array([1, 0])
covmat = 0.01*np.eye(2)
if not load_simulated:
    for i in range(num_paths):
        z0 = np.random.multivariate_normal(mean, covmat)
        sol = scipy.integrate.odeint(rhs, z0, tgrid)
        paths[i, :, 0] = tgrid
        paths[i, :, 1:] = sol
else:
    # load pre-generated .mat file
    with h5py.File('VanderPol.mat', 'r') as f:
        data = f['data']

In [None]:
for i in range(10):
    plt.figure(1);
    plt.plot(paths[:, :, 1:][i, :, 0], paths[:, :, 1:][i, :, 1]);

In [None]:
# train using RealNVP
optimizer = torch.optim.Adam(realNVP.parameters(), lr = 0.0005)
num_steps = 100

## the following loop learns the RealNVP_2D model by data
## in each loop, data is dynamically sampled from the scipy moon dataset
for idx_step in range(num_steps):
    # shuffle paths
    idx = np.arange(num_paths)
    np.random.shuffle(idx)
    paths = paths[idx, :, :]
    iter_loss = 0
    for j in range(num_paths):
        ## get a path from all paths
        X = paths[j, :, :]
        X = torch.Tensor(X).to(device = device)

        ## transform data X to latent space Z
        z, logdet = realNVP.inverse(X)

        ## calculate the negative loglikelihood of X
        loss = torch.log(torch.tensor(100.0)) + torch.log(z.new_tensor([2*np.pi])) + \
            torch.mean(torch.sum(0.5*100.0*(z[:, 1:]-torch.tensor([1,0]))**2, -1) - logdet)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        iter_loss += loss.item()
        if j % 500 == 0:
            print(loss.item())
    if (idx_step + 1) % 2 == 0:
        print(f"idx_steps: {idx_step:}, loss: {iter_loss:.5f}")


In [None]:
x1_min, x1_max = -5.0, 5.0
x2_min, x2_max = -5.0, 5.0
N = 200
x1_grid = np.linspace(x1_min, x1_max, N)
x2_grid = np.linspace(x2_min, x2_max, N)

dx = x1_grid[1]-x1_grid[0]
assert dx == x2_grid[2]-x2_grid[1]
# meshgrid
x1_mesh, x2_mesh = np.meshgrid(x1_grid, x2_grid)
# get list of coordinates
x_data = np.concatenate((x1_mesh.ravel().reshape(-1,1), x2_mesh.ravel().reshape(-1,1)), axis=1)
x_data = torch.tensor(x_data)


t = 2.0
xt_data = torch.cat([torch.Tensor.repeat(torch.tensor(t), x_data.shape[0]).reshape(-1, 1), x_data], 1)
zt_data, jac = realNVP.inverse(xt_data)

p_x = torch.exp( torch.log(torch.tensor(100.0)) + torch.log(2*torch.tensor(torch.pi)) - \
            torch.sum(0.5*100.0*(zt_data[:, 1:]-torch.tensor([1,0]))**2, -1) - jac ).reshape(N, N).detach().numpy()

plt.contourf(x1_mesh, x2_mesh, p_x);

In [None]:
# evaluate solution for a range of times
err_tgrid = np.arange(0.0, 10.0, 0.1)
err_nt = len(err_tgrid)
all_preds = np.zeros([err_nt, N, N])
for i in range(err_nt):
    t = err_tgrid[i]
    if i % 10 == 0:
        print("Time step = {}".format(t))
    x1_min, x1_max = -5.0, 5.0
    x2_min, x2_max = -5.0, 5.0
    N = 200
    x1_grid = np.linspace(x1_min, x1_max, N)
    x2_grid = np.linspace(x2_min, x2_max, N)

    dx = x1_grid[1]-x1_grid[0]
    assert dx == x2_grid[2]-x2_grid[1]
    # meshgrid
    x1_mesh, x2_mesh = np.meshgrid(x1_grid, x2_grid)
    # get list of coordinates
    x_data = np.concatenate((x1_mesh.ravel().reshape(-1,1), x2_mesh.ravel().reshape(-1,1)), axis=1)
    x_data = torch.tensor(x_data)

    xt_data = torch.cat([torch.Tensor.repeat(torch.tensor(t), x_data.shape[0]).reshape(-1, 1), x_data], 1)
    zt_data, jac = realNVP.inverse(xt_data)

    p_x = torch.exp( torch.log(torch.tensor(100.0)) + torch.log(2*torch.tensor(torch.pi)) - \
                torch.sum(0.5*100.0*(zt_data[:, 1:]-torch.tensor([1,0]))**2, -1) - jac ).reshape(N, N).detach().numpy()
    # replace NaN's with 0.0
    #p_x = np.nan_to_num(p_x)
    # save un-normalized prediction
    all_preds[i, :, :] = p_x
    # divide by constant to integrate to 1
    int_p_x = np.trapz(np.trapz(p_x, dx=dx), dx=dx)
    p_x = p_x / int_p_x

In [None]:
# plot all graphs
%matplotlib inline
import time
from IPython import display
for i in range(0, err_nt):
    plt.figure(1)
    # display predicted solution i
    plot_density = all_preds[i, :, :]
    plt.contourf(x1_mesh, x2_mesh, plot_density);
    plt.title("Time = {}".format(err_tgrid[i]))
    if err_tgrid[i]>20:
        plt.title("Time = {}, out of sample".format(err_tgrid[i]))
    display.clear_output(wait=True)
    display.display(plt.gcf());
    plt.close();