In [1]:
import torch
from nits.model import *
from nits.autograd_model import *

device = 'cpu'
# device = 'cuda:2'

base_arch = [8, 1]

n = 32
start, end = -2., 2.
arch = [1] + base_arch
monotonic_const = 1e-2

print("Testing NITS.")

for d in [1, 2, 10]:
    for A_constraint in ['neg_exp', 'exp']:
        for final_layer_constraint in ['softmax', 'exp']:
#             print("""
#             Testing configuration:
#                 d: {}
#                 A_constraint: {}
#                 final_layer_constraint: {}
#                   """.format(d, A_constraint, final_layer_constraint))
            ############################
            # DEFINE MODELS            #
            ############################

            model = NITS(d=d, start=start, end=end, arch=arch,
                                 monotonic_const=monotonic_const, A_constraint=A_constraint,
                                 final_layer_constraint=final_layer_constraint).to(device)
            params = torch.randn((n, d * model.n_params)).to(device)

            ############################
            # SANITY CHECKS            #
            ############################

            # check that the function integrates to 1
            assert torch.allclose(torch.ones((n, d)).to(device),
                                  model.cdf(model.end, params) - model.cdf(model.start, params), atol=1e-5)

            # check that the pdf is all positive
            z = torch.linspace(start, end, steps=n, device=device)[:,None].tile((1, d))
            assert (model.pdf(z, params) >= 0).all()

            # check that the cdf is the inverted
            cdf = model.cdf(z, params[0:1])
            icdf = model.icdf(cdf, params[0:1])
            assert (z - icdf <= 1e-3).all()

            ############################
            # COMPARE TO AUTOGRAD NITS #
            ############################
            autograd_model = ModelInverse(arch=arch, start=start, end=end, store_weights=False,
                                          A_constraint=A_constraint, monotonic_const=monotonic_const,
                                          final_layer_constraint=final_layer_constraint).to(device)

            def zs_params_to_forwards(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.apply_layers(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_forwards(z, params)
            outs = model.forward_(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_cdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.cdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_cdfs(z, params)
            outs = model.cdf(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_pdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z, params)
            outs = model.pdf(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_icdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.F_inv(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            y = torch.rand((n, d)).to(device)
            autograd_outs = zs_params_to_icdfs(y, params)
            outs = model.icdf(y, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-1)

            # try with single parameter, many zs

            def zs_params_to_pdfs(zs, param):
                out = []
                for z in zs:
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z, params[0])
            outs = model.pdf(z, params[0:1])
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            # try with single z, many parameters

            def zs_params_to_pdfs(z, params):
                out = []
                for param in params:
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z[0], params)
            outs = model.pdf(z[0:1], params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

from nits.discretized_mol import *
print("Testing arch = [1, 10, 1], 'neg_exp' A_constraint, 'softmax' final_layer_constraint " \
      "against discretized mixture of logistics.")

model = NITS(d=1, start=-1e5, end=1e5, arch=[1, 10, 1],
                     monotonic_const=0., A_constraint='neg_exp',
                     final_layer_constraint='softmax').to(device)
params = torch.randn((n, model.n_params, 1, 1))
z = torch.randn((n, 1, 1, 1))

loss1 = discretized_mix_logistic_loss_1d(z, params)
loss2 = discretized_nits_loss(z, params, nits_model=model)

assert (loss1 - loss2).norm() < 1e-2, (loss1 - loss2).norm()

model = NITS(d=1, start=-1e7, end=1e7, arch=[1, 10, 1],
                     monotonic_const=0., A_constraint='neg_exp',
                     final_layer_constraint='softmax').to(device)

loss1 = discretized_mix_logistic_loss_1d(z, params)
loss2 = discretized_nits_loss(z, params, nits_model=model)

assert (loss1 - loss2).norm() < 1e-3, (loss1 - loss2).norm()

print("All tests passed!")

print("Testing Conditional NITS.")
start, end = -2., 2.
monotonic_const = 1e-2
d = 2
c_arch = [d] + base_arch
A_constraint = 'exp'
final_layer_constraint = 'softmax'
device = 'cpu'

c_model = ConditionalNITS(d=d, start=start, end=end, arch=c_arch,
                          monotonic_const=monotonic_const, A_constraint=A_constraint,
                          final_layer_constraint=final_layer_constraint,
                          autoregressive=False).to(device)

c_params = torch.randn((n, c_model.tot_params)).to(device)
z = torch.linspace(start, end, steps=n, device=device)[:,None].tile((1, d)).to(device)

def cond_zs_params_to_cdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           A_constraint=A_constraint, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_).to(device)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            out.append(c_autograd_model.cdf(z[None,:]))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_cdfs(z, c_params)
outs = c_model.cdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-3), (autograd_outs - outs).norm()

def cond_zs_params_to_pdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           A_constraint=A_constraint, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_).to(device)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            out.append(c_autograd_model.pdf(z[None,:]))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_pdfs(z, c_params)
outs = c_model.pdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-4)

# testing the inverse_cdf function

def cond_zs_params_to_icdfs(ys, zs, params):
    out = []
    for y, z, param in zip(ys, zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           A_constraint=A_constraint, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_).to(device)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            out.append(c_autograd_model.F_inv(y[d_:d_+1][None,:], given_x=z[None,:]))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

y = torch.rand((n, d)).to(device)
autograd_outs = cond_zs_params_to_icdfs(y, z, c_params)
outs = c_model.icdf(y, c_params, given_x=z)
assert torch.allclose(autograd_outs, outs, atol=1e-1)

for i in range(d):
    tmp = torch.cat([z[:,:i], outs[:,i:i+1], z[:,i+1:]], axis=1)
    res = c_model.cdf(tmp, c_params)
    assert torch.allclose(res[:,i], y[:,i], atol=1e-2)

for i in range(d):
    tmp = torch.cat([z[:,:i], outs[:,i:i+1], z[:,i+1:]], axis=1)
    res = cond_zs_params_to_cdfs(tmp, c_params)
    assert torch.allclose(res[:,i], y[:,i], atol=1e-2)

print("All tests passed!")

print('Testing autoregressive conditional NITS.')
start, end = -2., 2.
monotonic_const = 1e-2
A_constraint = 'neg_exp'
final_layer_constraint = 'softmax'
device = 'cpu'

c_model = ConditionalNITS(d=d, start=start, end=end, arch=c_arch,
                          monotonic_const=monotonic_const, A_constraint=A_constraint,
                          final_layer_constraint=final_layer_constraint,
                          autoregressive=True).to(device)

c_params = torch.randn((n, c_model.tot_params)).to(device)
z = torch.linspace(start, end, steps=n, device=device)[:,None].tile((1, d)).to(device)

def causal_mask(x, i):
    x = x.clone()[None,:]
    x[:,i+1:] = 0.
    return x

def cond_zs_params_to_cdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           A_constraint=A_constraint, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_, b_constraint='tanh_conditional').to(device)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])

            # set mask and apply function
            z_masked = causal_mask(z, d_)
            out.append(c_autograd_model.cdf(z_masked))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_cdfs(z, c_params)
outs = c_model.cdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-4)

def cond_zs_params_to_pdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           A_constraint=A_constraint, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_, b_constraint='tanh_conditional').to(device)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])

            # set mask and apply function
            z_masked = causal_mask(z, d_)
            out.append(c_autograd_model.pdf(z_masked))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_pdfs(z, c_params)
outs = c_model.pdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-4)

# testing the inverse_cdf function

def cond_zs_params_to_icdfs(ys, zs, params):
    out = []
    for y, z, param in zip(ys, zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           A_constraint=A_constraint, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_, b_constraint='tanh_conditional').to(device)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])

            # set mask and apply function
            z_masked = torch.cat(out[len(out)-d_:] + [torch.zeros((1, d - d_))], axis=1)
            out.append(c_autograd_model.F_inv(y[d_:d_+1][None,:], given_x=z_masked))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

y = torch.rand((n, d)).to(device)
autograd_outs = cond_zs_params_to_icdfs(y, z, c_params)
outs = c_model.icdf(y, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-1)

assert torch.allclose(c_model.cdf(outs, c_params), y, atol=1e-3)
assert torch.allclose(cond_zs_params_to_cdfs(autograd_outs, c_params), y, atol=1e-3)

print("All tests passed!")

print("Passed all unit tests!")



Testing NITS.
Testing arch = [1, 10, 1], 'neg_exp' A_constraint, 'softmax' final_layer_constraint against discretized mixture of logistics.
All tests passed!
Testing Conditional NITS.
All tests passed!
Testing autoregressive conditional NITS.
All tests passed!
Passed all unit tests!


In [3]:
assert not discretized_nits_loss(z, c_params, c_model).isnan().any()

RuntimeError: number of dims don't match in permute

In [None]:
sample = nits_sample(c_params, arch, 0, 0, c_model)

In [None]:
discretized_nits_loss(sample, c_params, arch, c_model)

In [None]:
nits_model.cdf(imgs, params[:,i,j,:].reshape(-1, nits_model.tot_params))

In [None]:
def loss_reshape(x, l, nits_model):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    # Pytorch ordering
    x = x.permute(0, 2, 3, 1)
    l = l.permute(0, 2, 3, 1)
    xs = [int(y) for y in x.size()]
    ls = [int(y) for y in l.size()]

    # here and below: getting the means and adjusting them based on preceding
    # sub-pixels
    x = x.contiguous()

    nits_model = nits_model.to(x.device)
    x = x.reshape(-1, nits_model.d)
    params = l.reshape(-1, nits_model.tot_params)

    return x, params

In [None]:
# make sure that the reshaping operation is correct
z = torch.arange(n).reshape(n, 1, 1, 1).tile(1, d, 1, 1)
c_params = torch.arange(n).reshape(n, 1, 1, 1).tile(1, c_model.tot_params, 1, 1)
new_z, new_params = reshape(z, c_params, c_model)
assert (new_z == torch.arange(n).reshape(n, 1).tile(1, d)).all()
assert (new_params == torch.arange(n).reshape(n, 1).tile(1, c_model.tot_params)).all()

z = torch.arange(d).reshape(1, d, 1, 1).tile(n, 1, 1, 1)
c_params = torch.arange(c_model.tot_params).reshape(c_model.tot_params, 1, 1, 1).tile(n, 1, 1, 1)
new_z, new_params = reshape(z, c_params, c_model)
assert (new_z == torch.arange(d).reshape(1, d).tile(n, 1)).all()
assert (new_params == torch.arange(c_model.tot_params).reshape(1, c_model.tot_params).tile(n, 1)).all()

In [None]:
def sample_reshape(params, nits_model):
    params = params.permute(0, 2, 3, 1)
    batch_size, height, width, params_per_pixel = params.shape

    nits_model = nits_model.to(params.device)

    n_params = nits_model.n_params
    n_channels = int(params_per_pixel / n_params)

    data = torch.zeros((batch_size, n_channels, height, width), device=params.device)
    
    imgs = nits_model.sample(1, params[:,i,j,:].reshape(-1, nits_model.tot_params)).clamp(min=-1., max=1.)
    data[:,:,i,j] = imgs.reshape((batch_size, n_channels))

    return data