In [1]:
import argparse
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from nits.model import *
from nits.fc_model import *
from maf.datasets import *
from nits.resmade import ResidualMADE

def list_str_to_list(s):
    print(s)
    assert s[0] == '[' and s[-1] == ']'
    s = s[1:-1]
    s = s.replace(' ', '')
    s = s.split(',')

    s = [int(x) for x in s]

    return s
    
def create_batcher(x, batch_size=1):
    idx = 0
    p = torch.randperm(len(x))
    x = x[p]

    while idx + batch_size < len(x):
        yield torch.tensor(x[idx:idx+batch_size], device=device).float()
        idx += batch_size
    else:
        yield torch.tensor(x[idx:], device=device).float()
        
class Dataset:
    def __init__(self, x, permute=False, train_idx=0, val_idx=0):
        # splits x into train, val, and test
        self.n = len(x)
        if permute:
            p = np.random.permutation(self.n)
            x = x[p]
            
        train_idx = train_idx if train_idx else int(0.8 * self.n)
        val_idx = val_idx if val_idx else int(0.9 * self.n)
        
        class DataHolder:
            def __init__(self, x):
                self.x = x
                
        self.trn = DataHolder(x[:train_idx])
        self.val = DataHolder(x[train_idx:val_idx])
        self.tst = DataHolder(x[val_idx:])
        
def build_tridiagonal(n=1000000, d=30, k=1, permute=False):
    precov = np.random.normal(size=(d, d))
    precov = np.matmul(precov, precov.T)
    cov = np.tril(np.triu(precov, -k), k)
    cov = cov / np.diag(cov).mean()
    
    pre_x = np.random.normal(size=(n, d))
    x = np.matmul(pre_x, cov)
    
    # normalize
    m = np.mean(x, axis=0, keepdims=True)
    std = np.std(x, axis=0, keepdims=True)
    assert m.shape == x[0:1].shape and std.shape == x[0:1].shape
    x = (x - m) / std
    
    return Dataset(x.astype(np.float))

def permute_data(dataset):
    d = dataset.trn.x.shape[1]
    train_idx = len(dataset.trn.x)
    val_idx = train_idx + len(dataset.val.x)
    x = np.concatenate([dataset.trn.x, dataset.val.x, dataset.tst.x], axis=0)
    
    P = np.eye(d)
    P = P[np.random.permutation(d)]
    permuted_x = np.matmul(x, P)
    assert np.allclose(np.matmul(permuted_x, P.T), x)
    
    return Dataset(permuted_x.astype(np.float), train_idx=train_idx, val_idx=val_idx), P.astype(np.float)

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('-d', '--dataset', type=str, default='gas')
parser.add_argument('-g', '--gpu', type=str, default='')
parser.add_argument('-s', '--seed', type=int, default=1)
parser.add_argument('-b', '--batch_size', type=int, default=1024)
parser.add_argument('-hi', '--hidden_dim', type=int, default=64)
parser.add_argument('-nr', '--n_residual_blocks', type=int, default=2)
parser.add_argument('-n', '--patience', type=int, default=-1)
parser.add_argument('-ga', '--gamma', type=float, default=1)
parser.add_argument('-pd', '--polyak_decay', type=float, default=1 - 5e-5)
parser.add_argument('-a', '--nits_arch', type=list_str_to_list, default='[16,16,1]')
parser.add_argument('-r', '--rotate', action='store_true')
parser.add_argument('-dn', '--dont_normalize_inverse', type=bool, default=False)
parser.add_argument('-l', '--learning_rate', type=float, default=2e-4)
parser.add_argument('-p', '--dropout', type=float, default=-1.0)
parser.add_argument('-rc', '--add_residual_connections', type=bool, default=False)
parser.add_argument('-bm', '--bound_multiplier', type=float, default=1.0)
parser.add_argument('-pe', '--permute_data', action='store_true')

# args = parser.parse_args(['-g', '6', '-d', 'tridiagonal', '-pe', '-r',])
args = parser.parse_args(['-g', '6', '-pe', '-r', '-b', '1024', '-hi', '1024'])

np.random.seed(args.seed)
torch.manual_seed(args.seed)

device = 'cuda:' + args.gpu if args.gpu else 'cpu'

use_batch_norm = False
zero_initialization = True
weight_norm = False
default_patience = 10
if args.dataset == 'gas':
    # training set size: 852,174
    data = gas.GAS()
    default_dropout = 0.1
elif args.dataset == 'power':
    # training set size: 1,659,917
    data = power.POWER()
    default_dropout = 0.1
elif args.dataset == 'miniboone':
    # training set size: 29,556
    data = miniboone.MINIBOONE()
    default_dropout = 0.3
elif args.dataset == 'hepmass':
    # training set size: 315,123
    data = hepmass.HEPMASS()
    default_dropout = 0.5
    default_pateince = 3
elif args.dataset == 'bsds300':
    # training set size: 1,000,000
    data = bsds300.BSDS300()
    default_dropout = 0.2
elif args.dataset == 'tridiagonal':
    data = build_tridiagonal()
    default_dropout = 0.0
    
if args.permute_data:
    print("PERMUTED DATA")
    data, P = permute_data(data)
    print("P\n", np.arange(data.trn.x.shape[1]).reshape(1, -1)@(P))

args.patience = args.patience if args.patience >= 0 else default_patience
args.dropout = args.dropout if args.dropout >= 0.0 else default_dropout
print(args)

d = data.trn.x.shape[1]

max_val = max(data.trn.x.max(), data.val.x.max(), data.tst.x.max())
min_val = min(data.trn.x.min(), data.val.x.min(), data.tst.x.min())
max_val, min_val = torch.tensor(max_val).to(device).float(), torch.tensor(min_val).to(device).float()

max_val *= args.bound_multiplier
min_val *= args.bound_multiplier

nits_model = NITS(d=d, start=min_val, end=max_val, monotonic_const=1e-5,
                  A_constraint='neg_exp', arch=[1] + args.nits_arch,
                  final_layer_constraint='softmax',
                  add_residual_connections=args.add_residual_connections,
                  normalize_inverse=(not args.dont_normalize_inverse),
                  softmax_temperature=False).to(device)

model = ResMADEModel(
    d=d,
    rotate=args.rotate,
    nits_model=nits_model,
    n_residual_blocks=args.n_residual_blocks,
    hidden_dim=args.hidden_dim,
    dropout_probability=args.dropout,
    use_batch_norm=use_batch_norm,
    zero_initialization=zero_initialization,
    weight_norm=weight_norm
).to(device)

shadow = ResMADEModel(
    d=d,
    rotate=args.rotate,
    nits_model=nits_model,
    n_residual_blocks=args.n_residual_blocks,
    hidden_dim=args.hidden_dim,
    dropout_probability=args.dropout,
    use_batch_norm=use_batch_norm,
    zero_initialization=zero_initialization,
    weight_norm=weight_norm
).to(device)

# initialize weight norm
if weight_norm:
    with torch.no_grad():
        for i, x in enumerate(create_batcher(data.trn.x, batch_size=args.batch_size)):
            params = model(x)
            break

model = EMA(model, shadow, decay=args.polyak_decay).to(device)

print_every = 10
optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=args.gamma)

time_ = time.time()
epoch = 0
train_ll = 0.
max_val_ll = -np.inf
patience = args.patience
keep_training = True
while keep_training:
    model.train()
    for i, x in enumerate(create_batcher(data.trn.x, batch_size=args.batch_size)):
        ll = model(x)
        optim.zero_grad()
        (-ll).backward()
        train_ll += ll.detach().cpu().numpy()

        optim.step()
        scheduler.step()
        model.update()

    epoch += 1

    if epoch % print_every == 0:
        # compute train loss
        train_ll /= len(data.trn.x) * print_every
        lr = optim.param_groups[0]['lr']

        with torch.no_grad():
            model.eval()
            val_ll = 0.
            ema_val_ll = 0.
            for i, x in enumerate(create_batcher(data.val.x, batch_size=args.batch_size)):
                x = torch.tensor(x, device=device)
                val_ll += model.model(x).detach().cpu().numpy()
                ema_val_ll += model(x).detach().cpu().numpy()

            val_ll /= len(data.val.x)
            ema_val_ll /= len(data.val.x)

        # early stopping
        if ema_val_ll > max_val_ll + 1e-4:
            patience = args.patience
            max_val_ll = ema_val_ll
        else:
            patience -= 1

        if patience == 0:
            print("Patience reached zero. max_val_ll stayed at {:.3f} for {:d} iterations.".format(max_val_ll, args.patience))
            keep_training = False

        with torch.no_grad():
            model.eval()
            test_ll = 0.
            ema_test_ll = 0.
            for i, x in enumerate(create_batcher(data.tst.x, batch_size=args.batch_size)):
                x = torch.tensor(x, device=device)
                test_ll += model.model(x).detach().cpu().numpy()
                ema_test_ll += model(x).detach().cpu().numpy()

            test_ll /= len(data.tst.x)
            ema_test_ll /= len(data.tst.x)

        fmt_str1 = 'epoch: {:3d}, time: {:3d}s, train_ll: {:.3f},'
        fmt_str2 = ' ema_val_ll: {:.3f}, ema_test_ll: {:.3f},'
        fmt_str3 = ' val_ll: {:.3f}, test_ll: {:.3f}, lr: {:.2e}'

        print((fmt_str1 + fmt_str2 + fmt_str3).format(
            epoch,
            int(time.time() - time_),
            train_ll,
            ema_val_ll,
            ema_test_ll,
            val_ll,
            test_ll,
            lr))

        time_ = time.time()
        train_ll = 0.

    if epoch % (print_every * 10) == 0:
        print(args)

[16,16,1]
PERMUTED DATA
P
 [[4. 2. 1. 6. 5. 7. 3. 0.]]
Namespace(add_residual_connections=False, batch_size=1024, bound_multiplier=1.0, dataset='gas', dont_normalize_inverse=False, dropout=0.1, gamma=1, gpu='6', hidden_dim=1024, learning_rate=0.0002, n_residual_blocks=2, nits_arch=[16, 16, 1], patience=10, permute_data=True, polyak_decay=0.99995, rotate=True, seed=1)


  self.register_buffer('start_val', torch.tensor(start))
  self.register_buffer('end_val', torch.tensor(end))
  self.register_buffer('start', torch.tensor(start).reshape(1, 1).tile(1, d))
  self.register_buffer('end', torch.tensor(end).reshape(1, 1).tile(1, d))


epoch:  10, time: 149s, train_ll: 4.913, ema_val_ll: -33.917, ema_test_ll: -33.897, val_ll: 7.336, test_ll: 7.350, lr: 2.00e-04
epoch:  20, time: 149s, train_ll: 8.398, ema_val_ll: -23.221, ema_test_ll: -23.213, val_ll: 8.909, test_ll: 8.929, lr: 2.00e-04
epoch:  30, time: 149s, train_ll: 9.207, ema_val_ll: -14.361, ema_test_ll: -14.363, val_ll: 9.272, test_ll: 9.283, lr: 2.00e-04
epoch:  40, time: 146s, train_ll: 8.199, ema_val_ll: -12.623, ema_test_ll: -12.661, val_ll: 9.424, test_ll: 9.437, lr: 2.00e-04
epoch:  50, time: 143s, train_ll: 9.578, ema_val_ll: -26.166, ema_test_ll: -26.219, val_ll: 9.973, test_ll: 9.988, lr: 2.00e-04
epoch:  60, time: 143s, train_ll: 10.084, ema_val_ll: -6.818, ema_test_ll: -6.837, val_ll: 10.446, test_ll: 10.454, lr: 2.00e-04
epoch:  70, time: 144s, train_ll: 10.405, ema_val_ll: -2.838, ema_test_ll: -2.848, val_ll: 10.568, test_ll: 10.578, lr: 2.00e-04
epoch:  80, time: 146s, train_ll: 10.642, ema_val_ll: 0.714, ema_test_ll: 0.715, val_ll: 10.903, test_

epoch: 530, time: 141s, train_ll: 12.486, ema_val_ll: 12.757, ema_test_ll: 12.760, val_ll: 12.513, test_ll: 12.519, lr: 2.00e-04
epoch: 540, time: 140s, train_ll: 12.496, ema_val_ll: 12.764, ema_test_ll: 12.768, val_ll: 12.478, test_ll: 12.476, lr: 2.00e-04
epoch: 550, time: 136s, train_ll: 12.507, ema_val_ll: 12.770, ema_test_ll: 12.773, val_ll: 12.460, test_ll: 12.462, lr: 2.00e-04
epoch: 560, time: 138s, train_ll: 12.516, ema_val_ll: 12.776, ema_test_ll: 12.781, val_ll: 12.475, test_ll: 12.484, lr: 2.00e-04
epoch: 570, time: 138s, train_ll: 12.526, ema_val_ll: 12.781, ema_test_ll: 12.786, val_ll: 12.369, test_ll: 12.376, lr: 2.00e-04
epoch: 580, time: 140s, train_ll: 12.536, ema_val_ll: 12.785, ema_test_ll: 12.790, val_ll: 12.568, test_ll: 12.581, lr: 2.00e-04
epoch: 590, time: 142s, train_ll: 12.545, ema_val_ll: 12.788, ema_test_ll: 12.793, val_ll: 12.547, test_ll: 12.550, lr: 2.00e-04
epoch: 600, time: 142s, train_ll: 12.555, ema_val_ll: 12.796, ema_test_ll: 12.801, val_ll: 12.492

epoch: 1050, time: 129s, train_ll: 12.824, ema_val_ll: 12.927, ema_test_ll: 12.938, val_ll: 12.732, test_ll: 12.742, lr: 2.00e-04
epoch: 1060, time: 130s, train_ll: 12.828, ema_val_ll: 12.927, ema_test_ll: 12.938, val_ll: 12.745, test_ll: 12.752, lr: 2.00e-04
epoch: 1070, time: 129s, train_ll: 12.833, ema_val_ll: 12.931, ema_test_ll: 12.941, val_ll: 12.731, test_ll: 12.745, lr: 2.00e-04
epoch: 1080, time: 129s, train_ll: 12.836, ema_val_ll: 12.932, ema_test_ll: 12.942, val_ll: 12.718, test_ll: 12.728, lr: 2.00e-04
epoch: 1090, time: 130s, train_ll: 12.840, ema_val_ll: 12.934, ema_test_ll: 12.944, val_ll: 12.729, test_ll: 12.736, lr: 2.00e-04
epoch: 1100, time: 129s, train_ll: 12.844, ema_val_ll: 12.935, ema_test_ll: 12.945, val_ll: 12.699, test_ll: 12.707, lr: 2.00e-04
Namespace(add_residual_connections=False, batch_size=1024, bound_multiplier=1.0, dataset='gas', dont_normalize_inverse=False, dropout=0.1, gamma=1, gpu='6', hidden_dim=1024, learning_rate=0.0002, n_residual_blocks=2, nit

epoch: 1560, time: 129s, train_ll: 12.980, ema_val_ll: 13.011, ema_test_ll: 13.020, val_ll: 12.793, test_ll: 12.804, lr: 2.00e-04
epoch: 1570, time: 129s, train_ll: 12.983, ema_val_ll: 13.012, ema_test_ll: 13.022, val_ll: 12.816, test_ll: 12.829, lr: 2.00e-04
epoch: 1580, time: 129s, train_ll: 12.986, ema_val_ll: 13.013, ema_test_ll: 13.022, val_ll: 12.816, test_ll: 12.823, lr: 2.00e-04
epoch: 1590, time: 129s, train_ll: 12.987, ema_val_ll: 13.014, ema_test_ll: 13.023, val_ll: 12.831, test_ll: 12.837, lr: 2.00e-04
epoch: 1600, time: 129s, train_ll: 12.990, ema_val_ll: 13.013, ema_test_ll: 13.022, val_ll: 12.840, test_ll: 12.845, lr: 2.00e-04
Namespace(add_residual_connections=False, batch_size=1024, bound_multiplier=1.0, dataset='gas', dont_normalize_inverse=False, dropout=0.1, gamma=1, gpu='6', hidden_dim=1024, learning_rate=0.0002, n_residual_blocks=2, nits_arch=[16, 16, 1], patience=10, permute_data=True, polyak_decay=0.99995, rotate=True, seed=1)
epoch: 1610, time: 129s, train_ll: 

epoch: 2070, time: 129s, train_ll: 13.080, ema_val_ll: 13.034, ema_test_ll: 13.040, val_ll: 12.844, test_ll: 12.850, lr: 2.00e-04
epoch: 2080, time: 129s, train_ll: 13.081, ema_val_ll: 13.032, ema_test_ll: 13.039, val_ll: 12.832, test_ll: 12.842, lr: 2.00e-04
epoch: 2090, time: 129s, train_ll: 13.083, ema_val_ll: 13.035, ema_test_ll: 13.042, val_ll: 12.860, test_ll: 12.862, lr: 2.00e-04
epoch: 2100, time: 129s, train_ll: 13.085, ema_val_ll: 13.035, ema_test_ll: 13.042, val_ll: 12.809, test_ll: 12.819, lr: 2.00e-04
Namespace(add_residual_connections=False, batch_size=1024, bound_multiplier=1.0, dataset='gas', dont_normalize_inverse=False, dropout=0.1, gamma=1, gpu='6', hidden_dim=1024, learning_rate=0.0002, n_residual_blocks=2, nits_arch=[16, 16, 1], patience=10, permute_data=True, polyak_decay=0.99995, rotate=True, seed=1)
epoch: 2110, time: 129s, train_ll: 13.087, ema_val_ll: 13.036, ema_test_ll: 13.042, val_ll: 12.883, test_ll: 12.886, lr: 2.00e-04
epoch: 2120, time: 129s, train_ll: 

epoch: 2580, time: 128s, train_ll: 13.152, ema_val_ll: 13.049, ema_test_ll: 13.054, val_ll: 12.811, test_ll: 12.821, lr: 2.00e-04
epoch: 2590, time: 128s, train_ll: 13.152, ema_val_ll: 13.049, ema_test_ll: 13.054, val_ll: 12.864, test_ll: 12.869, lr: 2.00e-04
epoch: 2600, time: 128s, train_ll: 13.154, ema_val_ll: 13.050, ema_test_ll: 13.055, val_ll: 12.879, test_ll: 12.885, lr: 2.00e-04
Namespace(add_residual_connections=False, batch_size=1024, bound_multiplier=1.0, dataset='gas', dont_normalize_inverse=False, dropout=0.1, gamma=1, gpu='6', hidden_dim=1024, learning_rate=0.0002, n_residual_blocks=2, nits_arch=[16, 16, 1], patience=10, permute_data=True, polyak_decay=0.99995, rotate=True, seed=1)
epoch: 2610, time: 128s, train_ll: 13.155, ema_val_ll: 13.050, ema_test_ll: 13.055, val_ll: 12.896, test_ll: 12.899, lr: 2.00e-04
epoch: 2620, time: 128s, train_ll: 13.156, ema_val_ll: 13.050, ema_test_ll: 13.055, val_ll: 12.866, test_ll: 12.876, lr: 2.00e-04
epoch: 2630, time: 129s, train_ll: 

In [3]:
P_hat = model.model.get_P()
v = torch.ones(d, device=device).reshape(1, -1) / np.sqrt(d)
print(v.norm())
print((v @ P_hat).norm())
P_hat.argmax(axis=0)

tensor(1.0000, device='cuda:6')
tensor(1., device='cuda:6', grad_fn=<CopyBackwards>)


tensor([2, 2, 0, 5, 6, 5, 0, 3], device='cuda:6')