In [1]:
import argparse
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from nits.model import *
from nits.layer import *
from nits.fc_model import *
from nits.cnn_model import *
from maf.datasets import *
from nits.resmade import ResidualMADE
    
def create_batcher(x, batch_size=1):
    idx = 0
    p = torch.randperm(len(x))
    x = x[p]

    while idx + batch_size < len(x):
        yield torch.tensor(x[idx:idx+batch_size], device=device)
        idx += batch_size

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('-d', '--dataset', type=str, default='gas')
parser.add_argument('-g', '--gpu', type=str, default='')
parser.add_argument('-r', '--rotate', type=bool, default=False)

args = parser.parse_args(['-g', '6', '-d', 'gas'])

device = 'cuda:' + args.gpu if args.gpu else 'cpu'
print('device:', device)

lr = 1e-3
n_residual_blocks = 4
hidden_dim = 512
use_batch_norm = False
zero_initialization = True
if args.dataset == 'gas':
    data = gas.GAS()
    dropout_probability = 0.1
    nits_arch = [16, 16, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'power':
    data = power.POWER()
    dropout_probability = 0.1
    nits_arch = [16, 16, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'miniboone':
    data = miniboone.MINIBOONE()
    dropout_probability = 0.5
    nits_arch = [16, 16, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'hepmass':
    data = hepmass.HEPMASS()
    dropout_probability = 0.2
    nits_arch = [16, 16, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'bsds300':
    data = bsds300.BSDS300()
    dropout_probability = 0.2
    nits_arch = [16, 16, 1]
    gamma = 1 - 5e-7

d = data.trn.x.shape[1]

max_val = max(data.trn.x.max(), data.val.x.max(), data.tst.x.max())
min_val = min(data.trn.x.min(), data.val.x.min(), data.tst.x.min())
max_val, min_val = torch.tensor(max_val).to(device).float(), torch.tensor(min_val).to(device).float()

nits_model = NITS(d=d, start=min_val, end=max_val, monotonic_const=1e-5,
                             A_constraint='neg_exp', arch=[1] + nits_arch,
                             final_layer_constraint='softmax',
                             softmax_temperature=False).to(device)

# model_arch = [hidden_dim] * n_residual_blocks
# model = RotationParamModel(arch=[d] + model_arch + [nits_model.n_params], 
#                            rotate=args.rotate, nits_model=nits_model).to(device)
# shadow = RotationParamModel(arch=[d] + model_arch + [nits_model.n_params], 
#                             rotate=args.rotate, nits_model=nits_model).to(device)

model = ResMADEModel(
    d=d, 
    rotate=args.rotate, 
    nits_model=nits_model,
    n_residual_blocks=n_residual_blocks,
    hidden_dim=hidden_dim,
    dropout_probability=dropout_probability,
    use_batch_norm=use_batch_norm,
    zero_initialization=zero_initialization
).to(device)

shadow = ResMADEModel(
    d=d, 
    rotate=args.rotate, 
    nits_model=nits_model,
    n_residual_blocks=n_residual_blocks,
    hidden_dim=hidden_dim,
    dropout_probability=dropout_probability,
    use_batch_norm=use_batch_norm,
    zero_initialization=zero_initialization
).to(device)

device: cuda:6


  self.register_buffer('start_val', torch.tensor(start))
  self.register_buffer('end_val', torch.tensor(end))
  self.register_buffer('start', torch.tensor(start).reshape(1, 1).tile(1, d))
  self.register_buffer('end', torch.tensor(end).reshape(1, 1).tile(1, d))


In [3]:
max_epochs = 20000
batch_size = 512
optim = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=gamma)

# # initialize weight norm
# for i, x in enumerate(create_batcher(data.trn.x, batch_size=batch_size)):
#     params = model(x)
#     break
    
model = EMA(model, shadow, decay=0.9995).to(device)

In [4]:
time_ = time.time()
for epoch in range(max_epochs):
    model.train()
    train_ll = 0.
    for i, x in enumerate(create_batcher(data.trn.x, batch_size=batch_size)):
        orig_x = x.cpu().detach().clone()
        ll = model(x)
        optim.zero_grad()
        (-ll).backward()
        train_ll += ll.detach().cpu().numpy()

        optim.step()
        scheduler.step()
        model.update()

    if epoch % 10 == 0:
        # compute train loss
        train_ll /= i * batch_size

        with torch.no_grad():
            model.eval()
            val_ll = 0.
            lr = optim.param_groups[0]['lr']
            for i, x in enumerate(create_batcher(data.val.x, batch_size=batch_size)):
                x = torch.tensor(x, device=device)
                ll = model(x)
                val_ll += ll.detach().cpu().numpy()

            val_ll /= i * batch_size
            fmt_str1 = 'epoch: {:4d}, time: {:.2f}, train_ll: {:.4f},'
            fmt_str2 = ' val_ll: {:.4f}, lr: {:.4e}'

            print((fmt_str1 + fmt_str2).format(
                epoch,
                time.time() - time_,
                train_ll,
                val_ll,
                lr))
            
            time_ = time.time()



epoch:    0, time: 26.03, train_ll: 3.1337, val_ll: -22.4827, lr: 9.9917e-04
epoch:   10, time: 260.53, train_ll: 10.3799, val_ll: 11.1316, lr: 9.9089e-04
epoch:   20, time: 284.85, train_ll: 11.0317, val_ll: 11.8190, lr: 9.8268e-04
epoch:   30, time: 302.61, train_ll: 11.3194, val_ll: 12.0902, lr: 9.7454e-04
epoch:   40, time: 301.44, train_ll: 11.5056, val_ll: 12.2510, lr: 9.6646e-04
epoch:   50, time: 301.64, train_ll: 11.6294, val_ll: 12.3559, lr: 9.5846e-04
epoch:   60, time: 301.92, train_ll: 11.7186, val_ll: 12.4282, lr: 9.5051e-04
epoch:   70, time: 303.94, train_ll: 11.7954, val_ll: 12.4859, lr: 9.4264e-04
epoch:   80, time: 301.64, train_ll: 11.8577, val_ll: 12.5333, lr: 9.3483e-04
epoch:   90, time: 303.10, train_ll: 11.9096, val_ll: 12.5716, lr: 9.2708e-04
epoch:  100, time: 302.85, train_ll: 11.9554, val_ll: 12.6047, lr: 9.1940e-04
epoch:  110, time: 302.97, train_ll: 11.9950, val_ll: 12.6324, lr: 9.1178e-04
epoch:  120, time: 302.56, train_ll: 12.0337, val_ll: 12.6604, lr

epoch: 1060, time: 295.55, train_ll: 12.8467, val_ll: 13.0453, lr: 4.1364e-04
epoch: 1070, time: 295.83, train_ll: 12.8493, val_ll: 13.0439, lr: 4.1022e-04
epoch: 1080, time: 295.79, train_ll: 12.8376, val_ll: 13.0433, lr: 4.0682e-04
epoch: 1090, time: 295.95, train_ll: 12.8585, val_ll: 13.0470, lr: 4.0345e-04
epoch: 1100, time: 300.21, train_ll: 12.8616, val_ll: 13.0486, lr: 4.0010e-04
epoch: 1110, time: 303.70, train_ll: 12.8624, val_ll: 13.0492, lr: 3.9679e-04
epoch: 1120, time: 298.02, train_ll: 12.8633, val_ll: 13.0491, lr: 3.9350e-04
epoch: 1130, time: 296.64, train_ll: 12.8720, val_ll: 13.0516, lr: 3.9024e-04
epoch: 1140, time: 296.45, train_ll: 12.8709, val_ll: 13.0525, lr: 3.8701e-04
epoch: 1150, time: 297.88, train_ll: 12.8760, val_ll: 13.0541, lr: 3.8380e-04
epoch: 1160, time: 299.14, train_ll: 12.8774, val_ll: 13.0524, lr: 3.8062e-04
epoch: 1170, time: 299.87, train_ll: 12.8825, val_ll: 13.0534, lr: 3.7747e-04
epoch: 1180, time: 293.21, train_ll: 12.8863, val_ll: 13.0537, l

KeyboardInterrupt: 

In [5]:
with torch.no_grad():
    model.eval()
    test_ll = 0.
    for i, x in enumerate(create_batcher(data.tst.x, batch_size=batch_size)):
        x = torch.tensor(x, device=device)
        ll = model(x)
        test_ll += ll.detach().cpu().numpy()
        
    test_ll /= i * batch_size

    print('test_ll: {:4f}'.format(test_ll))

  """


test_ll: 13.079126


In [None]:
sampled_x = model.sample(10000)

In [None]:
plt.scatter(sampled_x[:,0].cpu(), sampled_x[:,1].cpu(), s=1, alpha=0.05)

In [None]:
plt.scatter(data.trn.x[:,0].cpu(), data.trn.x[:,1].cpu(), alpha=0.05)