In [24]:
import os
import json
import argparse
import pprint
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils import data
from bnaf import *
from tqdm import trange
from data.generate2d import sample2d, energy2d

# standard imports
import torch
import torch.nn as nn
from sklearn.datasets import make_moons
# from generate2d import sample2d, energy2d

# FrEIA imports
import FrEIA.framework as Ff
import FrEIA.modules as Fm

from nflows.distributions import normal


In [25]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

BATCHSIZE = 1000
N_DIM = 2

# we define a subnet for use inside an affine coupling block
# for more detailed information see the full tutorial
def subnet_fc(dims_in, dims_out):
    return nn.Sequential(nn.Linear(dims_in, 512), nn.ReLU(),
                         nn.Linear(512,  dims_out))

# a simple chain of operations is collected by ReversibleSequential
inn = Ff.SequenceINN(N_DIM)
for k in range(8):
    inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, permute_soft=True)

inn.to(device)
optimizer = torch.optim.Adam(inn.parameters(), lr=0.001)

# a very basic training loop
for i in range(2000):
    optimizer.zero_grad()
    # sample data from the moons distribution
    data, label = make_moons(n_samples=BATCHSIZE, noise=0.05)
    x = torch.Tensor(data).to(device)
    # pass to INN and get transformed variable z and log Jacobian determinant
    z, log_jac_det = inn(x)
    
#     print(z.size())
#     plt.figure()
#     plt.plot(z.detach().numpy()[:,0], z.detach().numpy()[:,1],'r.')
#     plt.plot(x.detach().numpy()[:,0], x.detach().numpy()[:,1],'b.')
    
    # calculate the negative log-likelihood of the model with a standard normal prior
#     loss = 0.5*torch.sum(z**2, 1) - log_jac_det
    
#     tt1 = 0.5*torch.sum(z**2, 1)
#     tt2 = torch.distributions.Normal(torch.zeros_like(z), torch.ones_like(z)).log_prob(z).sum(-1)
    shape = z.shape[1:]
    log_z = normal.StandardNormal(shape=shape).log_prob(z)
    
#     print(tt1.mean())
#     print(tt2.mean())
#     print(tt3.mean())
    
    loss = log_z + log_jac_det
    loss = -loss.mean() / N_DIM
    # backpropagate and update the weights
    loss.backward()
    optimizer.step()
    
    if i % 100==0:
        print(i,loss)


tensor(0.5868, grad_fn=<MeanBackward0>)
tensor(-2.4247, grad_fn=<MeanBackward0>)
tensor(-2.4247, grad_fn=<MeanBackward0>)
0 tensor(0.3183, grad_fn=<DivBackward0>)
tensor(0.6718, grad_fn=<MeanBackward0>)
tensor(-2.5097, grad_fn=<MeanBackward0>)
tensor(-2.5097, grad_fn=<MeanBackward0>)
tensor(0.7935, grad_fn=<MeanBackward0>)
tensor(-2.6314, grad_fn=<MeanBackward0>)
tensor(-2.6314, grad_fn=<MeanBackward0>)
tensor(0.9576, grad_fn=<MeanBackward0>)
tensor(-2.7955, grad_fn=<MeanBackward0>)
tensor(-2.7955, grad_fn=<MeanBackward0>)
tensor(1.1369, grad_fn=<MeanBackward0>)
tensor(-2.9748, grad_fn=<MeanBackward0>)
tensor(-2.9748, grad_fn=<MeanBackward0>)
tensor(1.2671, grad_fn=<MeanBackward0>)
tensor(-3.1050, grad_fn=<MeanBackward0>)
tensor(-3.1050, grad_fn=<MeanBackward0>)
tensor(1.2893, grad_fn=<MeanBackward0>)
tensor(-3.1272, grad_fn=<MeanBackward0>)
tensor(-3.1272, grad_fn=<MeanBackward0>)
tensor(1.2881, grad_fn=<MeanBackward0>)
tensor(-3.1260, grad_fn=<MeanBackward0>)
tensor(-3.1260, grad_fn=

tensor(0.9854, grad_fn=<MeanBackward0>)
tensor(-2.8232, grad_fn=<MeanBackward0>)
tensor(-2.8232, grad_fn=<MeanBackward0>)


KeyboardInterrupt: 

In [None]:
plt.figure()
plt.plot(z.detach().numpy()[:,0], z.detach().numpy()[:,1],'r.')
plt.plot(x.detach().numpy()[:,0], x.detach().numpy()[:,1],'b.')

In [None]:
data, label = make_moons(n_samples=1000, noise=0.05)
plt.figure()
plt.plot(data[:,0], data[:,1],'.')


# sample from the INN by sampling from a standard normal and transforming
# it in the reverse direction
nsam = 1000
zzz = torch.randn(nsam, N_DIM)
samples0, _ = inn(zzz, rev=True)

samples = samples0.detach().numpy()
plt.figure()
plt.plot(samples[:,0], samples[:,1],'.')