## Variational Inference with Tensor-Train
The general problem of Variational Inference (VI) is outlined as follows. Consider the joint density:
$$
    p(z,x) = p(z)\cdot p(x|z)
$$ where $z$ is latent variables and $x$ is observations. We would like to obtain an approximation to the posterior $p(z|x)$.

We consider a family of probability densities $q(z)$ as an approximation to the true posterior $p(z|x)$. And minimize the KL (Kullback-Liebner) divergence:
$$
    \min_q \text{KL}(q(z) || p(z|x))
$$

In the normalizing flow setting, $p(z|x)$ is often the target density, $q(z)$ is the approximate density, often parametrized by some parameter class $\theta$. We rewrite:
$$
    \min_{\theta} \text{KL}(q(x;\theta) || p(x)) = \min_{\theta}\mathbb{E}_{x\sim q(x;\theta)}(\log q(x;\theta) - \log p(x))
$$








We assume we have the ability to evaluate $p(x)$, but not necessarily sample from it. Furthermore, to get $x\sim q(x;\theta)$, the normalizing flow defines the transformation:
$$
    x = T(z) 
$$ with $z\sim s(z)$ for some simple base distribution $s(z)$.

By the change of variables formula:
$$
    q(x;\theta) = s(z)\cdot |\det J_T(z)|^{-1}
$$ where $J_T$ is the Jacobian of $T$ evaluated at $z$. This means we also need (and need only) the ability to sample from $z\sim s(z)$. 

We have:
$$
    \min_{\theta}\mathbb{E}_{z\sim s(z)}(\log s(z) - \log|\det J_T(z)| - \log p(x))) \approx \min_{\theta}\frac{1}{N}\sum_{i=1}^N\bigg[\big(
    \log s(z_i) - \log|\det J_{T_{\theta}}(z_i)|
    \big) - \log p(x_i)\bigg]
$$ where the last equality takes a sample dataset $\{z_i\}_{i=1}^N$ from the distribution $s(z)$. And $x_i = T(z_i)$. In particular, we do not require the ability to invert $T$.

## Training Considerations

Consistent with the derivations above, we have the algorithm for training:


* Initialize flow model $T$, parametrized by parameter class $\theta$

* for $M$ epochs:

    * draw $N$ sample data points $\{z_i\}_{i=1}^N$ from $s(z)$
    
        * evaluate $\sum_{i=1}^N\log s(z_i)$
    
    * flow the dataset to obtain $x_i = T(z_i)$
        * evaluate $\sum_{i=1}^N |\det J_{T_{\theta}}(z_i)|$
        
        * evaluate $\sum_{i=1}^N p(x_i)$
    * compute loss $KL = \frac{1}{N}(\sum_{i=1}^N \log s(z_i) - \sum_{i=1}^N |\det J_{T_{\theta}}(z_i)| - \sum_{i=1}^N p(x_i))$
    
    * backpropogate (done automatically in PyTorch) and update $\theta$

In [None]:
# normalizing flow as correction to TT-IRT
import numpy as np
import scipy as sp
import scipy.io
import scipy.stats
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distrib
import torch.distributions.transforms as transform
from IPython.display import HTML
import math

# very high precision
torch.set_default_dtype(torch.float64)

import torch.optim as optim

# return to original code: https://github.com/acids-ircam/pytorch_flows/blob/master/flows_04.ipynb

# use normalizing flow as correction
# cite: https://github.com/acids-ircam/pytorch_flows/blob/master/flows_04.ipynb
class Flow(transform.Transform, nn.Module):
    
    def __init__(self):
        transform.Transform.__init__(self)
        nn.Module.__init__(self)
    
    # Init all parameters
    def init_parameters(self):
        for param in self.parameters():
            # use random parameters
            param.data.uniform_(-0.01, 0.01)
            
    # Hacky hash bypass
    def __hash__(self):
        return nn.Module.__hash__(self)
    
    # forward evaluation: x = f(z)
    def forward(self, z):
        pass

# Main class for normalizing flow
class NormalizingFlow(nn.Module):
    def __init__(self, dim, blocks, flow_length, density):
        super().__init__()
        biject = []
        for f in range(flow_length):
            if blocks is None:
                # by default uses Planar flow, which does not have inverse abiity
                biject.append(PlanarFlow(dim))
            else:
                # alternate among the blocks
                for flow in blocks:
                    biject.append(flow(dim))
        self.transforms = transform.ComposeTransform(biject)
        self.bijectors = nn.ModuleList(biject)
        self.base_density = density
        #self.final_density = distrib.TransformedDistribution(density, self.transforms)
        self.log_det = []

    def forward(self, z):
        self.log_det = []
        # Applies series of flows
        for b in range(len(self.bijectors)):
            self.log_det.append(self.bijectors[b].log_abs_det_jacobian(z))
            z = self.bijectors[b](z)
        return z, self.log_det


class PlanarFlow(Flow):

    def __init__(self, dim):
        super(PlanarFlow, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(1, dim))
        self.scale = nn.Parameter(torch.Tensor(1, dim))
        self.bias = nn.Parameter(torch.Tensor(1))
        self.init_parameters()

    def _call(self, z):
        f_z = F.linear(z, self.weight, self.bias)
        return z + self.scale * torch.tanh(f_z)

    def log_abs_det_jacobian(self, z):
        f_z = F.linear(z, self.weight, self.bias)
        psi = (1 - torch.tanh(f_z) ** 2) * self.weight
        det_grad = 1 + torch.mm(psi, self.scale.t())
        return torch.log(det_grad.abs() + 1e-9)

# Affine coupling flow
class AffineCouplingFlow(Flow):
    def __init__(self, dim, n_hidden=64, n_layers=3, activation=nn.ReLU):
        super(AffineCouplingFlow, self).__init__()
        self.k = dim // 2
        self.g_mu = self.transform_net(self.k, dim - self.k, n_hidden, n_layers, activation)
        self.g_sig = self.transform_net(self.k, dim - self.k, n_hidden, n_layers, activation)
        self.init_parameters()
        self.bijective = True

    def transform_net(self, nin, nout, nhidden, nlayer, activation):
        net = nn.ModuleList()
        for l in range(nlayer):
            net.append(nn.Linear(l==0 and nin or nhidden, l==nlayer-1 and nout or nhidden))
            net.append(activation())
        return nn.Sequential(*net)
        
    def _call(self, z):
        z_k, z_D = z[:, :self.k], z[:, self.k:]
        zp_D = z_D * torch.exp(self.g_sig(z_k)) + self.g_mu(z_k)
        return torch.cat((z_k, zp_D), dim = 1)

    def _inverse(self, z):
        zp_k, zp_D = z[:, :self.k], z[:, self.k:]
        z_D = (zp_D - self.g_mu(zp_k)) / self.g_sig(zp_k)
        return torch.cat((zp_k, z_D))

    def log_abs_det_jacobian(self, z):
        z_k = z[:, :self.k]
        return -torch.sum(torch.abs(self.g_sig(z_k)))
    

class ReverseFlow(Flow):

    def __init__(self, dim):
        super(ReverseFlow, self).__init__()
        self.permute = torch.arange(dim-1, -1, -1)
        self.inverse = torch.argsort(self.permute)

    def _call(self, z):
        return z[:, self.permute]

    def _inverse(self, z):
        return z[:, self.inverse]

    def log_abs_det_jacobian(self, z):
        return torch.zeros(z.shape[0], 1)
    
class ShuffleFlow(ReverseFlow):

    def __init__(self, dim):
        super(ShuffleFlow, self).__init__(dim)
        self.permute = torch.randperm(dim)
        self.inverse = torch.argsort(self.permute)
    
    
class BatchNormFlow(Flow):

    def __init__(self, dim, momentum=0.95, eps=1e-5):
        super(BatchNormFlow, self).__init__()
        # Running batch statistics
        self.r_mean = torch.zeros(dim)
        self.r_var = torch.ones(dim)
        # Momentum
        self.momentum = momentum
        self.eps = eps
        # Trainable scale and shift (cf. original paper)
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        
    def _call(self, z):
        if self.training:
            # Current batch stats
            self.b_mean = z.mean(0)
            self.b_var = (z - self.b_mean).pow(2).mean(0) + self.eps
            # Running mean and var
            self.r_mean = self.momentum * self.r_mean + ((1 - self.momentum) * self.b_mean)
            self.r_var = self.momentum * self.r_var + ((1 - self.momentum) * self.b_var)
            mean = self.b_mean
            var = self.b_var
        else:
            mean = self.r_mean
            var = self.r_var
        x_hat = (z - mean) / var.sqrt()
        y = self.gamma * x_hat + self.beta
        return y

    def _inverse(self, x):
        if self.training:
            mean = self.b_mean
            var = self.b_var
        else:
            mean = self.r_mean
            var = self.r_var
        x_hat = (z - self.beta) / self.gamma
        y = x_hat * var.sqrt() + mean
        return y
        
    def log_abs_det_jacobian(self, z):
        # Here we only need the variance
        mean = z.mean(0)
        var = (z - mean).pow(2).mean(0) + self.eps
        log_det = torch.log(self.gamma) - 0.5 * torch.log(var + self.eps)
        return torch.sum(log_det, -1)

## Train 11-dimensional Rosenbrock Distribution Data

We use planar flow as above to correct the severely truncated samples.

In [None]:
# load 11d Rosen sample data
data2 = scipy.io.loadmat("./data/tt_irt_rosen_sample.mat")
data2.keys()

In [None]:
training_data = data2['training_data'] # severely truncated TT samples, used as training data to flow in every epoch
# get log approximate sample densities from TT-IRT
log_training_data_densities = data2['training_data_densities']
#training_data_densities = np.exp(log_training_data_densities)
#training_data_densities_normalized = training_data_densities / training_data_densities.sum()
#log_training_data_densities = np.log(training_data_densities)



# data used for training
N = data2['N'][0][0]
num_epoch = data2['epoch'][0][0]
rosen_dim = data2['d'][0][0]
# normalization constant
rosen_norm = data2['rosen_norm_const'][0][0]
print(">>> shape of all training data: ", training_data.shape)


#print(">>> shape of sample data from untruncated TT: ", tt_sample_trunc.shape)
#print(">>> shape of sample data from truncated TT: ", tt_sample_trunc.shape)

In [None]:
# look at heavy tail
np.random.seed(10)
data_idx = np.random.randint(0, training_data.shape[0], N)
plt.figure(figsize=(20,8));
plt.subplot(1,2,1); plt.grid(True); 
plt.xlabel("$\\theta_{}$".format(rosen_dim-1)); 
plt.ylabel("$\\theta_{}$".format(rosen_dim)); 
plt.title("Samples from Truncated TT"); 
plot_data_trunc = training_data[data_idx, :]
plt.scatter(plot_data_trunc[:,rosen_dim-2], plot_data_trunc[:,rosen_dim-1], color='blue', s=1);

In [None]:
# begin correction

# define Rosenbrock function for density evaluation
def rosen(theta):
    """ theta is a torch tensor in R^(Nxd). """
    d = np.shape(theta)[1]
    # formula from paper: An n-dimensional Rosenbrock Distribution for MCMC Testing
    # unknown normalization constant
    result = 0
    for k in range(d-1):
        result += (theta[:,k])**2 + ( theta[:,k+1] + 5 * ((theta[:,k])**2 + 1) )**2
    return torch.exp(-0.5*result) / rosen_norm

# need to redefine loss (takes in a PDF function instead of a pytorch distribution object)
def loss2(density, zk, log_jacobians):
    sum_of_log_jacobians = sum(log_jacobians)
    # free energy lower bound [Variational Inference with Normalizing Flows]
    return (-sum_of_log_jacobians - torch.log(density(zk)+1e-10)).mean()

def loss_kl(prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ prior_distrb, targ_distrib need to be PyTorch distribution objects. """
    sum_of_log_jacobians = sum(log_jacobians)
    # ELBO
    return (prior_distrib.log_prob(z0) - sum_of_log_jacobians - targ_distrib.log_prob(zk)).mean()
def loss_kl_2(prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ sae as LOSS_KL() but targ_distrib is callable, rather than PyTorch distribution. """
    sum_of_log_jacobians = sum(log_jacobians)
    return (prior_distrib.log_prob(z0) - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()

def loss_kl_3(log_prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ sae as LOSS_KL() but:
        - log_prior_distrib is predetermined as passed in as a torch.Tensor
        - targ_distrib is callable, rather than PyTorch distribution. 
    """
    sum_of_log_jacobians = sum(log_jacobians)
    return (log_prior_distrib - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()

In [None]:
# begin training
flow_type = "planar" # autoregressive
if flow_type == "planar":
    rosen_planar_flow = NormalizingFlow(dim=rosen_dim, blocks=None, flow_length=32, density=None)
elif flow_type == "autoregressive":
    block_real_nvp = [ AffineCouplingFlow, ReverseFlow, BatchNormFlow ]
    rosen_planar_flow = NormalizingFlow(dim=rosen_dim, blocks=block_real_nvp, flow_length=12, \
                                        density=None)

In [None]:
# create optimizer with adaptive learning rate
import torch.optim as optim
# Create optimizer algorithm
optimizer = optim.Adam(rosen_planar_flow.parameters(), lr=1e-3)
# Add learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

In [None]:
# Begin training! we use the sample from severely truncated TT and 
# the flow
#          x_TT = f(x_TT_truncated)
# train the inverse flow, i.e.
#          x_TT_truncated = g(x_TT) = f^{-1}(x_TT)

# check exists
training_data; N; data_idx; rosen_dim;
save_training_data = training_data

# As per discussion with Michael, shift the Gaussian by estimated mean 
# and scale by estimated std in each dimension
est_mean = training_data.mean(0)
est_std = training_data.std(0)

# d-DIMENSIONAL standard normal as toy training base
ref_distrib = distrib.MultivariateNormal(torch.Tensor(est_mean), torch.diag(torch.Tensor(est_std)))
targ_distrib = distrib.MultivariateNormal(torch.zeros([rosen_dim]), torch.eye(rosen_dim))


use_normal = False # use multivariate normal as base distribution
if use_normal:
    training_data = ref_distrib.sample((training_data.shape[0], ))
else:
    training_data = save_training_data

In [None]:
# plot a sample from training data
data_idx;
plot_data = training_data[data_idx, :]
plt.figure(figsize=(8,6)); plt.grid(True); plt.xlabel("dim 10"); plt.ylabel("dim 11");
plt.title("base distribution");
plt.scatter(plot_data[:,rosen_dim-2], plot_data[:,rosen_dim-1], color="blue", s=2);

# plot a sample from standard normal distribution
data_idx;
plot_data = targ_distrib.sample((len(data_idx), ))
plt.figure(figsize=(8,6)); plt.grid(True); plt.xlabel("dim 10"); plt.ylabel("dim 11");
plt.title("standard normal distribution");
plt.scatter(plot_data[:,rosen_dim-2], plot_data[:,rosen_dim-1], color="red", s=2);

In [None]:
# Main optimization loop
num_iter = 10000
num_subplot = 0
all_losses = []

# define loss function we are using
criterion = torch.nn.KLDivLoss(reduction='mean', log_target=True)
for it in range(num_iter+1):
    # Draw a random sample batch from "base" (the truncated TT samples, in training data)
    # row indices
    sample_idx = np.random.choice(training_data.shape[0], N)
    samples = torch.Tensor(training_data[sample_idx,:])
    #samples = torch.Tensor(training_data)
    
    # draw from normal and see if it captures Rosenbrock
    #samples = ref_distrib.sample((N, ))
    # flow this sample
    zk, log_jacobians = rosen_planar_flow(samples)
    
    if use_normal:
        # if we are using normal, the base density is simple to evaluate
        log_base_prob = ref_distrib.log_prob(samples).reshape(-1, 1)
        # volume correction (approximate dist.)
        log_base_prob -= sum(log_jacobians).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)
        # target density
        #targ_prob = rosen(torch.Tensor(zk))  
        targ_prob = targ_distrib.log_prob(torch.Tensor(zk))
    else:
        # need evaluating TT density
        log_base_prob = torch.Tensor(log_training_data_densities[sample_idx, :]).reshape(-1,1)
        
        # volume correction (approximate dist.)
        #log_base_prob -= sum(log_jacobians).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)
        # target density
        #targ_prob = rosen(torch.Tensor(zk))
        targ_prob = targ_distrib.log_prob(torch.Tensor(zk))
        
        base_prob = log_base_prob
        
    ### visual reporting
    if (it % 1000 == 0):
        #print(rosen_planar_flow.weight)
        # plot the flowed samples
        num_subplot += 1
        plt.figure(2, figsize=(100,80));
        plt.subplot(6,6,num_subplot); plt.grid(True); 
        plt.xlabel("$dim {}$".format(rosen_dim-1)); 
        plt.ylabel("$dim {}$".format(rosen_dim)); 
        plt.title("{} at iter. {}".format(num_subplot, it));
        # pick a sample from training data, flow, and plot it
        data_idx = np.random.randint(0, training_data.shape[0], N)
        plot_data = torch.Tensor(training_data[data_idx, :])
        flowed_plot_data, _ = rosen_planar_flow(plot_data)
        flowed_plot_data = flowed_plot_data.detach().numpy()
        plt.scatter(flowed_plot_data[:,rosen_dim-2], flowed_plot_data[:,rosen_dim-1], color='purple', s=5);
        plt.show()
    
    
    ###
    
    # compute loss of on the flowed sample
    optimizer.zero_grad()
    
    # <works>
    #loss_v = loss2(rosen, zk, log_jacobians)
    
    # <does not work>
    #loss_v = criterion(log_base_prob, targ_prob)
    
    # <works>
    #loss_v = loss_kl(ref_distrib, targ_distrib, samples, zk, log_jacobians)
    
    # <works, but learns wrong distribution)
    #loss_v = loss_kl_2(ref_distrib, rosen, samples, zk, log_jacobians)
    
    # <works, ideal solution>
    loss_v = loss_kl_3(base_prob, rosen, samples, zk, log_jacobians)
    loss_v.backward()
    optimizer.step()
    scheduler.step()
    if (it % 500 == 0):
        #print('Log ML Loss (it. %i) : %f'%(it, loss_v.item()))
        print('KL Divergence Loss (it. %i) : %f'%(it, loss_v.item()))
        # for plotting loss
        all_losses.append(loss_v.item())

In [None]:
# plot loss
plt.figure(figsize=(10,8));
plt.grid(True);
plt.title("ELBO vs. Training Loop");
plt.plot(500*np.array(range(len(all_losses))), all_losses, color='green', marker='o', label='loss');
plt.xlabel("num iter. "); plt.ylabel('ELBO');
plt.legend();

In [None]:
print(np.round(all_losses, 4))

In [None]:
# plot flowed result after training

# flow all truncated data
plt.figure(1, figsize=(20,8)); 
plt.subplot(1,2,1);
plt.grid(True);
plt.title("Heavy Tail Samples Before Flow Correction (KL Div.), Num Samples = {}".format(N)); 
plt.xlabel("$\\theta_{}$".format(rosen_dim-1)); plt.ylabel("$\\theta_{}$".format(rosen_dim));
# scatter the heavy tail part
plt.scatter(training_data[data_idx,:][:,rosen_dim-2], training_data[data_idx,:][:,rosen_dim-1], s=1, color='green');

plt.subplot(1,2,2);
plt.grid(True);
plt.title("Heavy Tail Samples After Flow Correction (KL Div.), Num Samples = {}".format(N)); 
plt.xlabel("$\\theta_{}$".format(rosen_dim-1)); plt.ylabel("$\\theta_{}$".format(rosen_dim));
# use NF to flow first
# row indices
#sample_idx = np.random.randint(0, training_data.shape[0], N)
samples = torch.Tensor(training_data[sample_idx,:])
#samples = ref_distrib.sample((N, ))

zk, _ = rosen_planar_flow(samples)

# plot
zk = zk.detach().numpy()
plt.scatter(zk[:,rosen_dim-2], zk[:,rosen_dim-1], s=1, color='blue');
print(zk.mean(0))
print(zk.std(0))

# 50 Dimensional Transition State Governed by Ginzburg-Landau Energy in 1d

In [None]:
# load 50d Rosen sample data
data_ginz = scipy.io.loadmat("./data/tt_irt_ginzburg1d_sample.mat")
display(data_ginz.keys())
# very high precision
torch.set_default_dtype(torch.float64)

In [None]:
# training data from truncated TT
training_data = data_ginz['training_data_ginsburg']
log_training_data_densities = data_ginz['training_densities_ginsburg']
# used for comparison, from accurate TT
comparison_data = data_ginz['comparison_data_ginsburg']
comparison_data_densities = data_ginz['comparison_densities_ginsburg']
print("====== shape of training data = {}".format(training_data.shape))
num_samples = data_ginz['N'][0][0]
num_epoch = data_ginz['epoch'][0][0]
# seed used to generate samples
unif_seed = data_ginz['unif_sample']
# dimension of the PDF
ginz_dim = unif_seed.shape[1]
# normalization constant
Z_beta = data_ginz['Z_beta'][0][0]
# parameters
temperature = data_ginz['temp'][0][0]
ginz_delta = data_ginz['delta'][0][0]

In [None]:
# helper functions
def ginzburg_landau_energy1d(U, delta=ginz_delta):
    """ computes the GL energy (1d) for U, U is an torch.Tensor of shape N x d. 
    
    Outputs energy as shape (N x 1) tensor
    
    """
    # make row vector
    U = torch.Tensor(U)
    N = U.shape[0]
    d = U.shape[1]
    # compute stepsize
    h = 1 / (d+1)
    # compute energy with U_0 and U_d+1
    # pad zeros
    #U = torch.cat((torch.zeros([N, 1]), U, torch.zeros([N, 1])), dim=1)
    
    
    V = ( (delta/2) * ( ((1/h) * (U[:,1:d] - U[:,0:d-1]))**2 ) + \
    (1/(4 * delta)) * ( (1 - U[:,1:d]**2 )**2 ) ).sum(1)
    return V

def equilibrium_pdf(U, delta=ginz_delta, beta=1/temperature, normalization_const=Z_beta):
    V = ginzburg_landau_energy1d(U, delta)
    prob = ( 1/normalization_const ) * torch.exp(-beta * V)
    return prob

print("===== testing ginzburg_landau_energy1d\n")
print(ginzburg_landau_energy1d(torch.rand([10,50])))

print("===== testing equilibrium probability density\n")
print(equilibrium_pdf(torch.rand([10,50])))

def loss_kl_3(log_prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ same as LOSS_KL() but:
        - log_prior_distrib is predetermined as passed in as a torch.Tensor
        - targ_distrib is callable, rather than PyTorch distribution. 
    """
    sum_of_log_jacobians = sum(log_jacobians)
    return (log_prior_distrib - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()

def loss_kl_2(prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ sae as LOSS_KL() but targ_distrib is callable, rather than PyTorch distribution. """
    sum_of_log_jacobians = sum(log_jacobians)
    return (prior_distrib.log_prob(z0) - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()

In [None]:
np.random.seed(10)
data_idx = np.random.randint(0, training_data.shape[0], 100*num_samples)
plt.figure(figsize=(20,8));
plt.subplot(1,2,1); plt.grid(True); 
plt.xlabel("$\\theta {}$".format(ginz_dim-1)); 
plt.ylabel("$\\theta {}$".format(ginz_dim)); 
plt.title("Samples from Truncated TT"); 
plot_data_trunc = training_data[data_idx, :]
# it should look like two modes
plt.hexbin(plot_data_trunc[:,ginz_dim-2], plot_data_trunc[:,ginz_dim-1]);

# comparison with accurate TT samples
plt.figure(figsize=(20,8));
plt.subplot(1,2,1); plt.grid(True); 
plt.xlabel("$\\theta {}$".format(ginz_dim-1)); 
plt.ylabel("$\\theta {}$".format(ginz_dim)); 
plt.title("Samples from Accurate TT"); 
plot_data_trunc = comparison_data[data_idx, :]
# it should look like two modes
plt.hexbin(plot_data_trunc[:,ginz_dim-2], plot_data_trunc[:,ginz_dim-1]);

In [None]:
# begin training 
flow_type = 'planar_flow'
if flow_type == 'planar_flow':
    ginz_flow = NormalizingFlow(dim=ginz_dim, blocks=None, flow_length=32, density=None)
elif flow_type == 'autoregressive':
    raise NotImplementedError("Autoregressive Flow needs to be implemented. ")
    
# Create optimizer algorithm
optimizer = optim.Adam(ginz_flow.parameters(), lr=1e-2)
# Add learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

In [None]:
# Main optimization loop
num_iter = 10000
num_subplot = 0
all_losses = []
batch_size = 2**12


# check exists
training_data; ginz_dim;
save_training_data = training_data

# As per discussion with Michael, shift the Gaussian by estimated mean 
# and scale by estimated std in each dimension
est_mean = training_data.mean(0)
est_std = training_data.std(0)

# d-DIMENSIONAL standard normal as toy training base
ref_distrib = distrib.MultivariateNormal(torch.Tensor(est_mean), torch.diag(torch.Tensor(est_std)))
targ_distrib = distrib.MultivariateNormal(torch.zeros([ginz_dim]), torch.eye(ginz_dim))


use_normal = False # use multivariate normal as base distribution
if use_normal:
    training_data = ref_distrib.sample((training_data.shape[0], ))
else:
    training_data = save_training_data

# define loss function we are using
#criterion = torch.nn.KLDivLoss(reduction='mean', log_target=True)
for it in range(num_iter+1):
    # Draw a random sample batch from "base" (the truncated TT samples, in training data)
    # row indices
    sample_idx = np.random.choice(training_data.shape[0], batch_size)
    samples = torch.Tensor(training_data[sample_idx,:])
    #samples = torch.Tensor(training_data)
    
    # draw from normal and see if it captures Rosenbrock
    #samples = ref_distrib.sample((N, ))
    # flow this sample
    zk, log_jacobians = ginz_flow(samples)
    
    if use_normal:
        # if we are using normal, the base density is simple to evaluate
        log_base_prob = ref_distrib.log_prob(samples).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)

    else:
        # need evaluating TT density
        log_base_prob = torch.Tensor(log_training_data_densities[sample_idx, :]).reshape(-1,1)
        
        # volume correction (approximate dist.)
        #log_base_prob -= sum(log_jacobians).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)
        base_prob = log_base_prob
        
    ### visual reporting
    if (it % 1000 == 0):
        # plot the flowed samples
        num_subplot += 1
        plt.figure(2, figsize=(8,6));
        plt.grid(True); 
        plt.xlabel("$dim {}$".format(ginz_dim-1)); 
        plt.ylabel("$dim {}$".format(ginz_dim)); 
        plt.title("{} at iter. {}".format(num_subplot, it));
        # pick a sample from training data, flow, and plot it
        data_idx = np.random.randint(0, training_data.shape[0], 10*num_samples)
        plot_data = torch.Tensor(training_data[data_idx, :])
        flowed_plot_data, _ = ginz_flow(plot_data)
        flowed_plot_data = flowed_plot_data.detach().numpy()
        plt.hexbin(flowed_plot_data[:,ginz_dim-2], flowed_plot_data[:,ginz_dim-1]);
        # save figure
        plt.savefig("./img/GL_1d/ginzburg1d_batchsz{}_train_iter{}".format(batch_size, it))
        plt.show()
        
    
    
    ###
    
    # compute loss of on the flowed sample
    optimizer.zero_grad()
    
    # <works>
    #loss_v = loss2(rosen, zk, log_jacobians)
    
    # <does not work>
    #loss_v = criterion(log_base_prob, targ_prob)
    
    # <works>
    #loss_v = loss_kl(ref_distrib, targ_distrib, samples, zk, log_jacobians)
    
    # <works, but learns wrong distribution)
    #loss_v = loss_kl_2(ref_distrib, equilibrium_pdf, samples, zk, log_jacobians)
    
    # <works, ideal solution>
    loss_v = loss_kl_3(base_prob, equilibrium_pdf, samples, zk, log_jacobians)
    loss_v.backward()
    optimizer.step()
    scheduler.step()
    if (it % 500 == 0):
        #print('Log ML Loss (it. %i) : %f'%(it, loss_v.item()))
        print('KL Divergence Loss (it. %i) : %f'%(it, loss_v.item()))
        # for plotting loss
        all_losses.append(loss_v.item())

# 100 Dimensional Transition State Governed by Ginzburg-Landau Energy in 2d

In [None]:
# load 100d GL sample data
data_ginz2d = scipy.io.loadmat("./data/tt_irt_ginzburg2d_sample.mat")
display(data_ginz2d.keys())
# very high precision
torch.set_default_dtype(torch.float64)

In [None]:
# training data from truncated TT
training_data = data_ginz2d['training_data']
log_training_data_densities = data_ginz2d['training_densities']
# used for comparison, from accurate TT
comparison_data = data_ginz2d['gl_TT_samples']

print("====== shape of training data = {}".format(training_data.shape))
num_samples = data_ginz2d['N'][0][0]
N = int(num_samples / 10)
# seed used to generate samples
unif_seed = data_ginz2d['unif_sample']
# dimension of the PDF
ginz_dim = unif_seed.shape[1]
# normalization constant
Z_beta = data_ginz2d['Z_beta'][0][0]
# parameters
temperature = data_ginz2d['temp'][0][0]
ginz_delta = data_ginz2d['delta'][0][0]

In [None]:
# helper functions
def ginzburg_landau_energy2d(U, delta=ginz_delta):
    """ computes the GL energy (1d) for U, U is an torch.Tensor of shape N x d. 
    
    Outputs energy as shape (N x 1) tensor
    
    """
    # make row vector
    U = torch.Tensor(U)
    N = U.shape[0]
    d = U.shape[1]
    # dimension must be a perfect square
    board_dim = int(math.sqrt(d) + 0.5) # spatial dimension, instead of the PDF dimension
    assert board_dim**2 == d, "dimension is not a perfect square"
    # compute stepsize
    h = 1 / (board_dim+1)
    # torch reshape is different from MATLAB reshape
    # MATLAB is by column, torch is by row
    # we use MATLAB's format here, so need to transpose the result of torch.reshape
    U_2d = U.reshape([N, board_dim, board_dim])
    U_2d = torch.transpose(U_2d, 1, 2)
    # add boundary values
    U_2d_save = U_2d.clone()
    U_2d = torch.zeros([N, board_dim+2, board_dim+2]) # including U0, Ud+1
    U_2d[:,1:board_dim+1,1:board_dim+1] = U_2d_save

    # compute energy with U_0 and U_d+1
    u_x = (1 / h) * (U_2d[:, 1:board_dim+2, :] - U_2d[:, 0:board_dim+1, :])
    u_y = (1 / h) * (U_2d[:, :, 1:board_dim+2] - U_2d[:, :, 0:board_dim+1])
    V = (delta/2)*( (u_x**2).sum([1,2]) ) + \
     (delta/2) * ( (u_y**2).sum([1,2]) ) + \
    (1/(4*delta)) * ((1 - U_2d ** 2) ** 2).sum([1,2])
    return V

def equilibrium_pdf(U, delta=ginz_delta, beta=1/temperature, normalization_const=Z_beta):
    V = ginzburg_landau_energy2d(U, delta)
    prob = ( 1/normalization_const ) * torch.exp(-beta * V)
    return prob

print("===== testing ginzburg_landau_energy2d\n")
print(ginzburg_landau_energy2d(torch.rand([2**15,100])))

print("===== testing equilibrium probability density\n")
print(equilibrium_pdf(torch.rand([2**15,100])))

def loss_kl_3(log_prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ same as LOSS_KL() but:
        - log_prior_distrib is predetermined as passed in as a torch.Tensor
        - targ_distrib is callable, rather than PyTorch distribution. 
    """
    sum_of_log_jacobians = sum(log_jacobians)
    return (log_prior_distrib - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()

def loss_kl_2(prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ sae as LOSS_KL() but targ_distrib is callable, rather than PyTorch distribution. """
    sum_of_log_jacobians = sum(log_jacobians)
    return (prior_distrib.log_prob(z0) - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()

In [None]:
ginzburg_landau_energy2d(2*torch.ones([1, 36]))

In [None]:
np.random.seed(10)
which = 3 # which plane are we plotting, plot(dim(D-which-1),dim(D-which))
plot_samples = num_samples * 4
data_idx = np.random.randint(0, training_data.shape[0], plot_samples)
plt.figure(figsize=(20,8));
plt.subplot(1,2,1); plt.grid(True); 
plt.xlabel("$\\theta {}$".format(ginz_dim-which)); 
plt.ylabel("$\\theta {}$".format(ginz_dim-which+1)); 
plt.title("Samples from Truncated TT"); 
plot_data_trunc = training_data[data_idx, :]

plt.hexbin(plot_data_trunc[:,ginz_dim-which-1], plot_data_trunc[:,ginz_dim-which]);

# comparison with accurate TT samples
plt.figure(figsize=(20,8));
plt.subplot(1,2,1); plt.grid(True); 
plt.xlabel("$\\theta {}$".format(ginz_dim-which)); 
plt.ylabel("$\\theta {}$".format(ginz_dim-which+1)); 
plt.title("Samples from Accurate TT"); 
plot_data_trunc = comparison_data[data_idx, :]
# it should look like two modes
plt.hexbin(plot_data_trunc[:,ginz_dim-which-1], plot_data_trunc[:,ginz_dim-which]);

In [None]:
# begin training 
flow_type = 'planar_flow'
if flow_type == 'planar_flow':
    ginz_flow = NormalizingFlow(dim=ginz_dim, blocks=None, flow_length=32, density=None)
elif flow_type == 'autoregressive':
    raise NotImplementedError("Autoregressive Flow needs to be implemented. ")
    
# Create optimizer algorithm
optimizer = optim.Adam(ginz_flow.parameters(), lr=1e-3)
# Add learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

In [None]:
# Main optimization loop
num_iter = 10000
num_subplot = 0
all_losses = []
batch_size = 2**12


# check exists
training_data; ginz_dim;
save_training_data = training_data

# As per discussion with Michael, shift the Gaussian by estimated mean 
# and scale by estimated std in each dimension
est_mean = training_data.mean(0)
est_std = training_data.std(0)

# d-DIMENSIONAL standard normal as toy training base
ref_distrib = distrib.MultivariateNormal(torch.Tensor(est_mean), torch.diag(torch.Tensor(est_std)))
targ_distrib = distrib.MultivariateNormal(torch.zeros([ginz_dim]), torch.eye(ginz_dim))


use_normal = False # use multivariate normal as base distribution
if use_normal:
    training_data = ref_distrib.sample((training_data.shape[0], ))
else:
    training_data = save_training_data

# define loss function we are using
#criterion = torch.nn.KLDivLoss(reduction='mean', log_target=True)
for it in range(num_iter+1):
    # Draw a random sample batch from "base" (the truncated TT samples, in training data)
    # row indices
    sample_idx = np.random.choice(training_data.shape[0], batch_size)
    samples = torch.Tensor(training_data[sample_idx,:])
    #samples = torch.Tensor(training_data)
    
    # draw from normal and see if it captures Rosenbrock
    #samples = ref_distrib.sample((N, ))
    # flow this sample
    zk, log_jacobians = ginz_flow(samples)
    
    if use_normal:
        # if we are using normal, the base density is simple to evaluate
        log_base_prob = ref_distrib.log_prob(samples).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)

    else:
        # need evaluating TT density
        log_base_prob = torch.Tensor(log_training_data_densities[sample_idx, :]).reshape(-1,1)
        
        # volume correction (approximate dist.)
        #log_base_prob -= sum(log_jacobians).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)
        base_prob = log_base_prob
        
    ### visual reporting
    if (it % 1000 == 0):
        #print(rosen_planar_flow.weight)
        # plot the flowed samples
        num_subplot += 1
        plt.figure(2, figsize=(8,6));
        plt.grid(True); 
        plt.xlabel("$dim {}$".format(ginz_dim-1)); 
        plt.ylabel("$dim {}$".format(ginz_dim)); 
        plt.title("{} at iter. {}".format(num_subplot, it));
        # pick a sample from training data, flow, and plot it
        data_idx = np.random.randint(0, training_data.shape[0], 10*num_samples)
        plot_data = torch.Tensor(training_data[data_idx, :])
        flowed_plot_data, _ = ginz_flow(plot_data)
        flowed_plot_data = flowed_plot_data.detach().numpy()
        plt.hexbin(flowed_plot_data[:,ginz_dim-2], flowed_plot_data[:,ginz_dim-1]);
        # save figure
        plt.savefig("./img/GL_2d/ginzburg2d_batchsz{}_train_iter{}".format(batch_size, it))
        plt.show()
        
    
    
    ###
    
    # compute loss of on the flowed sample
    optimizer.zero_grad()
    
    # <works>
    #loss_v = loss2(rosen, zk, log_jacobians)
    
    # <does not work>
    #loss_v = criterion(log_base_prob, targ_prob)
    
    # <works>
    #loss_v = loss_kl(ref_distrib, targ_distrib, samples, zk, log_jacobians)
    
    # <works, but learns wrong distribution)
    #loss_v = loss_kl_2(ref_distrib, equilibrium_pdf, samples, zk, log_jacobians)
    
    # <works, ideal solution>
    loss_v = loss_kl_3(base_prob, equilibrium_pdf, samples, zk, log_jacobians)
    loss_v.backward()
    optimizer.step()
    scheduler.step()
    if (it % 500 == 0):
        #print('Log ML Loss (it. %i) : %f'%(it, loss_v.item()))
        print('KL Divergence Loss (it. %i) : %f'%(it, loss_v.item()))
        # for plotting loss
        all_losses.append(loss_v.item())

# Nonequilibrium Path Sampling Training

The path space is governed by the following SDE:
$$
    dX_t = b(X_t)dt + \sqrt{2\beta^{-1}}dW_t
$$ where $X_t\in\mathbb{R}^2$, drift $b$ is non-conservative, $\beta \propto \frac{1}{T}$ is inverse temperature of the system. With end points conditioning in the path sapce, we may define a path-integral type distribution (formal):
$$
    P_{*}(x_{[0,t_m]}) \propto \exp(-\frac{\beta}{4}\int_{0}^{t_m} |\dot{x_t} - b(x_t) |^2 dt)
$$

The end point conditions are $x_0 = x_A = [-1,0]$, $x_{t_m} = x_B = [1,0]$. Furthermore, we need to discretize the integral using equidistant time steps $t_0, t_1, \cdots, t_N=t_m$, and a step size $\Delta t$, which yields the target distribution:
$$
    \tilde{P}_{*}(x_1,x_2,\cdots, x_{N-1}; x_0=x_A,x_N=x_B) \propto \exp\bigg[
        -\frac14 S_*(x_1,x_2,\cdots, x_{N-1};x_0,x_N)
    \bigg]
$$ where $S_*$ denotes the target path action:
$$
    S_*(x_1,\cdots, x_{N-1}; x_0, x_N) = \beta \Delta t \bigg[\big(\sum_{i=1}^{N-1}| \frac{x_{i+1} - x_{i}}{\Delta t} - b(x_i)|^2\big) + |\frac{x_1 - x_A}{\Delta t} - b(x_1)|^2 + |\frac{x_B - x_{N-1}}{\Delta t} - b(x_{N-1})|^2 \bigg]
$$ where the last two terms are "penalties" for boundary points.

The non-conservative drift term is given by $b: \mathbb{R}^2 \rightarrow \mathbb{R}^2$:
$$
    b(x) = -\nabla V(x) + f(x)
$$ where:

$$
     V(x) = A_1 \exp({|{x-\mu_1}|^2}) + A_2 \exp({|{x-\mu_2}|^2}) + A_3 \exp({|{x-\mu_3}|^2}) + A_4 \exp({|{x-\mu_4}|^2}) + A_5\big[(x^{(1)} - \mu_1^{(1)})^4 + (x^{(2)} - \mu_1^{(2)})^4\big]
$$  


$V$ is a weighted sum of Gaussians that define our three-hole potential [[Metzner]](https://aip-scitation-org.proxy.uchicago.edu/doi/10.1063/1.2335447) where the last two quartic terms are to prevent exponential decay (and thus exponential growth of probability mass) on the boundary.

$f(x) = (-x^{(2)}, x^{(1)})^T$ is the nonconservative term that adds a counterclockwise spin (by $\pi/2$, consider mapping $f: (\sin z, \cos z)\mapsto (-\cos z, \sin z)$) to the path.

In our experiments:
$$
    A_1 = 40; A_2 = -55; A_3 = -50; A_4 = -50; A_5 = 0.2;
$$
$$
    \beta = 1, N = ??
$$

In [None]:
# load 100d GL sample data
data_transition = scipy.io.loadmat("./data/tt_irt_path_sampling.mat")
display(data_transition.keys())
# very high precision
torch.set_default_dtype(torch.float64)

In [None]:
# training data from truncated TT
training_data = data_transition['path_samples']
log_training_data_densities = np.log(data_transition['path_sample_densities']).reshape(-1)

print("====== shape of training data = {}".format(training_data.shape))
num_samples = data_transition['M'][0][0]
# seed used to generate samples
unif_seed = data_transition['unif_sample']
# dimension of the PDF
transition_dim = data_transition['N'][0][0] * 2

# parameters
path_beta = data_transition['beta'][0][0]
path_dt = data_transition['dt'][0][0]

In [None]:
# visialize paths (uncorrected)
path_sampels_uncorrected = training_data.reshape([num_samples, 2, transition_dim // 2]).transpose(0, 2, 1)
for i in range(num_samples):
    plt.plot(path_sampels_uncorrected[i,:,0], path_sampels_uncorrected[i,:,1])

In [None]:
# helper functions
def path_action(U, beta=path_beta, dt=path_dt):
    """ computes the effective energy (path integral) for path sample U. 
    
    Inputs:
        U,                          (torch.Tensor) (M x 2*N) path samples, each row is a path
                                    with time discretization level N
        beta,                       (scalar)       inverse temperature, parameter
        dt,                         (scalar)       time discretization level, parameter
        
    Outputs:
        S,                          (torch.Tensor) (M x 1)   energy for each path of U
    """
    U = torch.Tensor(U) # make sure of formatting
    assert U.shape[1] % 2 == 0, "Dimension of path must be divisible by 2. "
    assert U.requires_grad, "U must require grad, for auto-differentiation. "
    xA = torch.tensor([[-1., 0.]]); xB = torch.tensor([[1., 0.]])
    M = int(U.shape[0])
    N = int(U.shape[1] // 2)
    # reshape to (M x N x 2), torch reshapes by row, hence the transpose
    U_2d = U.reshape([M, 2, N])
    U_2d = U_2d.transpose(1, 2) # U has size (M x N x 2)
    U_2d = U_2d.view(U_2d.shape[0], -1, 2)
    # get gradient of potential, has size (M x N x 2) the same as U_2d
    grad_V = torch.autograd.grad(V(U_2d).sum(), U_2d, create_graph=True)[0]
    # compute path integral with differencing
    S = torch.sum( ( (U_2d[:,1:,:] - U_2d[:,:-1,:]) / dt + grad_V[:,:-1,:] )**2, dim=(1, 2))
    # penalize boundary xA
    S += torch.sum(
            ( (U_2d[:,0,:] - xA) / dt + grad_V[:,0,:] )**2, dim=1)
    # penalize boundary xB
    S += torch.sum(
            ( (xB - U_2d[:,-1,:]) / dt + grad_V[:,-1,:] )**2, dim=1)
    return beta * dt * S

def drift(U):
    """ Nonequilibrium drift, currently none. """
    pass

def V(U):
    """ Three-hole potential (weighted sum of Gaussian) 
    
    Inputs:
        U,                          (torch.Tensor) (M x N x 2) path samples
        
    Outputs:
        V,                          (M x N) scalar, three-hole potential
    """
    # parameters for the Gaussian, make sure they match with MATLAB
    mu1 = torch.tensor([[0., 1./3.]])
    mu2 = torch.tensor([[0., 5./3.]])
    mu3 = torch.tensor([[-1., 0.]])
    mu4 = torch.tensor([[1., 0.]])
    A1 = 40.; A2 = -55.; A3 = -50.; A4 = -50.; A5 = 0.2;
    # mixture Gaussian with penalty
    return g(U, mu1, A1) + g(U, mu2, A2) +\
        g(U, mu3, A3) + g(U, mu4, A4) + 0.2 * torch.sum((U - mu1)**4, dim=-1)
    
    
def g(U, mu, amp):
    """ Gaussian function. 
    
    Inputs:
        U,                          (torch.Tensor) (M x N x 2) A path in R^2
        mu,                         (torch.Tensor) (1 x 2) mean of Gaussian
        amp,                        (scalar) amplitude of Gaussian
        
    Outputs:
        g_U,                        (M x N) result of Gaussian function evaluated at U 
    """
    M = U.shape[0]; N = U.shape[1]
    return amp * torch.exp(-torch.sum((U - mu)**2, dim=-1))

In [None]:
# criterion for training
def loss_kl_3(log_prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ same as LOSS_KL() but:
        - log_prior_distrib is predetermined as passed in as a torch.Tensor
        - targ_distrib is callable, rather than PyTorch distribution. 
    """
    sum_of_log_jacobians = sum(log_jacobians)
    #return (log_prior_distrib - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()
    return (log_prior_distrib - sum_of_log_jacobians - torch.log(targ_distrib(zk) )).mean()

def loss_kl_2(prior_distrib, targ_distrib, z0, zk, log_jacobians):
    """ sae as LOSS_KL() but targ_distrib is callable, rather than PyTorch distribution. """
    sum_of_log_jacobians = sum(log_jacobians)
    return (prior_distrib.log_prob(z0) - sum_of_log_jacobians - torch.log(targ_distrib(zk) + 1e-40)).mean()


# equilibrium PDF used for training
def equilibrium_pdf(U, beta=path_beta, dt=path_dt):
    
    # forcefully truncate the domain, make sure domain 
    # is consistent with MATLAB
    # [-1.2,1.2] x [-0.7,2]
    
    assert U.shape[1] % 2 == 0, "Dimension of path must be divisible by 2. "
    # reshape to (B x N x 2), torch reshapes by row, hence the transpose, B is batch size
    N = int(U.shape[1] // 2)
    U_2d = U.reshape([-1, 2, N])
    U_2d = U_2d.transpose(1, 2) # U has size (B x N x 2)
    # check if every point for each batch path is in the domain, if not, it should have 0 probability
    prob = torch.zeros([U_2d.shape[0]])
    valid_prob = torch.exp(-0.25 * path_action(U, beta, dt)) 
    #for j in range(U_2d.shape[0]):
    #    path_j = U_2d[j,:,:]
    #    if ( ( path_j[:,0] < -1.5 ).any() ) or ( ( path_j[:,0] > 1.5 ).any() ):
    #        prob[j] = 0
    #    elif ( ( path_j[:,1] < -0.9 ).any() ) or ( ( path_j[:,1] > -2.2 ).any() ):
    #        prob[j] = 0
    #    else:
    #        prob[j] = valid_prob[j]
    return valid_prob

In [None]:
np.random.seed(10)
which = 5 # which plane are we plotting, plot(dim(D-which-1),dim(D-which))
plot_samples = num_samples
data_idx = np.random.randint(0, training_data.shape[0], plot_samples)
plt.figure(figsize=(20,8));
plt.subplot(1,2,1); plt.grid(True); 
plt.xlabel("$\\theta {}$".format(transition_dim-which)); 
plt.ylabel("$\\theta {}$".format(transition_dim-which+1)); 
plt.title("Samples from Truncated TT"); 
plot_data_trunc = training_data[data_idx, :]

plt.hexbin(plot_data_trunc[:,transition_dim-which-1], plot_data_trunc[:,transition_dim-which]);

# it should look like two modes
plt.hexbin(plot_data_trunc[:,transition_dim-which-1], plot_data_trunc[:,transition_dim-which]);

In [None]:
# begin training 
flow_type = 'planar_flow'
if flow_type == 'planar_flow':
    path_flow = NormalizingFlow(dim=transition_dim, blocks=None, flow_length=32, density=None)
elif flow_type == 'autoregressive':
    raise NotImplementedError("Autoregressive Flow needs to be implemented. ")
    
# Create optimizer algorithm
optimizer = optim.Adam(path_flow.parameters(), lr=1e-2)
# Add learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

In [None]:
# Main optimization loop
num_iter = 10000
num_subplot = 0
all_losses = []
batch_size = 2**12


# check exists
training_data; transition_dim;
save_training_data = training_data

# As per discussion with Michael, shift the Gaussian by estimated mean 
# and scale by estimated std in each dimension
est_mean = training_data.mean(0)
est_std = training_data.std(0)

# d-DIMENSIONAL standard normal as toy training base
ref_distrib = distrib.MultivariateNormal(torch.Tensor(est_mean), torch.diag(torch.Tensor(est_std)))
targ_distrib = distrib.MultivariateNormal(torch.zeros([transition_dim]), torch.eye(transition_dim))


use_normal = False # use multivariate normal as base distribution
if use_normal:
    training_data = ref_distrib.sample((training_data.shape[0], ))
else:
    training_data = save_training_data

# define loss function we are using
#criterion = torch.nn.KLDivLoss(reduction='mean', log_target=True)
for it in range(num_iter+1):
    # Draw a random sample batch from "base" (the truncated TT samples, in training data)
    # row indices
    sample_idx = np.random.choice(training_data.shape[0], batch_size)
    samples = torch.Tensor(training_data[sample_idx,:])
    #samples = torch.Tensor(training_data)
    
    # draw from normal and see if it captures Rosenbrock
    #samples = ref_distrib.sample((N, ))
    # flow this sample
    zk, log_jacobians = path_flow(samples)
    
    if use_normal:
        # if we are using normal, the base density is simple to evaluate
        log_base_prob = ref_distrib.log_prob(samples).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)

    else:
        # need evaluating TT density
        log_base_prob = torch.Tensor(log_training_data_densities[sample_idx]).reshape(-1,1)
        
        # volume correction (approximate dist.)
        #log_base_prob -= sum(log_jacobians).reshape(-1, 1)
        log_base_prob = log_base_prob.reshape(-1)
        base_prob = log_base_prob
        
    ### visual reporting
    if (it % 1000 == 0):
        #print(rosen_planar_flow.weight)
        # plot the flowed samples
        num_subplot += 1
        plt.figure(2, figsize=(8,6));
        plt.grid(True); 
        plt.xlabel("$dim {}$".format(transition_dim-1)); 
        plt.ylabel("$dim {}$".format(transition_dim)); 
        plt.title("{} at iter. {}".format(num_subplot, it));
        # pick a sample from training data, flow, and plot it
        data_idx = np.random.randint(0, training_data.shape[0], num_samples)
        plot_data = torch.Tensor(training_data[data_idx, :])
        flowed_plot_data, _ = path_flow(plot_data)
        flowed_plot_data = flowed_plot_data.detach().numpy()
        plt.hexbin(flowed_plot_data[:,transition_dim-2], flowed_plot_data[:,transition_dim-1]);
        # save figure
        plt.savefig("./img/twochannel/transition_path_batchsz{}_train_iter{}".format(batch_size, it))
        #plt.show()
        
    
    
    ###
    
    # compute loss of on the flowed sample
    optimizer.zero_grad()
    loss_v = loss_kl_3(base_prob, equilibrium_pdf, samples, zk, log_jacobians)
    loss_v.backward()
    optimizer.step()
    scheduler.step()
    if (it % 500 == 0):
        #print('Log ML Loss (it. %i) : %f'%(it, loss_v.item()))
        print('KL Divergence Loss (it. %i) : %f'%(it, loss_v.item()))
        # for plotting loss
        all_losses.append(loss_v.item())

In [None]:
# plot samples after correction
path_samples_corrected = path_flow(torch.Tensor(training_data))[0]

In [None]:
path_samples_corrected = path_samples_corrected.detach().numpy().reshape([num_samples, 2, \
                                                                          transition_dim // 2]).transpose(0, 2, 1)

In [None]:
for i in range(num_samples):
    plt.plot(path_samples_corrected[i,:,0], path_samples_corrected[i,:,1])

## Corrections using 4d Gaussian (old example)

```python
# load 4d sample data
data = scipy.io.loadmat("./tt_irt_sample.mat")
data.keys()

# load a few useful variables"
unif_sample = data['unif_sample']
grid = data['x']
# Gaussian parameters used to generate data
mu = data['mu']
sigma = data['sigma']
d = data['d'][0][0]
# generated samples
sample_exact = data['s_exact'] # can only be used for comparison
sample_tt = data['xq'] # sample from TT-IRT
print("=== Ftt rank before truncation: ", data['ftt_rank_before'])
print("== Ftt rank after truncation: ", data['ftt_rank_after'])
print("=== data size: ", sample_tt.shape)
print("=== current Wasserstein metric: ", scipy.stats.wasserstein_distance(sample_exact.flatten(), \
                                                                           sample_tt.flatten()))

# Create normalizing flow
flow = NormalizingFlow(dim=d, flow_length=16, density=None)

def loss(distrib, zk, log_jacobians):
    sum_of_log_jacobians = sum(log_jacobians)
    return (-sum_of_log_jacobians - distrib.log_prob(zk)+1e-9).mean()

import torch.optim as optim
# Create optimizer algorithm
optimizer = optim.Adam(flow.parameters(), lr=2e-3)
# Add learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
```

```python
# exact prob density
target_distrib = distrib.MultivariateNormal(torch.Tensor(mu.T), torch.Tensor(sigma))


# we use the sample from severely truncated TT


#id_figure=2
#plt.figure(figsize=(16, 18))

# Main optimization loop
for it in range(10001):
    # Draw a sample batch from TT sample
    #sample_idx = np.random.randint(0, sample_tt.shape[0], 1024)
    #samples = torch.Tensor(sample_tt[sample_idx, :])
    samples = torch.Tensor(sample_tt)
    # Evaluate flow of transforms
    zk, log_jacobians = flow(samples)
    # Evaluate loss and backprop
    optimizer.zero_grad()
    loss_v = loss(target_distrib, zk, log_jacobians)
    loss_v.backward()
    optimizer.step()
    scheduler.step()
    if (it % 500 == 0):
        print('Loss (it. %i) : %f'%(it, loss_v.item()))
        # Draw random samples
        #samples = ref_distrib.sample((int(1e5), ))
        # Evaluate flow and plot
        #zk, _ = flow(samples)
        #zk = zk.detach().numpy()
        #plt.subplot(3,4,id_figure)
        #plt.hexbin(zk[:,0], zk[:,1], cmap='rainbow')
        #plt.title('Iter.%i'%(it), fontsize=15);
        #id_figure += 1

# check after training
samples_corrected, _ = flow(torch.Tensor(sample_tt))
samples_corrected = samples_corrected.detach().numpy()
print("=== Wasserstein metric (corrected): ", scipy.stats.wasserstein_distance(sample_exact.flatten(), \
                                                                           samples_corrected.flatten()))
```

In [None]:
# load 50d Rosen sample data
data_ginz = scipy.io.loadmat("./data/tt_irt_ginzburg1d_sample.mat")
display(data_ginz.keys())
# very high precision
torch.set_default_dtype(torch.float64)

In [None]:
# training data from truncated TT
training_data = data_ginz['training_data_ginsburg']
log_training_data_densities = data_ginz['training_densities_ginsburg']
# used for comparison, from accurate TT
comparison_data = data_ginz['comparison_data_ginsburg']
comparison_data_densities = data_ginz['comparison_densities_ginsburg']
print("====== shape of training data = {}".format(training_data.shape))
num_samples = data_ginz['N'][0][0]
num_epoch = data_ginz['epoch'][0][0]
# seed used to generate samples
unif_seed = data_ginz['unif_sample']
# dimension of the PDF
ginz_dim = unif_seed.shape[1]
# normalization constant
Z_beta = data_ginz['Z_beta'][0][0]
# parameters
temperature = data_ginz['temp'][0][0]
ginz_delta = data_ginz['delta'][0][0]