In [2]:
import argparse
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from nits.model import *
from nits.fc_model import *
from maf.datasets import *
from nits.resmade import ResidualMADE, CausalTransformer

def list_str_to_list(s):
    print(s)
    assert s[0] == '[' and s[-1] == ']'
    s = s[1:-1]
    s = s.replace(' ', '')
    s = s.split(',')

    s = [int(x) for x in s]

    return s
    
def create_batcher(x, batch_size=1):
    idx = 0
    p = torch.randperm(len(x))
    x = x[p]

    while idx + batch_size < len(x):
        yield torch.tensor(x[idx:idx+batch_size], device=device).float()
        idx += batch_size
    else:
        yield torch.tensor(x[idx:], device=device).float()
        
class Dataset:
    def __init__(self, x, permute=False, train_idx=0, val_idx=0):
        # splits x into train, val, and test
        self.n = len(x)
        if permute:
            p = np.random.permutation(self.n)
            x = x[p]
            
        train_idx = train_idx if train_idx else int(0.8 * self.n)
        val_idx = val_idx if val_idx else int(0.9 * self.n)
        
        class DataHolder:
            def __init__(self, x):
                self.x = x
                
        self.trn = DataHolder(x[:train_idx])
        self.val = DataHolder(x[train_idx:val_idx])
        self.tst = DataHolder(x[val_idx:])
        
def build_tridiagonal(n=1000000, d=30, k=1, permute=False):
    precov = np.random.normal(size=(d, d))
    precov = np.matmul(precov, precov.T)
    cov = np.tril(np.triu(precov, -k), k)
    cov = cov / np.diag(cov).mean()
    
    pre_x = np.random.normal(size=(n, d))
    x = np.matmul(pre_x, cov)
    
    # normalize
    m = np.mean(x, axis=0, keepdims=True)
    std = np.std(x, axis=0, keepdims=True)
    assert m.shape == x[0:1].shape and std.shape == x[0:1].shape
    x = (x - m) / std
    
    return Dataset(x.astype(np.float))

def permute_data(dataset):
    d = dataset.trn.x.shape[1]
    train_idx = len(dataset.trn.x)
    val_idx = train_idx + len(dataset.val.x)
    x = np.concatenate([dataset.trn.x, dataset.val.x, dataset.tst.x], axis=0)
    
    P = np.eye(d)
    P = P[np.random.permutation(d)]
    permuted_x = np.matmul(x, P)
    assert np.allclose(np.matmul(permuted_x, P.T), x)
    
    return Dataset(permuted_x.astype(np.float), train_idx=train_idx, val_idx=val_idx), P.astype(np.float)

In [3]:
parser = argparse.ArgumentParser()

parser.add_argument('-d', '--dataset', type=str, default='gas')
parser.add_argument('-g', '--gpu', type=str, default='')
parser.add_argument('-s', '--seed', type=int, default=1)
parser.add_argument('-b', '--batch_size', type=int, default=1024)
parser.add_argument('-hi', '--hidden_dim', type=int, default=64)
parser.add_argument('-nr', '--n_blocks', type=int, default=2)
parser.add_argument('-n', '--patience', type=int, default=-1)
parser.add_argument('-ga', '--gamma', type=float, default=1)
parser.add_argument('-pd', '--polyak_decay', type=float, default=1 - 5e-5)
parser.add_argument('-a', '--nits_arch', type=list_str_to_list, default='[16,16,1]')
parser.add_argument('-r', '--rotate', action='store_true')
parser.add_argument('-dn', '--dont_normalize_inverse', type=bool, default=False)
parser.add_argument('-l', '--learning_rate', type=float, default=2e-4)
parser.add_argument('-p', '--dropout', type=float, default=-1.0)
parser.add_argument('-rc', '--add_residual_connections', type=bool, default=False)
parser.add_argument('-bm', '--bound_multiplier', type=float, default=1.0)
parser.add_argument('-pe', '--permute_data', action='store_true')

# args = parser.parse_args(['-g', '6', '-d', 'tridiagonal', '-pe', '-r',])
args = parser.parse_args([
    '--gpu=1', 
    '--batch_size=128', 
    '--hidden_dim=64', 
    '--learning_rate=2e-4',
    '--n_blocks=8'
])

np.random.seed(args.seed)
torch.manual_seed(args.seed)

device = 'cuda:' + args.gpu if args.gpu else 'cpu'

# param_model = ResidualMADE
param_model = CausalTransformer
default_patience = 10
if args.dataset == 'gas':
    # training set size: 852,174
    data = gas.GAS()
    default_dropout = 0.1
elif args.dataset == 'power':
    # training set size: 1,659,917
    data = power.POWER()
    default_dropout = 0.1
elif args.dataset == 'miniboone':
    # training set size: 29,556
    data = miniboone.MINIBOONE()
    default_dropout = 0.3
elif args.dataset == 'hepmass':
    # training set size: 315,123
    data = hepmass.HEPMASS()
    default_dropout = 0.5
    default_pateince = 3
elif args.dataset == 'bsds300':
    # training set size: 1,000,000
    data = bsds300.BSDS300()
    default_dropout = 0.2
elif args.dataset == 'tridiagonal':
    data = build_tridiagonal()
    default_dropout = 0.0
    
if args.permute_data:
    print("PERMUTED DATA")
    data, P = permute_data(data)
    print("P\n", np.arange(data.trn.x.shape[1]).reshape(1, -1)@(P))

[16,16,1]


In [None]:
args.patience = args.patience if args.patience >= 0 else default_patience
args.dropout = args.dropout if args.dropout >= 0.0 else default_dropout
print(args)

d = data.trn.x.shape[1]

max_val = max(data.trn.x.max(), data.val.x.max(), data.tst.x.max())
min_val = min(data.trn.x.min(), data.val.x.min(), data.tst.x.min())
max_val, min_val = torch.tensor(max_val).to(device).float(), torch.tensor(min_val).to(device).float()

max_val *= args.bound_multiplier
min_val *= args.bound_multiplier

nits_model = NITS(start=min_val, end=max_val, monotonic_const=1e-5,
                  A_constraint='neg_exp', arch=[d] + args.nits_arch,
                  final_layer_constraint='softmax',
                  add_residual_connections=args.add_residual_connections,
                  normalize_inverse=(not args.dont_normalize_inverse),
                  softmax_temperature=False).to(device)

model = Model(
    d=d,
    rotate=args.rotate,
    nits_model=nits_model,
    param_model=param_model,
    n_blocks=args.n_blocks,
    hidden_dim=args.hidden_dim,
    dropout_probability=args.dropout,
).to(device)

shadow = Model(
    d=d,
    rotate=args.rotate,
    nits_model=nits_model,
    param_model=param_model,
    n_blocks=args.n_blocks,
    hidden_dim=args.hidden_dim,
    dropout_probability=args.dropout,
).to(device)

model = EMA(model, shadow, decay=args.polyak_decay).to(device)

print_every = 1
optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=args.gamma)

time_ = time.time()
epoch = 0
train_ll = 0.
max_val_ll = -np.inf
patience = args.patience
keep_training = True
while keep_training:
    model.train()
    for i, x in enumerate(create_batcher(data.trn.x, batch_size=args.batch_size)):
        ll = model(x)
        optim.zero_grad()
        (-ll).backward()
        train_ll += ll.detach().cpu().numpy()

        optim.step()
        scheduler.step()
        model.update()

    epoch += 1

    if epoch % print_every == 0:
        # compute train loss
        train_ll /= len(data.trn.x) * print_every
        lr = optim.param_groups[0]['lr']

        with torch.no_grad():
            model.eval()
            val_ll = 0.
            ema_val_ll = 0.
            for i, x in enumerate(create_batcher(data.val.x, batch_size=args.batch_size)):
                x = torch.tensor(x, device=device)
                val_ll += model.model(x).detach().cpu().numpy()
                ema_val_ll += model(x).detach().cpu().numpy()

            val_ll /= len(data.val.x)
            ema_val_ll /= len(data.val.x)

        # early stopping
        if ema_val_ll > max_val_ll + 1e-4:
            patience = args.patience
            max_val_ll = ema_val_ll
        else:
            patience -= 1

        if patience == 0:
            print("Patience reached zero. max_val_ll stayed at {:.3f} for {:d} iterations.".format(max_val_ll, args.patience))
            keep_training = False

        with torch.no_grad():
            model.eval()
            test_ll = 0.
            ema_test_ll = 0.
            for i, x in enumerate(create_batcher(data.tst.x, batch_size=args.batch_size)):
                x = torch.tensor(x, device=device)
                test_ll += model.model(x).detach().cpu().numpy()
                ema_test_ll += model(x).detach().cpu().numpy()

            test_ll /= len(data.tst.x)
            ema_test_ll /= len(data.tst.x)

        fmt_str1 = 'epoch: {:3d}, time: {:3d}s, train_ll: {:.3f},'
        fmt_str2 = ' ema_val_ll: {:.3f}, ema_test_ll: {:.3f},'
        fmt_str3 = ' val_ll: {:.3f}, test_ll: {:.3f}, lr: {:.2e}'

        print((fmt_str1 + fmt_str2 + fmt_str3).format(
            epoch,
            int(time.time() - time_),
            train_ll,
            ema_val_ll,
            ema_test_ll,
            val_ll,
            test_ll,
            lr))

        time_ = time.time()
        train_ll = 0.

    if epoch % (print_every * 10) == 0:
        print(args)

In [None]:
P_hat = model.model.get_P()
v = torch.ones(d, device=device).reshape(1, -1) / np.sqrt(d)
print(v.norm())
print((v @ P_hat).norm())
P_hat.argmax(axis=0)

In [23]:
from mlpack import det

d = det(folds=10, test=data.tst.x,
#         max_leaf_size=10,
#         min_leaf_size=5,
#         training=np.concatenate([data.trn.x, data.val.x]), verbose=False)
        training=data.trn.x, verbose=True)

[0;32m[INFO ] [0m134298 leaf nodes in the tree using full dataset; minimum alpha: -42.8865.
[0;32m[INFO ] [0mPerforming 10-fold cross validation.
[0;32m[INFO ] [0m39733 trees in the sequence; maximum alpha: 8.99925.
[0;32m[INFO ] [0mOptimal alpha: 8.83999.
[0;32m[INFO ] [0m100 leaf nodes in the optimally pruned tree; optimal alpha: 8.83999.
[0;33m[WARN ] [0mUnable to open file '' to save tag membership info.


In [24]:
np.log(d['test_set_estimates'] + 1e-7).mean()

-13.7834544349559