In [5]:
import torch
import torch.distributions as td
import torch.nn.functional as F
from torch import nn
from abc import ABC, abstractmethod

In [6]:
class Bijector(nn.Module, ABC):
    @abstractmethod
    def forward(self, u):
        raise NotImplementedError

    @abstractmethod
    def inverse(self, x):
        raise NotImplementedError
        
    @abstractmethod        
    def inverse_log_det_jacobian(self, x):
        raise NotImplementedError
        
        
class TransformedDistribution(nn.Module, ABC):
    def __init__(self, base_distribution, bijector):
        super(TransformedDistribution, self).__init__()
        self.base_distribution = base_distribution
        self.bijector = bijector
        return
    
    @abstractmethod
    def forward(self, u):
        raise NotImplementedError

    
class Flow(nn.Sequential):
    def forward(self, x, y):
        acc_log_abs_det_jacobian = 0
        for bijector in self:
            x, log_abs_det_jacobian = bijector(x, y)
            acc_log_abs_det_jacobian += log_abs_det_jacobian
        return x, acc_log_abs_det_jacobian

    def inverse(self, u, y):
        acc_log_abs_det_jacobian = 0
        for module in reversed(self):
            u, log_abs_det_jacobian = module.inverse(u, y)
            acc_log_abs_det_jacobian += log_abs_det_jacobian
        return u, acc_log_abs_det_jacobian
    
    def log_prob(self, x, y=None):
        raise NotImplementedError

In [25]:
"""
This module contains the most commonly used bijectors, i.e. parameterised distribution transformations
"""


class Affine(Bijector, nn.Linear):
    """
    Scales and shifts into a standard multivariate Gaussian
    """
    def forward(self, u, y=None):
        x = F.linear(u, self.weight, self.bias)
        if y is not None:
            x += F.linear(y, self.weight, self.bias)
        return x
    
    def inverse(self, x):
        return 
    
    def inverse_log_det_jacobian(self, x):
        raise NotImplementedError

In [26]:
class MaskedAutoregressiveFLow(Flow):
    def __init__(self, input_dim, hidden_dim, n_layers):
        self.base_distribution = td.Normal(torch.tensor(0.), torch.tensor(1.))
        bijectors = [Affine(input_dim, hidden_dim) for _ in range(n_layers)]
        super(MaskedAutoregressiveFLow, self).__init__(*bijectors)
        return
    
    def log_prob(self, x, y=None):
        u, acc_log_abs_det_jacobians = self.forward(x, y)
        return torch.sum(self.base_distribution.log_prob(u) + acc_log_abs_det_jacobians, dim=1)            

In [27]:
maf = MaskedAutoregressiveFLow(32, 32, 4)

In [28]:
maf(torch.randn(32, 32), None)

ValueError: too many values to unpack (expected 2)