In [1]:
import torch 
from torch import nn

In [4]:
class LinearMasked(nn.Module):
    def __init__(self, in_features, out_features, num_input_features, bias=True):
        """

        Parameters
        ----------
        in_features : int
        out_features : int
        num_input_features : int
            Number of features of the models input X.
            These are needed for all masked layers.
        bias : bool
        """
        super(LinearMasked, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias)
        self.num_input_features = num_input_features

        assert (
            out_features >= num_input_features
        ), "To ensure autoregression, the output there should be enough hidden nodes. h >= in."

        # Make sure that d-values are assigned to m
        # d = 1, 2, ... D-1
        d = set(range(1, num_input_features))
        c = 0
        while True:
            c += 1
            if c > 10:
                break
            # m function of the paper. Every hidden node, gets a number between 1 and D-1
            self.m = torch.randint(1, num_input_features, size=(out_features,)).type(
                torch.int32
            )
            if len(d - set(self.m.numpy())) == 0:
                break

            self.register_buffer(
                "mask", torch.ones_like(self.linear.weight).type(torch.uint8)
            )

    def set_mask(self, m_previous_layer):
        """
        Sets mask matrix of the current layer.

        Parameters
        ----------
        m_previous_layer : tensor
            m values for previous layer layer.
            The first layers should be incremental except for the last value,
            as the model does not make a prediction P(x_D+1 | x_<D + 1).
            The last prediction is P(x_D| x_<D)
        """
        self.mask[...] = (m_previous_layer[:, None] <= self.m[None, :]).T

    def forward(self, x):
        if self.linear.bias is None:
            b = 0
        else:
            b = self.linear.bias

        return F.linear(x, self.linear.weight * self.mask, b)

In [5]:
lin_masked = LinearMasked(5, 5, 3)