# Neural Spline Flow

In [1]:
# Import required packages
import torch
import numpy as np
import normflows as nf
from torchviz import make_dot
from sklearn.datasets import make_moons
from scipy.special import erf, gamma
import matplotlib as mat
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import benchmark
import xfacpy
from tqdm import tqdm
from functools import reduce

In [2]:
plt.switch_backend('TkAgg')

plt.style.use(['science'])
fsize = 38
mat.rcParams.update({'font.size': fsize})
mat.rcParams["font.family"] = "Times New Roman"
#Colormap = ["blue", "orange", "green", "red", "black", "brown",
#            "pink", "gray", "olive", "cyan",  "black"]
pt_L = ['p', '^', 'v', 'o', 's', 'p', '<', 'h', '>', 'H', 'D']

In [3]:
#Generate corresponding tensor based on the indices
#xell is the grid of one dimension. Here we use the same uniform grid for all dimensions
#nfm is the flow model.
def Ptensor(I, J, zell, ndims, nfm, well=None):
    # Generate input tensor
    z = torch.zeros(len(I),len(J),ndims)
    w = torch.ones(len(I),len(J))
    for i, Idex in enumerate(I):
        for j, Jdex in enumerate(J):
            z[i,j,:] = torch.tensor(np.array([zell[k] for k in Idex+Jdex])) 
            if well is not None:
                w[i,j] = reduce(lambda x, y: x * y, [well[k] for k in Idex+Jdex])
    #result = torch.tensor(x2)
    #print("test:",result)
    znewshape = (torch.prod(torch.tensor(z.shape[0:-1])).item(), z.shape[-1])
    
    # Reshape the input z two 2D tensor, (batchsize, ndims), required by the flow.prob.
    x, logJ = nfm.forward_and_log_det(z.view(znewshape))
    result = torch.exp(nfm.p.log_prob(x) + logJ)
    # Reshape back to the Matrix form for loss calculation
    return w*result.view(z.shape[0:-1])

def Ttensor(I, P, J, zell, ndims, nfm, well=None,zfine=None, wfine=None):
    z = torch.zeros(len(I),len(P), len(J), ndims)
    w = torch.ones(len(I),len(P), len(J))
    for i, Idex in enumerate(I):
        for p, Pdex in enumerate(P):
            for j, Jdex in enumerate(J):
                if zfine is not None:
                    grid = np.array([zell[k] for k in Idex]+[zfine[Pdex]]+[zell[k] for k in Jdex])
                else:
                    grid = np.array([zell[k] for k in Idex+[Pdex]+Jdex])
                z[i,p,j,:] = torch.tensor(grid)
                if well is not None:
                    if wfine is not None:
                        weight = [well[k] for k in Idex]+[wfine[Pdex]]+[well[k] for k in Jdex]
                    else:
                        weight = [well[k] for k in Idex+[Pdex]+Jdex]
                    w[i,p,j] = reduce(lambda x, y: x * y,weight)
    znewshape = (torch.prod(torch.tensor(z.shape[0:-1])).item(), z.shape[-1])
    
    # Reshape the input z two 2D tensor, (batchsize, ndims), required by the flow.prob.
    x, logJ = nfm.forward_and_log_det(z.view(znewshape))
    result = torch.exp(nfm.p.log_prob(x) + logJ)
    # Reshape back to the Matrix form for loss calculation
    return w*result.view(z.shape[0:-1])

def Pitensor(I, P1, P2, J, zell, ndims, nfm, well=None, zfine=None, wfine=None):
    z = torch.zeros(len(I),len(P1), len(P2),len(J), ndims)
    w = torch.ones(len(I),len(P1), len(P2),len(J))
    for i, Idex in enumerate(I):
        for p1, P1dex in enumerate(P1):
            for p2,P2dex in enumerate(P2):
                for j, Jdex in enumerate(J):
                    if zfine is not None:
                        grid = np.array([zell[k] for k in Idex]+[zfine[P1dex]]+[zfine[P2dex]]+[zell[k] for k in Jdex])
                    else:
                        grid = np.array([zell[k] for k in Idex+[P1dex]+[P2dex]+Jdex])
                    z[i,p1,p2,j,:] = torch.tensor(grid)  
                    if well is not None:
                        if wfine is not None:
                            weight = [well[k] for k in Idex]+[wfine[P1dex]]+[wfine[P2dex]]+[well[k] for k in Jdex]
                        else:
                            weight = [well[k] for k in Idex+[P1dex]+[P2dex]+Jdex]
                        w[i,p1,p2,j] = reduce(lambda x, y: x * y, weight)
    znewshape = (torch.prod(torch.tensor(z.shape[0:-1])).item(), z.shape[-1])
    
    # Reshape the input z two 2D tensor, (batchsize, ndims), required by the flow.prob.
    x, logJ = nfm.forward_and_log_det(z.view(znewshape))
    result = torch.exp(nfm.p.log_prob(x) + logJ)
    # Reshape back to the Matrix form for loss calculation
    return w*result.view(z.shape[0:-1])


def insample_error(T1,P,T2,Pi):
    T2newshape =(T2.shape[0], torch.prod(torch.tensor(T2.shape[1:])).item()) 
    T1newshape =(torch.prod(torch.tensor(T1.shape[0:-1])).item(), T1.shape[-1]) 
    Pinewshape =(T1newshape[0], T2newshape[1]) 
    return Pi.view(Pinewshape) - T1.view(T1newshape) @ torch.linalg.solve(P, T2.view(T2newshape)), torch.sum(T1.view(T1newshape) @ torch.linalg.solve(P, T2.view(T2newshape)))

In [4]:
# Set up flow model
ndims = 2
batchsize = 1000
delta = 0.3

sigma_x = 0.15
sigma_y = 0.75

# Create diagonal covariance matrix
cov = torch.diag(torch.tensor([sigma_x**2, sigma_y**2]))
theta = torch.tensor(np.pi / 4)
#theta = torch.tensor(0)
target = benchmark.Gauss(batchsize,ndims,cov, theta)
# Define flows
K = 4
torch.manual_seed(0)

latent_size = ndims
hidden_units = 4
num_blocks = 2

flows = []
# for i in range(K):
#     flows += [nf.flows.CoupledRationalQuadraticSpline(latent_size, num_blocks, hidden_units)]
#     flows += [nf.flows.LULinearPermute(latent_size)]
# flows += [nf.flows.CoupledRationalQuadraticSpline(latent_size, num_blocks, hidden_units)]
# flows += [nf.flows.CoupledRationalQuadraticSpline(latent_size, num_blocks, hidden_units, reverse_mask=True)]

masks = nf.utils.iflow_binary_masks(latent_size)  # mask0
# masks = [torch.ones(num_input_channels)]
print(masks)
for mask in masks[::-1]:
    flows += [nf.flows.CoupledRationalQuadraticSpline(latent_size, num_blocks, hidden_units, mask=mask)]
# Set base distribuiton
q0 = nf.distributions.base.Uniform(2, 0.0, 1.0)
    
# Construct flow model
nfm = nf.NormalizingFlow(q0, flows, target)

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)

[tensor([1, 0], dtype=torch.uint8), tensor([0, 1], dtype=torch.uint8)]


In [5]:
print("total model params",sum(p.numel() for p in nfm.parameters() if p.requires_grad))

total model params 476


In [31]:
def target_prob(x):
    x_tensor = torch.tensor(x)
    return torch.exp(target.log_prob(x_tensor.view(1,x_tensor.shape[0])))
xell, well = xfacpy.GK15(0, 1)

# n=20
# points, weights = np.polynomial.legendre.leggauss(n)
# xell = 0.5 * (points + 1)
# # Adjust the weights accordingly (scale by 0.5 due to the transformation)
# well = 0.5 * weights

print("Quadrature grid number",len(xell))
par = xfacpy.TensorCI2Param()
par.weight = [well] * ndims
par.bondDim = 2
#xell = np.linspace(0.0, 1.0, 1000)
#well = xell*0.0 + (xell[1]-xell[0])
tci = xfacpy.CTensorCI2(target_prob, [xell] * ndims, par)
#tci = xfacpy.CTensorCI1(gaussian, [xell] * ndims)
# Estimate integral and error
itci = []
x = [0.78, 0.34]
#print(target_prob(x))#, gaussian(x))
hsweeps = range(1)
#for hsweep in hsweeps:
tci.iterate()
tci.iterate()
#itci.append(tci.get_TensorTrain().sum([well] * ndims))
itci.append(tci.tt.sum([well] * ndims))
print("Integration value",itci[-1])
print("In sample error",tci.pivotError[-1])
print("MPS rank", [M.shape for M in tci.tt.core])

Quadrature grid number 15
Integration value 0.45322325581371137
In sample error 0.9958070055600824
MPS rank [(1, 15, 2), (2, 15, 1)]


In [7]:
Iset = tci.getIset()
Jset = tci.getJset()
o = list(range(len(xell)))
loss = torch.zeros(1)
F = Ttensor(Iset[0], o, Jset[0], xell, ndims, nfm, well)
for l in range(len(Jset)-1):
    Pi = Pitensor(Iset[l], o, o, Jset[l+1], xell, ndims, nfm, well)
    T1 = Ttensor(Iset[l], o, Jset[l], xell, ndims, nfm, well)
    P = Ptensor(Iset[l+1], Jset[l],xell, ndims, nfm, well)
    print("test sum Pi",torch.sum(Pi))
    T2 = Ttensor(Iset[l+1], o, Jset[l+1],xell, ndims, nfm, well)
    T2newshape =(T2.shape[0], torch.prod(torch.tensor(T2.shape[1:])).item()) 
    Fnewshape =(torch.prod(torch.tensor(F.shape[0:-1])).item(), F.shape[-1]) 
    print(F.shape, P.shape, T2.shape)
    F = F.view(Fnewshape) @ torch.linalg.solve(P, T2.view(T2newshape))
    err = insample_error(T1, P, T2, Pi)

test sum Pi tensor(0.5633, grad_fn=<SumBackward0>)
torch.Size([1, 15, 15]) torch.Size([15, 15]) torch.Size([15, 15, 1])


In [8]:
# Plot target distribution
# x_np, _ = make_moons(2 ** 20, noise=0.1)
# plt.figure(figsize=(15, 15))
# plt.hist2d(x_np[:, 0], x_np[:, 1], bins=200)
# plt.show()

# Plot initial flow distribution
grid_size = 100
xx, yy = torch.meshgrid(torch.linspace(0.0, 1.0, grid_size), torch.linspace(0.0, 1.0, grid_size))
zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
zz = zz.to(device)

nfm.eval()
log_prob = nfm.p.log_prob(zz).to('cpu').view(*xx.shape)
prob = nfm.p.prob(zz).to('cpu').view(*xx.shape)
nfm.train()
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(13.5, 13.5))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal', 'box')
plt.savefig("target.png")
plt.show()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [9]:
#A wrapper 
def prob_model(z_list):
    z =torch.tensor(z_list)
    #print(xsize, x_input)
    #z = torch.reshape(z, (1, len(z_list)))
    x, logJ = nfm.forward_and_log_det(z.view((1, len(z_list))))
#     print(target.prob(z.view((1, len(z_list)))))
#     print(z, x, logJ, nfm.p.log_prob(x))
    return torch.exp(nfm.p.log_prob(x) + logJ)
# z = np.array([xell[0], xell[1]])
# prob_model(z)

In [10]:
@torch.no_grad()
def integrate(num_samples, nfm):
    """Importance sampling integration with flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw

    Returns:
      mean, variance
    """
#     z, log_q = nfm.sample(num_samples)
#     q = torch.exp(log_q)
#     func = nfm.p.prob(z)
#     return torch.mean(func / q, dim=0) , torch.mean(q, dim=0)
    z, log_q_ = nfm.q0(num_samples)
    x, logJ = nfm.forward_and_log_det(z)
    logJ -= log_q_
    #print(log_q_)
    func = torch.exp(nfm.p.log_prob(x) + logJ)
    func2 = nfm.p.prob(z)
    return torch.mean(func, dim=0), torch.mean(torch.exp(-logJ), dim=0)#, torch.mean(func2, dim=0)


In [11]:
# @torch.no_grad()
# def integrate(num_samples, nfm):
#     """Importance sampling integration with flow-based approximate distribution

#     Args:
#       num_samples: Number of samples to draw

#     Returns:
#       mean, variance
#     """
#     z, log_q = nfm.q0(num_samples)
#     for flow in nfm.flows:
#         z, log_det = flow(z)
#         log_q -= log_det
#     q = torch.exp(log_q)
#     func = nfm.p.prob(z)
#     return torch.mean(func / q, dim=0) , torch.mean(q, dim=0)


In [12]:
# nn = 10
# result = 0
# for _ in range(nn):
#     r = integrate(1000000, nfm)
#     print(r)
#     #result += r#*r[1]
#     print(r[0], r[1]) #, r[2])
#     result += r[0]#*r[1]
# print(result/nn)

In [13]:
xell, well = xfacpy.GK15(0, 1)
# print(xell)
# n = 50
# points, weights = np.polynomial.legendre.leggauss(n)
# xell = 0.5 * (points + 1)
# Adjust the weights accordingly (scale by 0.5 due to the transformation)
# well = 0.5 * weights

#xell = np.linspace(0.0, 1.0, 20)
#well = xell*0.0 + (xell[1]-xell[0])
# Add weight to parameter envokes environment error mode
par = xfacpy.TensorCI2Param()
par.weight = [well] * ndims
par.pivot1 = [1]*ndims
par.bondDim = 2 #max bond dimension


In [14]:
tci = xfacpy.CTensorCI2(prob_model, [xell] * ndims, par)
tci.iterate()
tci.iterate()
tci.iterate()
tci.iterate()
print("Initial integral error", itci[-1], tci.tt.sum([well]*ndims) - itci[-1] )

Initial integral error 0.56327004465845 -0.1131297508877217


In [15]:
# Train model
max_iter = 1000
num_samples = 1000
show_iter = 20
clip = 10.0

loss_hist = np.array([])
integral_error_hist = np.array([])
in_sample_error_list = np.array([])
int_error_hist = np.array([])
optimizer = torch.optim.Adam(nfm.parameters(), lr=8e-4)#, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter)

# n = 30
# points, weights = np.polynomial.legendre.leggauss(n)
# xfine = 0.5 * (points + 1)
# # Adjust the weights accordingly (scale by 0.5 due to the transformation)
# wfine = 0.5 * weights
tci = xfacpy.CTensorCI2(prob_model, [xell] * ndims, par) 
tci.iterate()
Iset = tci.getIset()
Jset = tci.getJset()
o = list(range(len(xell)))
#integral_error_hist = np.append(integral_error_hist, np.fabs(tci.tt.sum([well]*ndims) - itci[-1]))
for it in tqdm(range(max_iter)):
    optimizer.zero_grad()
    #tci.iterate()      
    #hsweeps = range(1, 2)
    #for hsweep in hsweeps:
    loss = torch.zeros(1)
    #F = Ttensor(Iset[0], o, Jset[0], xell, ndims, nfm, well)
    for l in range(len(Jset)-1):
        Pi = Pitensor(Iset[l], o, o, Jset[l+1], xell, ndims, nfm, well)#, xfine, wfine)
        T1 = Ttensor(Iset[l], o, Jset[l], xell, ndims, nfm, well)#, xfine, wfine)
        P = Ptensor(Iset[l+1], Jset[l],xell, ndims, nfm, well)
        #print("test sum Pi",torch.sum(Pi))
        T2 = Ttensor(Iset[l+1], o, Jset[l+1],xell, ndims, nfm, well)#, xfine, wfine)
        # err = torch.mean(torch.square(insample_error(T1, P, T2, Pi)))
        err, int_value = insample_error(T1, P, T2, Pi)
        #print(err,int_value)
        #loss += torch.sum(torch.square(err))
        loss += torch.square(torch.sum(err))
    loss = loss/(len(Jset)-1)
    #loss = nfm.reverse_kld(num_samples)
    int_error_hist=np.append(int_error_hist, np.fabs(int_value.item()-itci[-1]))
    integral_error_hist = np.append(integral_error_hist, np.fabs(tci.tt.sum([well]*ndims) - itci[-1]))
    in_sample_error_list = np.append(in_sample_error_list, tci.pivotError[-1])
    # Do backprop and optimizer step
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
#         if (it+1)%show_iter == 0:
#             #print(loss, loss2)
#             # make_dot(z, params=dict(nfm.named_parameters()))
#             for name, param in nfm.named_parameters():
#                 print(f"Gradient of {name} is \n{param.grad}")
        torch.nn.utils.clip_grad_value_(nfm.parameters(), clip)
        optimizer.step()
    
    # Log loss
    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
    
    # Plot learned distribution
    if (it + 1) % show_iter == 0:
        print("in-sample tci error",tci.pivotError[-1])
        print("loss",loss)
        print([len(I) for I in tci.getIset()])
        print([len(J) for J in tci.getJset()])
        #print("MPS rank", [M.shape for M in tci.get_TensorTrain().core])
        print("MPS rank", [M.shape for M in tci.tt.core])
        #print("integral error", tci.get_TensorTrain().sum([well]*ndims) - itci[-1] )
        print("integral error", tci.tt.sum([well]*ndims) - itci[-1] )
        print("integral error", int_value - itci[-1] )
        print("current loss", loss_hist[-1])
        tci = xfacpy.CTensorCI2(prob_model, [xell] * ndims, par) 
        for _ in range(5):
            tci.iterate()
        Iset = tci.getIset()
        Jset = tci.getJset()
        #print("integral error", torch.sum(Pi) - itci[-1] )
        nfm.eval()
        log_prob = nfm.log_prob(zz)
        nfm.train()
        prob = torch.exp(log_prob.to('cpu').view(*xx.shape))
        prob[torch.isnan(prob)] = 0


    #scheduler.step()

  2%|▋                                        | 18/1000 [00:00<00:47, 20.75it/s]

in-sample tci error 0.756589812131395
loss tensor([0.0476], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.24477348771031898
integral error tensor(-0.2157, grad_fn=<SubBackward0>)
current loss 0.04755919799208641


  4%|█▌                                       | 39/1000 [00:02<00:52, 18.31it/s]

in-sample tci error 1.140667234419877
loss tensor([0.0015], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.039524066196221286
integral error tensor(-0.0340, grad_fn=<SubBackward0>)
current loss 0.001451453659683466


  6%|██▍                                      | 59/1000 [00:04<00:52, 18.05it/s]

in-sample tci error 1.2211555325266001
loss tensor([0.0310], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.20114435878519848
integral error tensor(-0.1811, grad_fn=<SubBackward0>)
current loss 0.031000211834907532


  8%|███▏                                     | 79/1000 [00:05<00:53, 17.32it/s]

in-sample tci error 1.708171586783747
loss tensor([0.0167], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.17969573958116764
integral error tensor(-0.1188, grad_fn=<SubBackward0>)
current loss 0.01665375381708145


 10%|████                                     | 99/1000 [00:07<00:49, 18.22it/s]

in-sample tci error 2.1727013871020593
loss tensor([0.0148], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.1591085789515253
integral error tensor(-0.1084, grad_fn=<SubBackward0>)
current loss 0.01478580106049776


 12%|████▋                                   | 117/1000 [00:08<00:50, 17.49it/s]

in-sample tci error 3.005945179911179
loss tensor([0.0086], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.1686789027199524
integral error tensor(-0.1376, grad_fn=<SubBackward0>)
current loss 0.008556713350117207


 14%|█████▌                                  | 138/1000 [00:10<00:46, 18.43it/s]

in-sample tci error 2.2627014518842614
loss tensor([0.0202], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.2672509660990901
integral error tensor(-0.2789, grad_fn=<SubBackward0>)
current loss 0.02018253318965435


 16%|██████▎                                 | 159/1000 [00:11<00:46, 18.07it/s]

in-sample tci error 2.7790104891426184
loss tensor([0.0049], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.2911052303628417
integral error tensor(-0.2479, grad_fn=<SubBackward0>)
current loss 0.004941368941217661


 18%|███████                                 | 177/1000 [00:13<00:48, 16.87it/s]

in-sample tci error 1.6028306476059822
loss tensor([0.0132], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.3590811640716911
integral error tensor(-0.2924, grad_fn=<SubBackward0>)
current loss 0.013211044482886791


 20%|███████▉                                | 198/1000 [00:14<00:43, 18.30it/s]

in-sample tci error 2.9132811657679323
loss tensor([0.0003], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.253215178076983
integral error tensor(-0.1491, grad_fn=<SubBackward0>)
current loss 0.0003020540752913803


 22%|████████▊                               | 219/1000 [00:16<00:42, 18.31it/s]

in-sample tci error 1.2922514498578486
loss tensor([0.0019], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.2510714388857604
integral error tensor(-0.1639, grad_fn=<SubBackward0>)
current loss 0.0018570476677268744


 24%|█████████▍                              | 237/1000 [00:17<00:44, 17.31it/s]

in-sample tci error 2.683890260269327
loss tensor([3.0226e-05], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.127218000965937
integral error tensor(-0.1355, grad_fn=<SubBackward0>)
current loss 3.0226094168028794e-05


 26%|██████████▎                             | 258/1000 [00:19<00:40, 18.17it/s]

in-sample tci error 1.2823133314228263
loss tensor([0.0003], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.1978469705638991
integral error tensor(-0.0471, grad_fn=<SubBackward0>)
current loss 0.00033243338111788034


 28%|███████████▏                            | 279/1000 [00:20<00:39, 18.21it/s]

in-sample tci error 4.6606873803777304
loss tensor([0.0002], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.16432061369583367
integral error tensor(-0.0747, grad_fn=<SubBackward0>)
current loss 0.00017680485325399786


 30%|███████████▉                            | 297/1000 [00:22<00:40, 17.17it/s]

in-sample tci error 3.449782193030796
loss tensor([3.4931e-05], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.0729311647367632
integral error tensor(-0.0647, grad_fn=<SubBackward0>)
current loss 3.493104668450542e-05


 32%|████████████▋                           | 317/1000 [00:23<00:37, 18.01it/s]

in-sample tci error 2.8047780390063224
loss tensor([3.7952e-07], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.0655989153646242
integral error tensor(-0.0677, grad_fn=<SubBackward0>)
current loss 3.79523754645561e-07


 34%|█████████████▌                          | 338/1000 [00:25<00:36, 18.12it/s]

in-sample tci error 2.873620798447762
loss tensor([7.2042e-07], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06771190397757193
integral error tensor(-0.0675, grad_fn=<SubBackward0>)
current loss 7.20417460797762e-07


 36%|██████████████▎                         | 358/1000 [00:27<00:35, 18.22it/s]

in-sample tci error 2.9361410950696705
loss tensor([2.7018e-08], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06746702616334649
integral error tensor(-0.0677, grad_fn=<SubBackward0>)
current loss 2.7017513559712825e-08


 38%|███████████████▏                        | 379/1000 [00:28<00:33, 18.59it/s]

in-sample tci error 2.913410260386341
loss tensor([1.6780e-09], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06766740238115371
integral error tensor(-0.0677, grad_fn=<SubBackward0>)
current loss 1.6779913014630665e-09


 40%|███████████████▉                        | 399/1000 [00:30<00:33, 17.93it/s]

in-sample tci error 2.914781110024006
loss tensor([1.3597e-09], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.0676578664239228
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.3596983539088114e-09


 42%|████████████████▋                       | 418/1000 [00:31<00:33, 17.47it/s]

in-sample tci error 2.9173501122695615
loss tensor([2.3807e-11], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06763963379511284
integral error tensor(-0.0677, grad_fn=<SubBackward0>)
current loss 2.380658170597627e-11


 44%|█████████████████▍                      | 437/1000 [00:33<00:31, 17.91it/s]

in-sample tci error 2.916311152573064
loss tensor([5.1787e-12], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764894230746993
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 5.178749790113457e-12


 46%|██████████████████▎                     | 458/1000 [00:34<00:30, 17.95it/s]

in-sample tci error 2.916332117557318
loss tensor([2.6910e-13], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764754566452869
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.690981118491642e-13


 48%|███████████████████▏                    | 479/1000 [00:36<00:28, 18.11it/s]

in-sample tci error 2.9164013818127117
loss tensor([1.2690e-13], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764871379566056
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.2690044327856587e-13


 50%|███████████████████▉                    | 497/1000 [00:37<00:29, 17.26it/s]

in-sample tci error 2.9163871535061396
loss tensor([2.4595e-15], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764788048938647
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.4594584181747425e-15


 52%|████████████████████▋                   | 518/1000 [00:39<00:26, 18.36it/s]

in-sample tci error 0.8354942981249661
loss tensor([1.9921e-12], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764801832265988
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.9921046149479205e-12


 54%|█████████████████████▌                  | 539/1000 [00:40<00:25, 18.23it/s]

in-sample tci error 2.9163847338964164
loss tensor([1.3326e-12], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764709290070153
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.3325800966162982e-12


 56%|██████████████████████▎                 | 557/1000 [00:42<00:24, 17.73it/s]

in-sample tci error 0.9404881681523689
loss tensor([3.1147e-12], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764745876279299
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 3.114717685159807e-12


 58%|███████████████████████                 | 578/1000 [00:43<00:23, 17.91it/s]

in-sample tci error 2.9164382980789285
loss tensor([2.1962e-12], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.0676492281243542
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.196226273759594e-12


 60%|███████████████████████▉                | 599/1000 [00:45<00:22, 18.20it/s]

in-sample tci error 2.916384996122808
loss tensor([9.2555e-13], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.0676470710542057
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 9.255521696283076e-13


 62%|████████████████████████▋               | 617/1000 [00:46<00:22, 17.32it/s]

in-sample tci error 2.916382431858029
loss tensor([1.5268e-12], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764626201934504
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.5267869434010262e-12


 64%|█████████████████████████▌              | 638/1000 [00:48<00:20, 18.08it/s]

in-sample tci error 2.916383978427312
loss tensor([9.4401e-13], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764686922530805
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 9.440109284553078e-13


 66%|██████████████████████████▎             | 659/1000 [00:50<00:18, 18.13it/s]

in-sample tci error 2.9163831807624194
loss tensor([4.0699e-15], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764698723239837
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 4.0698781150760865e-15


 68%|███████████████████████████             | 677/1000 [00:51<00:18, 17.17it/s]

in-sample tci error 1.0875942468497688
loss tensor([6.4113e-13], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764736237469882
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 6.411278300789969e-13


 70%|███████████████████████████▉            | 698/1000 [00:53<00:16, 18.14it/s]

in-sample tci error 2.916397426634942
loss tensor([1.6413e-13], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764781525150648
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.6412652487085566e-13


 72%|████████████████████████████▊           | 719/1000 [00:54<00:15, 18.23it/s]

in-sample tci error 2.916396916714432
loss tensor([1.1601e-12], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764762962562382
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.1600896025060203e-12


 74%|█████████████████████████████▍          | 737/1000 [00:56<00:15, 17.51it/s]

in-sample tci error 2.916393108089694
loss tensor([1.3843e-13], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.0676476804119534
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.384311502233837e-13


 76%|██████████████████████████████▎         | 757/1000 [00:57<00:13, 18.18it/s]

in-sample tci error 2.9164082456582934
loss tensor([8.2031e-15], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764768691186179
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 8.203127847133951e-15


 78%|███████████████████████████████         | 778/1000 [00:59<00:12, 18.12it/s]

in-sample tci error 2.91637993389921
loss tensor([9.3030e-16], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764791814458193
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 9.30299674101187e-16


 80%|███████████████████████████████▉        | 799/1000 [01:00<00:10, 18.36it/s]

in-sample tci error 0.8354933213607089
loss tensor([2.1516e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764828600258976
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.1515992112974835e-14


 82%|████████████████████████████████▋       | 817/1000 [01:02<00:10, 17.31it/s]

in-sample tci error 2.916390994047912
loss tensor([2.1516e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.0676483011041088
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.1515992112974835e-14


 84%|█████████████████████████████████▍      | 837/1000 [01:03<00:09, 18.04it/s]

in-sample tci error 0.8354933213607089
loss tensor([1.2077e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830110068693
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.2077144839750531e-14


 86%|██████████████████████████████████▎     | 859/1000 [01:05<00:07, 19.00it/s]

in-sample tci error 2.916390994047912
loss tensor([1.2077e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830110068693
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.2077144839750531e-14


 88%|███████████████████████████████████     | 877/1000 [01:06<00:07, 17.28it/s]

in-sample tci error 2.916390994047912
loss tensor([2.1516e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830110068693
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.1515992112974835e-14


 90%|███████████████████████████████████▉    | 897/1000 [01:08<00:05, 18.10it/s]

in-sample tci error 2.916390994047912
loss tensor([2.1516e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830110068693
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.1515992112974835e-14


 92%|████████████████████████████████████▋   | 918/1000 [01:10<00:04, 18.00it/s]

in-sample tci error 2.916390994047912
loss tensor([1.1772e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830033066216
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.177205034841311e-14


 94%|█████████████████████████████████████▌  | 939/1000 [01:11<00:03, 16.42it/s]

in-sample tci error 2.916390994047912
loss tensor([1.2077e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764829954850404
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.2077144839750531e-14


 96%|██████████████████████████████████████▎ | 958/1000 [01:13<00:02, 17.39it/s]

in-sample tci error 2.916390994047912
loss tensor([2.0838e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830110068693
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 2.0838365755171395e-14


 98%|███████████████████████████████████████ | 978/1000 [01:14<00:01, 17.92it/s]

in-sample tci error 2.916390994047912
loss tensor([1.2077e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830110068693
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.2077144839750531e-14


100%|███████████████████████████████████████▉| 999/1000 [01:16<00:00, 18.14it/s]

in-sample tci error 2.916390994047912
loss tensor([1.1772e-14], grad_fn=<DivBackward0>)
[1, 2]
[2, 1]
MPS rank [(1, 15, 2), (2, 15, 1)]
integral error -0.06764830032195068
integral error tensor(-0.0676, grad_fn=<SubBackward0>)
current loss 1.177205034841311e-14


100%|███████████████████████████████████████| 1000/1000 [01:16<00:00, 13.00it/s]


In [23]:
# Plot loss
plt.figure(figsize=(13.5, 10))
plt.plot(loss_hist[0:500], linewidth = 2.0,label='loss')
plt.xlabel("Iteration")
#print(in_sample_error_list)
#plt.plot(in_sample_error_list, label='in_sample')
plt.legend()
plt.savefig("loss.png")
plt.show()

In [24]:
plt.figure(figsize=(13.5, 10))
#plt.plot(integral_error_hist, label='Int Error')
print(int_error_hist)
plt.plot(int_error_hist[0:500], linewidth = 2.0, label='Integration Error')
plt.xlabel("Iteration")
#print(in_sample_error_list)
#plt.plot(in_sample_error_list, label='in_sample')
plt.legend()
plt.savefig("error.png")
plt.show()

[0.24477352 0.2435307  0.24225248 0.24093495 0.23957614 0.23817549
 0.23673589 0.23526041 0.23375188 0.23221312 0.23064719 0.22905595
 0.22744162 0.22580643 0.22415228 0.22248251 0.22079824 0.21910299
 0.21739743 0.21568475 0.03952403 0.03901191 0.03853555 0.03813578
 0.03783275 0.03741963 0.03709079 0.03681005 0.03655328 0.0363114
 0.03605212 0.03582479 0.03560377 0.03538306 0.03516198 0.03494014
 0.03471656 0.03449113 0.03426344 0.03403349 0.20114435 0.2005489
 0.19978754 0.19889138 0.19788877 0.19680486 0.19566326 0.194486
 0.19329433 0.19211097 0.19095306 0.18983943 0.18878783 0.18781383
 0.18687992 0.18598298 0.18500934 0.18368087 0.1823736  0.18109761
 0.17969574 0.17798398 0.17624055 0.174244   0.17207764 0.16968251
 0.1663538  0.1639076  0.1615638  0.15848146 0.15504221 0.15140192
 0.14768819 0.14385383 0.13991628 0.13587956 0.13174535 0.12751685
 0.12319834 0.11879383 0.15910856 0.15528531 0.15137769 0.14741944
 0.14345784 0.13951085 0.13561146 0.13193126 0.12904943 0.12632422

In [37]:
nfm.eval()
log_prob = nfm.log_prob(zz)
#x_new, logJ = nfm.forward_and_log_det(zz)
#prob = torch.exp(nfm.p.log_prob(x_new) + logJ).view(*xx.shape)
log_f = nfm.p.log_prob(zz)
prob = torch.exp((log_f-log_prob).to('cpu').view(*xx.shape))

prob[torch.isnan(prob)] = 0
plt.figure(figsize=(13.5, 13.5))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal', 'box')
plt.savefig("learned.png")
plt.show()

In [11]:
# Plot target distribution
# x_np, _ = make_moons(2 ** 20, noise=0.1)
# plt.figure(figsize=(15, 15))
# plt.hist2d(x_np[:, 0], x_np[:, 1], bins=200)
# plt.show()

# Plot initial flow distribution
grid_size = 100
xx, yy = torch.meshgrid(torch.linspace(0.0, 1.0, grid_size), torch.linspace(0.0, 1.0, grid_size))
zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
zz = zz.to(device)

nfm.eval()
#log_prob = nfm.p.log_prob(zz).to('cpu').view(*xx.shape)
#prob = nfm.p.prob(zz).to('cpu').view(*xx.shape)
#prob = nfm.p.prob(zz).to('cpu').view(*xx.shape)
#print(prob, log_prob)
log_prob = nfm.q0.log_prob(zz).to('cpu').view(*xx.shape)
#print(prob)
#log_prob = log_prob - log_q
print(log_prob,log_q)
nfm.train()
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal', 'box')
plt.show()

NameError: name 'log_q' is not defined