In [2]:
import torch
from nits.model import *
from nits.autograd_model import *

device = 'cpu'

n = 8
start, end = -2., 2.
arch = [1, 8, 1]
monotonic_const = 1e-2

print("Testing NITS.")

for d in [1, 2, 10]:
    for constraint_type in ['neg_exp', 'exp']:
        for final_layer_constraint in ['softmax', 'exp']:
#             print("""
#             Testing configuration:
#                 d: {}
#                 constraint_type: {}
#                 final_layer_constraint: {}
#                   """.format(d, constraint_type, final_layer_constraint))
            ############################
            # DEFINE MODELS            #
            ############################

            model = NITS(d=d, start=start, end=end, arch=arch,
                                 monotonic_const=monotonic_const, constraint_type=constraint_type,
                                 final_layer_constraint=final_layer_constraint).to(device)
            params = torch.randn((n, d * model.n_params)).to(device)

            ############################
            # SANITY CHECKS            #
            ############################

            # check that the function integrates to 1
            assert torch.allclose(torch.ones((n, d)).to(device),
                                  model.cdf(model.end, params) - model.cdf(model.start, params), atol=1e-5)

            # check that the pdf is all positive
            z = torch.linspace(start, end, steps=n, device=device)[:,None].tile((1, d))
            assert (model.pdf(z, params) >= 0).all()

            # check that the cdf is the inverted
            cdf = model.cdf(z, params[0:1])
            icdf = model.icdf(cdf, params[0:1])
            assert (z - icdf <= 1e-3).all()

            ############################
            # COMPARE TO AUTOGRAD NITS #
            ############################
            autograd_model = ModelInverse(arch=arch, start=start, end=end, store_weights=False,
                                          constraint_type=constraint_type, monotonic_const=monotonic_const,
                                          final_layer_constraint=final_layer_constraint)

            def zs_params_to_forwards(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.apply_layers(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_forwards(z, params)
            outs = model.forward_(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_cdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.cdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_cdfs(z, params)
            outs = model.cdf(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_pdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z, params)
            outs = model.pdf(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_icdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.F_inv(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            y = torch.rand((n, d)).to(device)
            autograd_outs = zs_params_to_icdfs(y, params)
            outs = model.icdf(y, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-1)

            # try with single parameter, many zs

            def zs_params_to_pdfs(zs, param):
                out = []
                for z in zs:
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z, params[0])
            outs = model.pdf(z, params[0:1])
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            # try with single z, many parameters

            def zs_params_to_pdfs(z, params):
                out = []
                for param in params:
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z[0], params)
            outs = model.pdf(z[0:1], params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

from nits.discretized_mol import *
print("Testing arch = [1, 10, 1], 'neg_exp' constraint_type, 'softmax' final_layer_constraint " \
      "against discretized mixture of logistics.")

model = NITS(d=1, start=-1e5, end=1e5, arch=[1, 10, 1],
                     monotonic_const=0., constraint_type='neg_exp',
                     final_layer_constraint='softmax').to(device)
params = torch.randn((n, model.n_params, 1, 1))
z = torch.randn((n, 1, 1, 1))

loss1 = discretized_mix_logistic_loss_1d(z, params)
loss2 = discretized_nits_loss(z, params, arch=[1, 10, 1], nits_model=model)

assert (loss1 - loss2).norm() < 1e-2, (loss1 - loss2).norm()

model = NITS(d=1, start=-1e7, end=1e7, arch=[1, 10, 1],
                     monotonic_const=0., constraint_type='neg_exp',
                     final_layer_constraint='softmax').to(device)

loss1 = discretized_mix_logistic_loss_1d(z, params)
loss2 = discretized_nits_loss(z, params, arch=[1, 10, 1], nits_model=model)

assert (loss1 - loss2).norm() < 1e-3, (loss1 - loss2).norm()

print("All tests passed!")

Testing NITS.
Testing arch = [1, 10, 1], 'neg_exp' constraint_type, 'softmax' final_layer_constraint against discretized mixture of logistics.
All tests passed!


In [3]:
print("Testing Conditional NITS.")
start, end = -2., 2.
monotonic_const = 1e-2
d = 2
c_arch = [d, 8, 1]
constraint_type = 'exp'
final_layer_constraint = 'softmax'
device = 'cpu'

c_model = ConditionalNITS(d=d, start=start, end=end, arch=c_arch,
                          monotonic_const=monotonic_const, constraint_type=constraint_type,
                          final_layer_constraint=final_layer_constraint,
                          autoregressive=False).to(device)

c_params = torch.randn((n, c_model.tot_params))
z = torch.linspace(start, end, steps=n, device=device)[:,None].tile((1, d))

def cond_zs_params_to_cdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           constraint_type=constraint_type, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            out.append(c_autograd_model.cdf(z[None,:]))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_cdfs(z, c_params)
outs = c_model.cdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-4)

def cond_zs_params_to_pdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           constraint_type=constraint_type, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            out.append(c_autograd_model.pdf(z[None,:]))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_pdfs(z, c_params)
outs = c_model.pdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-4)

# testing the inverse_cdf function

def cond_zs_params_to_icdfs(ys, zs, params):
    out = []
    for y, z, param in zip(ys, zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           constraint_type=constraint_type, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            out.append(c_autograd_model.F_inv(y[d_:d_+1][None,:], given_x=z[None,:]))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

y = torch.rand((n, d)).to(device)
autograd_outs = cond_zs_params_to_icdfs(y, z, c_params)
outs = c_model.icdf(y, c_params, given_x=z)
assert torch.allclose(autograd_outs, outs, atol=1e-1)
    
for i in range(d):
    tmp = torch.cat([z[:,:i], outs[:,i:i+1], z[:,i+1:]], axis=1)
    res = c_model.cdf(tmp, c_params)
    assert torch.allclose(res[:,i], y[:,i], atol=1e-2)
    
for i in range(d):
    tmp = torch.cat([z[:,:i], outs[:,i:i+1], z[:,i+1:]], axis=1)
    res = cond_zs_params_to_cdfs(tmp, c_params)
    assert torch.allclose(res[:,i], y[:,i], atol=1e-2)
    
print("All tests passed!")

Testing Conditional NITS.
All tests passed!


In [4]:
print('Testing autoregressive conditional NITS.')
start, end = -2., 2.
monotonic_const = 1e-2
d = 2
c_arch = [d, 8, 1]
constraint_type = 'exp'
final_layer_constraint = 'softmax'
device = 'cpu'

c_model = ConditionalNITS(d=d, start=start, end=end, arch=c_arch,
                          monotonic_const=monotonic_const, constraint_type=constraint_type,
                          final_layer_constraint=final_layer_constraint,
                          autoregressive=True).to(device)

c_params = torch.randn((n, c_model.tot_params))
z = torch.linspace(start, end, steps=n, device=device)[:,None].tile((1, d))

def causal_mask(x, i):
    x = x.clone()[None,:]
    x[:,i+1:] = 0.
    return x

def cond_zs_params_to_cdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           constraint_type=constraint_type, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            
            # set mask and apply function
            z_masked = causal_mask(z, d_)
            out.append(c_autograd_model.cdf(z_masked))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_cdfs(z, c_params)
outs = c_model.cdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-4)

def cond_zs_params_to_pdfs(zs, params):
    out = []
    for z, param in zip(zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           constraint_type=constraint_type, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            
            # set mask and apply function
            z_masked = causal_mask(z, d_)
            out.append(c_autograd_model.pdf(z_masked))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

autograd_outs = cond_zs_params_to_pdfs(z, c_params)
outs = c_model.pdf(z, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-4)

# testing the inverse_cdf function

def cond_zs_params_to_icdfs(ys, zs, params):
    out = []
    for y, z, param in zip(ys, zs, params):
        for d_ in range(d):
            c_autograd_model = ModelInverse(arch=c_arch, start=start, end=end, store_weights=False,
                                           constraint_type=constraint_type, monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint,
                                           non_conditional_dim=d_)
            start_idx, end_idx = d_ * c_autograd_model.n_params, (d_ + 1) * c_autograd_model.n_params
            c_autograd_model.set_params(param[start_idx:end_idx])
            
            # set mask and apply function
            z_masked = torch.cat(out[len(out)-d_:] + [torch.zeros((1, d - d_))], axis=1)
            out.append(c_autograd_model.F_inv(y[d_:d_+1][None,:], given_x=z_masked))

    out = torch.cat(out, axis=0).reshape(-1, d)
    return out

y = torch.rand((n, d)).to(device)
autograd_outs = cond_zs_params_to_icdfs(y, z, c_params)
outs = c_model.icdf(y, c_params)
assert torch.allclose(autograd_outs, outs, atol=1e-1)

assert torch.allclose(c_model.cdf(outs, c_params), y, atol=1e-3)  
assert torch.allclose(cond_zs_params_to_cdfs(autograd_outs, c_params), y, atol=1e-3)

print("All tests passed!")

Testing autoregressive conditional NITS.
All tests passed!


In [40]:
class ConditionalNITS(NITSPrimitive):
    # TODO: for now, just implement ConditionalNITS such that it sequentially evaluates each dimension
    # this process is (probably) possible to vectorize, but since we're currently only doing 3 dimensions,
    # there's no need to speed things up, because we only gain a factor of 3 speedup
    def __init__(self, d, arch, start=-2., end=2., constraint_type='neg_exp',
                 monotonic_const=1e-2, final_layer_constraint='softmax',
                 autoregressive=True):
        super(ConditionalNITS, self).__init__(arch=arch, start=start, end=end,
                                           constraint_type=constraint_type,
                                           monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint)
        self.d = d
        self.tot_params = self.n_params * d
        self.final_layer_constraint = final_layer_constraint
        self.autoregressive = autoregressive

        self.register_buffer('start', torch.tensor(start).reshape(1, 1).tile(1, d))
        self.register_buffer('end', torch.tensor(end).reshape(1, 1).tile(1, d))

        assert arch[0] == d
        self.nits_list = torch.nn.ModuleList()
        for i in range(self.d):
            model = NITSPrimitive(arch=arch, start=start, end=end,
                         constraint_type=constraint_type,
                         monotonic_const=monotonic_const,
                         final_layer_constraint=final_layer_constraint,
                         non_conditional_dim=i)
            self.nits_list.append(model)

    def causal_mask(self, x, i):
        if self.autoregressive:
            x = x.clone()
            x[:,i+1:] = 0.
        return x

    def apply_conditional_func(self, func, x, params):
        n = max(len(x), len(params))
        result = func(x, params)

        if isinstance(result, tuple):
            return (result[0].reshape((n, -1)),) + result[1:]
        else:
            return result.reshape((n, -1))

    def forward_(self, x, params, return_intermediaries=False):
        result = []
        for i in range(self.d):
            x_masked = self.causal_mask(x, i)
            start_idx, end_idx = i * self.n_params, (i + 1) * self.n_params
            func = lambda x, params: self.nits_list[i].forward_(x, params, return_intermediaries)
            result.append(self.apply_conditional_func(func, x_masked, params[:,start_idx:end_idx]))

        result = torch.cat(result, axis=1)
        return result

    def backward_(self, x, params):
        result = []
        for i in range(self.d):
            x_masked = self.causal_mask(x, i)
            start_idx, end_idx = i * self.n_params, (i + 1) * self.n_params
            func = self.nits_list[i].backward_
            result.append(self.apply_conditional_func(func, x_masked, params[:,start_idx:end_idx]))

        result = torch.cat(result, axis=1)
        return result

    def cdf(self, x, params):
        result = []
        for i in range(self.d):
            x_masked = self.causal_mask(x, i)
            start_idx, end_idx = i * self.n_params, (i + 1) * self.n_params
            func = self.nits_list[i].cdf
            result.append(self.apply_conditional_func(func, x_masked, params[:,start_idx:end_idx]))

        result = torch.cat(result, axis=1)
        return result

    def pdf(self, x, params):
        result = []
        for i in range(self.d):
            x_masked = self.causal_mask(x, i)
            start_idx, end_idx = i * self.n_params, (i + 1) * self.n_params
            func = self.nits_list[i].pdf
            result.append(self.apply_conditional_func(func, x_masked, params[:,start_idx:end_idx]))

        result = torch.cat(result, axis=1)
        return result

    def icdf(self, x, params, given_x=None):
        if self.autoregressive and given_x is not None:
            raise NotImplementedError('given_x cannot be supplied if autoregressive == True')
        
        result = []
        for i in range(self.d):
            if self.autoregressive:
                print(result + [torch.zeros(len(x), self.d - len(result))])
                given_x = torch.cat(result + [torch.zeros(len(x), self.d - len(result))], axis=1)
            start_idx, end_idx = i * self.n_params, (i + 1) * self.n_params
            func = lambda x, params: self.nits_list[i].icdf(x, params, given_x=given_x)
            result.append(self.apply_conditional_func(func, x, params[:,start_idx:end_idx]))

        result = torch.cat(result, axis=1)
        return result

    def sample(self, n, params):
        if len(params) == 1:
            params = params.reshape(self.d, self.n_params).tile((n, 1))
        elif len(params) == n:
            params = params.reshape(-1, self.n_params)
            
        z = torch.rand((n, self.d)).to(device)
            
        return self.icdf(z, params)

In [41]:
print('Testing autoregressive conditional NITS loss and sample functions.')
start, end = -2., 2.
monotonic_const = 1e-2
d = 3
c_arch = [d, 8, 1]
constraint_type = 'exp'
final_layer_constraint = 'softmax'
device = 'cpu'

c_model = ConditionalNITS(d=d, start=start, end=end, arch=c_arch,
                          monotonic_const=monotonic_const, constraint_type=constraint_type,
                          final_layer_constraint=final_layer_constraint,
                          autoregressive=True).to(device)

c_params = torch.randn((n, c_model.tot_params, 1, 1))
z = torch.randn((n, d, 1, 1))

Testing autoregressive conditional NITS loss and sample functions.


In [42]:
def discretized_nits_loss(x, l, arch, nits_model):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    # Pytorch ordering
    x = x.permute(0, 2, 3, 1)
    l = l.permute(0, 2, 3, 1)
    xs = [int(y) for y in x.size()]
    ls = [int(y) for y in l.size()]

    # here and below: getting the means and adjusting them based on preceding
    # sub-pixels
    x = x.contiguous()

    nits_model = nits_model.to(x.device)
    x = x.reshape(-1, nits_model.d)
    params = l.reshape(-1, nits_model.tot_params)

    x_plus = (x * 127.5 + .5).round() / 127.5
    x_min = (x * 127.5 - .5).round() / 127.5

    cdf_delta = nits_model.cdf(x_plus, params) - nits_model.cdf(x_min, params)
    log_cdf_plus = nits_model.cdf(x_plus, params).log()
    log_one_minus_cdf_min = (1 - nits_model.cdf(x_min, params)).log()
    log_pdf_mid = nits_model.pdf(x, params).log()

    inner_inner_cond = (cdf_delta > 1e-5).float()
    inner_inner_out  = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (log_pdf_mid - np.log(127.5))
    inner_cond       = (x > 0.999).float()
    inner_out        = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
    cond             = (x < -0.999).float()
    log_probs        = cond * log_cdf_plus + (1. - cond) * inner_out

    return -log_probs.sum()

def nits_sample(params, arch, i, j, nits_model):
    params = params.permute(0, 2, 3, 1)
    batch_size, height, width, params_per_pixel = params.shape

    nits_model = nits_model.to(params.device)

    n_params = nits_model.n_params
    n_channels = int(params_per_pixel / n_params)

    data = torch.zeros((batch_size, n_channels, height, width))

    imgs = nits_model.sample(1, params[:,i,j,:].reshape(-1, nits_model.tot_params)).clamp(min=-1., max=1.)
    data[:,:,i,j] = imgs.reshape((batch_size, n_channels))

    return data

In [43]:
discretized_nits_loss(z, c_params, arch, c_model)

tensor(120.5176)

In [44]:
nits_sample(c_params, arch, 0, 0, c_model)

[tensor([[0., 0., 0.]])]
[tensor([[0.4575],
        [0.6313],
        [0.7710],
        [0.6577],
        [0.7632],
        [0.6108],
        [0.3442],
        [0.8647]]), tensor([[0., 0.]])]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 8 but got size 1 for tensor number 1 in the list.