In [1]:
import argparse
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from nits.model import *
from nits.fc_model import *
from maf.datasets import *
from nits.resmade import ResidualMADE
    
def create_batcher(x, batch_size=1):
    idx = 0
    p = torch.randperm(len(x))
    x = x[p]

    while idx + batch_size < len(x):
        yield torch.tensor(x[idx:idx+batch_size], device=device)
        idx += batch_size
    else:
        yield torch.tensor(x[idx:], device=device)

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('-d', '--dataset', type=str, default='gas')
parser.add_argument('-g', '--gpu', type=str, default='')
parser.add_argument('-r', '--rotate', type=bool, default=False)

args = parser.parse_args(['-g', '0', '-d', 'gas'])

device = 'cuda:' + args.gpu if args.gpu else 'cpu'
print('device:', device)

lr = 5e-4
gamma = 1.
polyak_decay = 0.9995
n_residual_blocks = 4
hidden_dim = 512
use_batch_norm = False
zero_initialization = True
weight_norm = False
batch_size = 512
print('batch_size:', batch_size)
if args.dataset == 'gas':
    # training set size: 852,174
    data = gas.GAS()
    dropout_probability = 0.1
    nits_arch = [8, 8, 8, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'power':
    # training set size: 1,659,917
    data = power.POWER()
    dropout_probability = 0.1
    nits_arch = [8, 8, 8, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'miniboone':
    # training set size: 29,556
    data = miniboone.MINIBOONE()
    dropout_probability = 0.5
    nits_arch = [8, 8, 8, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'hepmass':
    # training set size: 315,123
    data = hepmass.HEPMASS()
    dropout_probability = 0.5
    nits_arch = [8, 8, 8, 1]
    gamma = 1 - 5e-7
elif args.dataset == 'bsds300':
    # training set size: 1,000,000
    data = bsds300.BSDS300()
    dropout_probability = 0.2
    nits_arch = [8, 8, 8, 1]
    gamma = 1 - 5e-7

d = data.trn.x.shape[1]

max_val = max(data.trn.x.max(), data.val.x.max(), data.tst.x.max())
min_val = min(data.trn.x.min(), data.val.x.min(), data.tst.x.min())
max_val, min_val = torch.tensor(max_val).to(device).float(), torch.tensor(min_val).to(device).float()

normalizer = Normalizer(d, d)
normalizer.set_weights(torch.tensor(data.trn.x, device='cpu'), device=device)

nits_model = NITS(d=d, start=min_val, end=max_val, monotonic_const=1e-5,
                 A_constraint='neg_exp', arch=[1] + nits_arch,
                 final_layer_constraint='softmax',
                 softmax_temperature=False,
                 normalizer=normalizer).to(device)

model = ResMADEModel(
    d=d, 
    rotate=args.rotate, 
    nits_model=nits_model,
    n_residual_blocks=n_residual_blocks,
    hidden_dim=hidden_dim,
    dropout_probability=dropout_probability,
    use_batch_norm=use_batch_norm,
    zero_initialization=zero_initialization,
    weight_norm=weight_norm,
    normalizer=normalizer
)

shadow = ResMADEModel(
    d=d, 
    rotate=args.rotate, 
    nits_model=nits_model,
    n_residual_blocks=n_residual_blocks,
    hidden_dim=hidden_dim,
    dropout_probability=dropout_probability,
    use_batch_norm=use_batch_norm,
    zero_initialization=zero_initialization,
    weight_norm=weight_norm,
    normalizer=normalizer
)

device: cuda:0
batch_size: 512


  self.register_buffer('start_val', torch.tensor(start))
  self.register_buffer('end_val', torch.tensor(end))
  self.register_buffer('start', torch.tensor(start).reshape(1, 1).tile(1, d))
  self.register_buffer('end', torch.tensor(end).reshape(1, 1).tile(1, d))


In [3]:
# initialize weight norm
if weight_norm:
    with torch.no_grad():
        for i, x in enumerate(create_batcher(data.trn.x, batch_size=4096)):
            params = model(x)
            break
    
model = EMA(model, shadow, decay=polyak_decay).to(device)

max_epochs = 20000
print_every = 10
optim = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=gamma)

In [4]:
time_ = time.time()
train_ll = 0.
for epoch in range(max_epochs):
    model.train()
    for i, x in enumerate(create_batcher(data.trn.x, batch_size=batch_size)):
        orig_x = x.cpu().detach().clone()
        ll = model(x)
        optim.zero_grad()
        (-ll).backward()
        train_ll += ll.detach().cpu().numpy()

        optim.step()
        scheduler.step()
        model.update()

    if (epoch + 1) % print_every == 0:
        # compute train loss
        train_ll /= len(data.trn.x) * print_every
        lr = optim.param_groups[0]['lr']

        with torch.no_grad():
            model.eval()
            val_ll = 0.
            for i, x in enumerate(create_batcher(data.val.x, batch_size=batch_size)):
                x = torch.tensor(x, device=device)
                ll = model(x)
                val_ll += ll.detach().cpu().numpy()

            val_ll /= len(data.val.x)
            
        with torch.no_grad():
            model.eval()
            test_ll = 0.
            for i, x in enumerate(create_batcher(data.tst.x, batch_size=batch_size)):
                x = torch.tensor(x, device=device)
                ll = model(x)
                test_ll += ll.detach().cpu().numpy()

            test_ll /= len(data.tst.x)
            
        fmt_str1 = 'epoch: {:4d}, time: {:.2f}, train_ll: {:.4f},'
        fmt_str2 = ' val_ll: {:.4f}, test_ll: {:.4f}, lr: {:.4e}'

        print((fmt_str1 + fmt_str2).format(
            epoch + 1,
            time.time() - time_,
            train_ll,
            val_ll,
            test_ll,
            lr))

        time_ = time.time()
        train_ll = 0.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x1 and 8x8)

In [None]:
sampled_x = model.model.sample(10000)

In [None]:
plt.scatter(sampled_x[:,0].cpu(), sampled_x[:,1].cpu(), s=1, alpha=0.05)

In [None]:
plt.scatter(data.trn.x[:,0].cpu(), data.trn.x[:,1].cpu(), alpha=0.05)

In [None]:
v = torch.randn(size=(8,))
print(v.shape)
torch.diag(v)