In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

class PositiveLinear(nn.Module):
    def __init__(self, in_features, out_features, constraint_type='clamp', store_weights=True):
        super(PositiveLinear, self).__init__()
        self.constraint_type = constraint_type
        self.in_features, self.out_features = in_features, out_features

        if store_weights:
            # initialize weight in Unif(+eps, sqrt(k)), i.e. log_weight in Unif(log(eps), log(sqrt(k)))
            # where k = 1 / in_features
            pre_weight = torch.rand((out_features, in_features))

            if self.constraint_type == 'exp':
                init_min, init_max = np.log(1e-2), np.log(np.sqrt(1 / in_features))
            elif self.constraint_type == 'clamp':
                init_min, init_max = 1e-2, np.sqrt(1 / in_features)
            elif self.constraint_type == '':
                init_min, init_max = -np.sqrt(1 / in_features), np.sqrt(1 / in_features)
            self.pre_weight = (pre_weight * (init_max - init_min)) + init_min

            bias = torch.rand((out_features))
            scale = 1 / in_features
            self.bias = bias * 2 * scale - scale

            self.pre_weight, self.bias = nn.Parameter(self.pre_weight), nn.Parameter(self.bias)

    def forward(self, x):
        if self.constraint_type == 'neg_exp':
            weight = 1 / self.pre_weight.exp()
            return x.mm(weight.T) - (self.bias.unsqueeze(-1) * weight).mean(axis=-1)
        elif self.constraint_type == 'exp':
            weight = self.pre_weight.exp()
        elif self.constraint_type == 'softmax':
            weight = F.softmax(self.pre_weight, dim=-1)
        elif self.constraint_type == 'clamp':
            weight = self.pre_weight.clamp(min=0.)
        elif self.constraint_type == '':
            weight = self.pre_weight
        return x.mm(weight.T) + self.bias

class MonotonicInverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, self, input):
        with torch.no_grad():
            b = bisection_search(self.F, input, self.start, self.end, n_iter=20)

        dy = 1 / torch.autograd.functional.jacobian(self.F, b, create_graph=True, vectorize=True)
        ctx.save_for_backward(dy.reshape(len(input), 1))
        return b

    @staticmethod
    def backward(ctx, grad_output):
        dy, = ctx.saved_tensors
        return None, dy

class ModelInverse(nn.Module):
    def __init__(self, arch, start=0., end=1., store_weights=True, 
                 constraint_type='exp', monotonic_const=1e-2, final_layer_constraint='softmax'):
        super(ModelInverse, self).__init__()
        self.d = arch[0]
        self.monotonic_const = monotonic_const
        self.store_weights = store_weights
        self.constraint_type = constraint_type
        self.final_layer_constraint = final_layer_constraint
        self.last_layer = len(arch) - 2
        self.layers = self.build_layers(arch)

        self.register_buffer('start', torch.tensor(start).reshape(1, 1))
        self.register_buffer('end', torch.tensor(end).reshape(1, 1))

    def build_layers(self, arch):
        self.n_params = 0
        layers = nn.ModuleList()

        for i, (a1, a2) in enumerate(zip(arch[:-1], arch[1:])):
            # add nonlinearities
            self.n_params += (a1 * a2)
            if i < self.last_layer:
                layers.append(PositiveLinear(a1, a2, store_weights=self.store_weights,
                                         constraint_type=self.constraint_type))
                layers.append(nn.Sigmoid())
                self.n_params += a2
            else:
                layers.append(PositiveLinear(a1, a2, store_weights=self.store_weights,
                                         constraint_type=self.final_layer_constraint))
                if self.final_layer_constraint != 'softmax':
                    layers.append(nn.Sigmoid())
                    self.n_params += a2

        return layers

    def set_params(self, param_tensor):
        if self.store_weights:
            raise NotImplementedError("set_parameters() should not be called if store_weights == True!")

        assert len(param_tensor) == self.n_params, "{} =/= {}".format(str(param_tensor.shape), str(self.n_params))

        cur_idx = 0
        i = 0
        for layer in self.layers:
            if isinstance(layer, PositiveLinear):
                weight_shape = (layer.out_features, layer.in_features)
                n_params = np.prod(weight_shape)
                layer.pre_weight = param_tensor[cur_idx:cur_idx+n_params].reshape(weight_shape)
                cur_idx += n_params

                if i < self.last_layer or self.final_layer_constraint != 'softmax':
                    layer.bias = param_tensor[cur_idx:cur_idx+layer.out_features]
                    cur_idx += layer.out_features
                else:
                    layer.bias = torch.zeros(layer.out_features).to(param_tensor.device)
                    
                i += 1

    def apply_layers(self, x):
        y = x
        for l in self.layers:
            y = l(y)

        return y + self.monotonic_const * x

    def scale(self, y):
        start, end = self.apply_layers(self.start), self.apply_layers(self.end)
        return (y - start) / (end - start)

    def forward(self, x):
        raise NotImplementedError("forward() should not be used!")

    def f(self, x):
        # compute df/dx
        dy = []
        for x_ in x:
            dy_ = torch.autograd.functional.jacobian(self.F, x_.reshape(-1, 1), create_graph=True, vectorize=False)
            dy.append(dy_.reshape(1, 1))

        dy = torch.cat(dy, axis=0)
        return dy

    def f_(self, x):
        # compute df/dx
        dy = []
        for x_ in x:
            dy_ = torch.autograd.functional.jacobian(self.apply_layers, x_.reshape(-1, 1),
                                                     create_graph=True, vectorize=False)
            dy.append(dy_.reshape(1, 1))

        dy = torch.cat(dy, axis=0)
        return dy

    def F(self, x):
        return self.scale(self.apply_layers(x))

    def pdf(self, x):
        return self.f(x)

    def cdf(self, x):
        return self.F(x)

    def F_inv(self, x):
        inverse = MonotonicInverse.apply

        z = []
        for x_ in x:
            z.append(inverse(self, x_).reshape(1, 1))

        z = torch.cat(z, axis=0)

        return z

    def sample(self, n, batch_size=1):
        x = []
        while n > batch_size:
            z = torch.rand(batch_size, device=self.start.device) * (self.end - self.start) + self.start
            x.append(self.F_inv(z.reshape(-1, 1)))
            n -= batch_size
        else:
            z = torch.rand(n, device=self.start.device) * (self.end - self.start) + self.start
            x.append(self.F_inv(z.reshape(-1, 1)))

        return torch.cat(x, axis=0)

def bisection_search(increasing_func, target, start, end, n_iter=20, eps=1e-3):
    query = (start + end) / 2
    result = increasing_func(query)

    if n_iter == 0:
        print("bottomed out recursion depth, return best guess epsilon =", (result - target).norm())
        return query
    elif (result - target).norm() < eps:
        return query
    elif result > target:
        return bisection_search(increasing_func, target, start, query, n_iter-1, eps)
    else:
        return bisection_search(increasing_func, target, query, end, n_iter-1, eps)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class NITSMonotonicInverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, self, input, params):
        with torch.no_grad():
            b = self.bisection_search(input, params)

        dy = 1 / self.pdf(b, params)
        ctx.save_for_backward(dy.reshape(len(input), -1))
        return b

    @staticmethod
    def backward(ctx, grad_output):
        dy, = ctx.saved_tensors
        return None, dy

class NITS(nn.Module):
    def __init__(self, arch, start=0., end=1., constraint_type='neg_exp',
                 monotonic_const=1e-3, activation='sigmoid', final_layer_constraint='softmax'):
        super(NITS, self).__init__()
        self.arch = arch
        self.monotonic_const = monotonic_const
        self.constraint_type = constraint_type
        self.final_layer_constraint = final_layer_constraint
        self.last_layer = len(arch) - 2
        self.activation = activation

        # count parameters
        self.n_params = 0

        for i, (a1, a2) in enumerate(zip(arch[:-1], arch[1:])):
            self.n_params += (a1 * a2)
            if i < self.last_layer or final_layer_constraint != 'softmax':
                self.n_params += a2

        # set start and end tensors
        self.register_buffer('start', torch.tensor(start).reshape(1, arch[0]))
        self.register_buffer('end', torch.tensor(end).reshape(1, arch[0]))

    def apply_constraint(self, A, constraint_type):
        if constraint_type == 'neg_exp':
            A = (-A).exp()
        if constraint_type == 'exp':
            A = A.exp()
        elif constraint_type == 'clamp':
            A = A.clamp(min=0.)
        elif constraint_type == 'softmax':
            A = F.softmax(A, dim=-1)
        elif constraint_type == '':
            pass

        return A

    def apply_act(self, x):
        if self.activation == 'tanh':
            return x.tanh()
        elif self.activation == 'sigmoid':
            return x.sigmoid()
        elif self.activation == 'linear':
            return x

    def forward_(self, x, params, return_intermediaries=False):
        orig_x = x

        # store pre-activations and weight matrices
        pre_activations = []
        nonlinearities = []
        As = []
        bs = []

        cur_idx = 0

        # compute layers
        for i, (in_features, out_features) in enumerate(zip(self.arch[:-1], self.arch[1:])):
            # get linear weights
            A_end = cur_idx + in_features * out_features
            A = params[:,cur_idx:A_end].reshape(-1, out_features, in_features)
            cur_idx = A_end

            constraint = self.constraint_type if i < self.last_layer else self.final_layer_constraint
            A = self.apply_constraint(A, constraint)
            As.append(A)
            x = torch.einsum('nij,nj->ni', A, x)

            # get bias weights if not softmax layer
            if i < self.last_layer or self.final_layer_constraint != 'softmax':
                b_end = A_end + out_features
                b = params[:,A_end:b_end].reshape(-1, out_features)
                bs.append(b)
                cur_idx = b_end
                if i < self.last_layer and self.constraint_type == 'neg_exp':
                    x = x - (b.unsqueeze(-1) * A).mean(axis=-1)
                elif i == self.last_layer and self.final_layer_constraint == 'neg_exp':
                    x = x - (b.unsqueeze(-1) * A).mean(axis=-1)
                else:
                    x = x + b
                pre_activations.append(x)
                x = self.apply_act(x)
                nonlinearities.append(self.activation)
            else:
                pre_activations.append(x)
                nonlinearities.append('linear')
        
        x = x + self.monotonic_const * orig_x

        if return_intermediaries:
            return x, pre_activations, As, bs, nonlinearities
        else:
            return x

    def cdf(self, x, params, return_intermediaries=False):
        # get scaling factors
        start = self.forward_(self.start, params)
        end = self.forward_(self.end, params)

        # compute pre-scaled cdf, then scale
        y, pre_activations, As, bs, nonlinearities = self.forward_(x, params, return_intermediaries=True)
        scale = 1 / (end - start)
        y_scaled = (y - start) * scale

        # accounting
        pre_activations.append(y_scaled)
        As.append(scale.reshape(-1, 1, 1))
        nonlinearities.append('linear')

        if return_intermediaries:
            return y_scaled, pre_activations, As, bs, nonlinearities
        else:
            return y_scaled

    def fc_gradient(self, grad, pre_activation, A, activation):
        if activation == 'linear':
            pass
        elif activation == 'tanh':
            grad = grad * (1 - pre_activation.tanh() ** 2)
        elif activation == 'sigmoid':
            sig_act = pre_activation.sigmoid()
            grad = grad * sig_act * (1 - sig_act)

        return torch.einsum('ni,nij->nj', grad, A)

    def backward_primitive_(self, y, pre_activations, As, bs, nonlinearities):
        pre_activations.reverse()
        As.reverse()
        nonlinearities.reverse()
        grad = torch.ones_like(y, device=y.device)

        for i, (A, pre_activation, nonlinearity) in enumerate(zip(As, pre_activations, nonlinearities)):
            grad = self.fc_gradient(grad, pre_activation, A, activation=nonlinearity)

        return grad

    def backward_(self, x, params):
        y, pre_activations, As, bs, nonlinearities = self.forward_(x, params, return_intermediaries=True)

        grad = self.backward_primitive_(y, pre_activations, As, bs, nonlinearities)

        return grad + self.monotonic_const

    def pdf(self, x, params):
        y, pre_activations, As, bs, nonlinearities = self.cdf(x, params, return_intermediaries=True)

        grad = self.backward_primitive_(y, pre_activations, As, bs, nonlinearities)

        return grad + self.monotonic_const * As[0].reshape(-1, 1)

    def sample(self, params):
        z = torch.rand((len(params), 1), device=params.device)

        with torch.no_grad():
            x = self.icdf(z, params)

        return x

    def icdf(self, z, params):
        func = NITSMonotonicInverse.apply

        return func(self, z, params)

    def bisection_search(self, y, params, eps=1e-3):
        low = torch.ones((len(y), 1), device=y.device) * self.start
        high = torch.ones((len(y), 1), device=y.device) * self.end

        while ((high - low) > eps).any():
            x_hat = (low + high) / 2
            y_hat = self.cdf(x_hat, params)
            low = torch.where(y_hat > y, low, x_hat)
            high = torch.where(y_hat > y, x_hat, high)

        return high

class MultiDimNITS(NITS):
    def __init__(self, d, arch, start=-2., end=2., constraint_type='neg_exp',
                 monotonic_const=1e-2, final_layer_constraint='softmax'):
        super(MultiDimNITS, self).__init__(arch, start, end,
                                           constraint_type=constraint_type,
                                           monotonic_const=monotonic_const,
                                           final_layer_constraint=final_layer_constraint)
        self.d = d
        self.tot_params = self.n_params * d
        self.final_layer_constraint = final_layer_constraint
        
        self.register_buffer('start', torch.tensor(start).reshape(1, 1).tile(1, d))
        self.register_buffer('end', torch.tensor(end).reshape(1, 1).tile(1, d))
        
        self.nits = NITS(arch, start, end, 
                         constraint_type=constraint_type,
                         monotonic_const=monotonic_const,
                         final_layer_constraint=final_layer_constraint)

    def multidim_reshape(self, x, params):
        n = max(len(x), len(params))
        _, d = x.shape
        assert d == self.d
        assert params.shape[1] == self.tot_params
        assert len(x) == len(params) or len(x) == 1 or len(params) == 1
        
        if len(params) == 1:
            params = params.reshape(self.d, self.n_params).tile((n, 1))
        elif len(params) == n:
            params = params.reshape(-1, self.n_params)
        else:
            raise NotImplementedError('len(params) should be 1 or {}, but it is {}.'.format(n, len(params)))
        
        if len(x) == 1:
            x = x.reshape(1, self.d).tile((n, 1)).reshape(-1, 1)
        elif len(x) == n:
            x = x.reshape(-1, 1)
        else:
            raise NotImplementedError('len(params) should be 1 or {}, but it is {}.'.format(n, len(x)))
            

        return x, params

    def forward_(self, x, params, return_intermediaries=False):
        n = max(len(x), len(params))
        x, params = self.multidim_reshape(x, params)

        if return_intermediaries:
            x, pre_activations, As, bs, nonlinearities = self.nits.forward_(x, params, return_intermediaries)
            x = x.reshape((n, self.d))
            return x, pre_activations, As, bs, nonlinearities
        else:
            x = self.nits.forward_(x, params, return_intermediaries)
            x = x.reshape((n, self.d))
            return x

    def backward_(self, x, params):
        n = max(len(x), len(params))
        x, params = self.multidim_reshape(x, params)

        return self.nits.backward_(x, params).reshape((n, self.d))

    def cdf(self, x, params):
        n = max(len(x), len(params))
        x, params = self.multidim_reshape(x, params)

        return self.nits.cdf(x, params).reshape((n, self.d))

    def icdf(self, x, params):
        n = max(len(x), len(params))
        x, params = self.multidim_reshape(x, params)

        return self.nits.icdf(x, params).reshape((n, self.d))

    def pdf(self, x, params):
        n = max(len(x), len(params))
        x, params = self.multidim_reshape(x, params)

        return self.nits.pdf(x, params).reshape((n, self.d))

    def sample(self, n, params):
        if len(params) == 1:
            params = params.reshape(self.d, self.n_params).tile((n, 1))
        elif len(params) == n:
            params = params.reshape(-1, self.n_params)

        return self.nits.sample(params).reshape((-1, self.d))

    def initialize_parameters(self, n, constraint_type):
        params = torch.rand((self.d * n, self.n_params))

        def init_constant(params, in_features, constraint_type):
            const = np.sqrt(1 / in_features)
            if constraint_type == 'clamp':
                params = params.abs() * const
            elif constraint_type == 'exp':
                params = params * np.log(const)
            elif constraint_type == 'tanh':
                params = params * np.arctanh(const - 1)

            return params

        cur_idx = 0

        for i, (a1, a2) in enumerate(zip(self.arch[:-1], self.arch[1:])):
            next_idx = cur_idx + (a1 * a2)
            if i < len(self.arch) - 2 or self.final_layer_constraint != 'softmax':
                 next_idx = next_idx + a2
            params[:,cur_idx:next_idx] = init_constant(params[:,cur_idx:next_idx], a2, constraint_type)
            cur_idx = next_idx

        return params.reshape((n, self.d * self.n_params))

In [None]:
import torch
import numpy as np
# from model import *
# from autograd_model import *

device = 'cpu'

n = 1024
start, end = -2., 2.
arch = [1, 8, 1]
monotonic_const = 1e-2

for d in [1, 2, 10]:
    for constraint_type in ['neg_exp', 'exp']:
        for final_layer_constraint in ['softmax', 'exp']:
            print("""
            Testing configuration:
                d: {}
                constraint_type: {}
                final_layer_constraint: {}
                  """.format(d, constraint_type, final_layer_constraint))
            ############################
            # DEFINE MODELS            #
            ############################
            
            model = MultiDimNITS(d=d, start=start, end=end, arch=arch,
                                 monotonic_const=monotonic_const, constraint_type=constraint_type,
                                 final_layer_constraint=final_layer_constraint).to(device)
            params = torch.randn((n, d * model.n_params)).to(device)
            
            ############################
            # SANITY CHECKS            #
            ############################

            # check that the function integrates to 1
            assert torch.allclose(torch.ones((n, d)).to(device), 
                                  model.cdf(model.end, params) - model.cdf(model.start, params), atol=1e-5)

            # check that the pdf is all positive
            z = torch.linspace(start, end, steps=n, device=device)[:,None].tile((1, d))
            assert (model.pdf(z, params) >= 0).all()

            # check that the cdf is the inverted
            cdf = model.cdf(z, params[0:1])
            icdf = model.icdf(cdf, params[0:1])
            assert (z - icdf <= 1e-3).all()

            ############################
            # COMPARE TO AUTOGRAD NITS #
            ############################
            autograd_model = ModelInverse(arch=arch, start=start, end=end, store_weights=False, 
                                          constraint_type=constraint_type, monotonic_const=monotonic_const,
                                          final_layer_constraint=final_layer_constraint)

            def zs_params_to_forwards(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.apply_layers(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_forwards(z, params)
            outs = model.forward_(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_cdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.cdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_cdfs(z, params)
            outs = model.cdf(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            def zs_params_to_pdfs(zs, params):
                out = []
                for z, param in zip(zs, params):
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z, params)
            outs = model.pdf(z, params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            # try with single parameter, many zs

            def zs_params_to_pdfs(zs, param):
                out = []
                for z in zs:
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z, params[0])
            outs = model.pdf(z, params[0:1])
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

            # try with single z, many parameters

            def zs_params_to_pdfs(z, params):
                out = []
                for param in params:
                    for d_ in range(d):
                        start_idx, end_idx = d_ * autograd_model.n_params, (d_ + 1) * autograd_model.n_params
                        autograd_model.set_params(param[start_idx:end_idx])
                        out.append(autograd_model.pdf(z[d_:d_+1][None,:]))

                out = torch.cat(out, axis=0).reshape(-1, d)
                return out

            autograd_outs = zs_params_to_pdfs(z[0], params)
            outs = model.pdf(z[0:1], params)
            assert torch.allclose(autograd_outs, outs, atol=1e-4)

In [25]:
from discretized_mol import *
print("Testing arch = [1, 10, 1], 'neg_exp' constraint_type, 'softmax' final_layer_constraint " \
      "against discretized mixture of logistics.")

model = MultiDimNITS(d=1, start=-1e5, end=1e5, arch=[1, 10, 1],
                     monotonic_const=0., constraint_type='neg_exp',
                     final_layer_constraint='softmax').to(device)
params = torch.randn((n, model.n_params, 1, 1))
z = torch.randn((n, 1, 1, 1))

loss1 = discretized_mix_logistic_loss_1d3(z, params)
loss2 = discretized_nits_loss(z, params, arch=[1, 10, 1], nits_model=model)

assert (loss1 - loss2).norm() < 1e-3
            
print("Finished unit tests. All passed!")

Testing arch = [1, 10, 1], 'neg_exp' constraint_type, 'softmax' final_layer_constraint against discretized mixture of logistics.
Finished unit tests. All passed!


In [None]:
final_layer_constraint = 'softmax'
model = MultiDimNITS(d=d, start=start, end=end, arch=arch,
                                 monotonic_const=monotonic_const, constraint_type=constraint_type,
                                 final_layer_constraint=final_layer_constraint).to(device)
params = torch.randn((n, d * model.n_params)).to(device)
autograd_model = ModelInverse(arch=arch, start=start, end=end, store_weights=False, 
                                          constraint_type=constraint_type, monotonic_const=monotonic_const,
                                          final_layer_constraint=final_layer_constraint)

print("""Testing configuration:
                     d: {}
                     constraint_type: {}
                     final_layer_constraint: {}
                  """.format(d, constraint_type, final_layer_constraint))

In [None]:
autograd_model.set_params(params[0])
autograd_model.apply_layers(z[40:41])

In [None]:
1 / autograd_model.layers[0].pre_weight.exp()

In [None]:
x, pre_activations, As, bs, nonlinearities = model.forward_(z[40:41], params[0:1], return_intermediaries=True)

In [None]:
pre_activations

In [None]:
As

In [None]:
bs