## Reference

https://github.com/idiap/fast-transformers/tree/2ad36b97e64cb93862937bd21fcc9568d989561f/fast_transformers

In [2]:
from functools import partial

import torch
from torch.nn import Module

### Masks

In [3]:
class BaseMask(object):
    @property
    def bool_matrix(self):
        """Return a bool (uint8) matrix with 1s to all places that should be
        kept."""
        raise NotImplementedError()

    @property
    def float_matrix(self):
        """Return the bool matrix as a float to be used as a multiplicative
        mask for non softmax attentions."""
        if not hasattr(self, "_float_matrix"):
            with torch.no_grad():
                self._float_matrix = self.bool_matrix.float()
        return self._float_matrix

    @property
    def lengths(self):
        """If the matrix is of the following form

            1 1 1 0 0 0 0
            1 0 0 0 0 0 0
            1 1 0 0 0 0 0

        then return it as a vector of integers

            3 1 2.
        """
        if not hasattr(self, "_lengths"):
            with torch.no_grad():
                lengths = self.bool_matrix.long().sum(dim=-1)
                # make sure that the mask starts with 1s and continues with 0s
                # this should be changed to something more efficient, however,
                # I chose simplicity over efficiency since the LengthMask class
                # will be used anyway (and the result is cached)
                m = self.bool_matrix.view(-1, self.shape[-1])
                for i, l in enumerate(lengths.view(-1)):
                    if not torch.all(m[i, :l]):
                        raise ValueError("The mask is not a length mask")
                self._lengths = lengths
        return self._lengths

    @property
    def shape(self):
        """Return the shape of the boolean mask."""
        return self.bool_matrix.shape

    @property
    def additive_matrix(self):
        """Return a float matrix to be added to an attention matrix before
        softmax."""
        if not hasattr(self, "_additive_matrix"):
            with torch.no_grad():
                self._additive_matrix = torch.log(self.bool_matrix.float())
        return self._additive_matrix

    @property
    def additive_matrix_finite(self):
        """Same as additive_matrix but with -1e24 instead of infinity."""
        if not hasattr(self, "_additive_matrix_finite"):
            with torch.no_grad():
                self._additive_matrix_finite = (
                    (~self.bool_matrix).float() * (-1e24)
                )
        return self._additive_matrix_finite

    @property
    def all_ones(self):
        """Return true if the mask is all ones."""
        if not hasattr(self, "_all_ones"):
            with torch.no_grad():
                self._all_ones = torch.all(self.bool_matrix)
        return self._all_ones

    @property
    def lower_triangular(self):
        """Return true if the attention is a triangular causal mask."""
        if not hasattr(self, "_lower_triangular"):
            self._lower_triangular = False
            with torch.no_grad():
                try:
                    lengths = self.lengths
                    if len(lengths.shape) == 1:
                        target = torch.arange(
                            1,
                            len(lengths)+1,
                            device=lengths.device
                        )
                        self._lower_triangular = torch.all(lengths == target)
                except ValueError:
                    pass
        return self._lower_triangular


class FullMask(BaseMask):
    """Thin wrapper over a pytorch tensor that provides the BaseMask
    interface.

    The arguments can be given both by keyword arguments and positional
    arguments. To imitate function overloading, the constructor checks the type
    of the first argument and if it is a tensor it treats it as the mask.
    otherwise it assumes that it was the N argument.

    Arguments
    ---------
        mask: The mask as a PyTorch tensor.
        N: The rows of the all True mask to be created if the mask argument is
           not provided.
        M: The columns of the all True mask to be created if the mask argument
           is not provided. If N is given M defaults to N.
        device: The device to create the mask in (defaults to cpu)
    """
    def __init__(self, mask=None, N=None, M=None, device="cpu"):
        # mask is a tensor so we ignore N and M
        if mask is not None and isinstance(mask, torch.Tensor):
            if mask.dtype != torch.bool:
                raise ValueError("FullMask expects the mask to be bool")
            with torch.no_grad():
                self._mask = mask.clone()
            return

        # mask is an integer, N is an integer and M is None so assume they were
        # passed as N, M
        if mask is not None and M is None and isinstance(mask, int):
            M = N
            N = mask

        if N is not None:
            M = M or N
            with torch.no_grad():
                self._mask = torch.ones(N, M, dtype=torch.bool, device=device)
            self._all_ones = True
            return

        raise ValueError("Either mask or N should be provided")

    @property
    def bool_matrix(self):
        return self._mask


class LengthMask(BaseMask):
    """Provide a BaseMask interface for lengths. Mostly to be used with
    sequences of different lengths.

    Arguments
    ---------
        lengths: The lengths as a PyTorch long tensor
        max_len: The maximum length for the mask (defaults to lengths.max())
        device: The device to be used for creating the masks (defaults to
                lengths.device)
    """
    def __init__(self, lengths, max_len=None, device=None):
        self._device = device or lengths.device
        with torch.no_grad():
            self._lengths = lengths.clone().to(self._device)
        self._max_len = max_len or self._lengths.max()

        self._bool_matrix = None
        self._all_ones = torch.all(self._lengths == self._max_len).item()

    @property
    def bool_matrix(self):
        if self._bool_matrix is None:
            with torch.no_grad():
                indices = torch.arange(self._max_len, device=self._device)
                self._bool_matrix = (
                    indices.view(1, -1) < self._lengths.view(-1, 1)
                )
        return self._bool_matrix


class TriangularCausalMask(LengthMask):
    """A square matrix with everything masked out above the diagonal.

    Arguments
    ---------
        N: The size of the matrix
        device: The device to create the mask in (defaults to cpu)
    """
    def __init__(self, N, device="cpu"):
        lengths = torch.arange(1, N+1, device=device)
        super(TriangularCausalMask, self).__init__(lengths, N, device)
        self._lower_triangular = True

### ELU Feature Map

In [4]:



class FeatureMap(Module):
    """Define the FeatureMap interface."""
    def __init__(self, query_dims):
        super().__init__()
        self.query_dims = query_dims

    def new_feature_map(self, device):
        """Create a new instance of this feature map. In particular, if it is a
        random feature map sample new parameters."""
        raise NotImplementedError()

    def forward_queries(self, x):
        """Encode the queries `x` using this feature map."""
        return self(x)

    def forward_keys(self, x):
        """Encode the keys `x` using this feature map."""
        return self(x)

    def forward(self, x):
        """Encode x using this feature map. For symmetric feature maps it
        suffices to define this function, but for asymmetric feature maps one
        needs to define the `forward_queries` and `forward_keys` functions."""
        raise NotImplementedError()

    @classmethod
    def factory(cls, *args, **kwargs):
        """Return a function that when called with the query dimensions returns
        an instance of this feature map.

        It is inherited by the subclasses so it is available in all feature
        maps.
        """
        def inner(query_dims):
            return cls(query_dims, *args, **kwargs)
        return inner


class ActivationFunctionFeatureMap(FeatureMap):
    """Define a feature map that is simply an element-wise activation
    function."""
    def __init__(self, query_dims, activation_function):
        super().__init__(query_dims)
        self.activation_function = activation_function

    def new_feature_map(self, device):
        return

    def forward(self, x):
        return self.activation_function(x)


elu_feature_map = ActivationFunctionFeatureMap.factory(
    lambda x: torch.nn.functional.elu(x) + 1
)

### Linear Attention Code

In [5]:
class LinearAttention(Module):
    """Implement unmasked attention using dot product of feature maps in
    O(N D^2) complexity.

    Given the queries, keys and values as Q, K, V instead of computing

        V' = softmax(Q.mm(K.t()), dim=-1).mm(V),

    we make use of a feature map function Φ(.) and perform the following
    computation

        V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).

    The above can be computed in O(N D^2) complexity where D is the
    dimensionality of Q, K and V and N is the sequence length. Depending on the
    feature map, however, the complexity of the attention might be limited.

    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6, event_dispatcher=""):
        super(LinearAttention, self).__init__()
        self.feature_map = elu_feature_map(query_dimensions)
        self.eps = eps

    def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths):
        """
        queries: (N, L, H, E)
        keys: (N, S, H, E)
        values: (N, S, H, D)
        attn_mask: (L, S)
        query_lengths: (N, L)
        key_lengths: (N, S)

        where
            - N: batch-size
            - L: seq len for queries
            - S: seq len for keys & values
            - H: number of heads
            - E: key & query dim
            - D: value dim
        """
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary attention masks"))

        # key_lengths => (N, S, 1, 1)
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

        return V.contiguous()

In [11]:

class TestLinearAttention(object): #(unittest.TestCase):
    
    def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"):
        return (
            torch.rand(N, L, H, E).to(device), # Q
            torch.rand(N, S, H, E).to(device), # K
            torch.rand(N, S, H, D).to(device), # V
            FullMask(L, S, device=device), # m1
            FullMask(N, L, device=device), # m2
            FullMask(N, S, device=device) # m3
        )

    # TODO: JPK added
    def get_their_forward(self):
        att = LinearAttention(32)
        q, k, v, m1, m2, m3 = self._get_inputs()
        v = att(q, k, v, m1, m2, m3)
        return v
    
    def test_forward(self):
        att = LinearAttention(32)
        q, k, v, m1, m2, m3 = self._get_inputs()
        v = att(q, k, v, m1, m2, m3)
        self.assertTrue(v.is_contiguous())

    def test_masking(self):
        att = LinearAttention(32)
        q, k, v, m1, m2, m3 = self._get_inputs()

        # Make sure that we raise an error if m1 is not all ones
        with self.assertRaises(RuntimeError):
            att(q, k, v, FullMask(torch.rand(*m1.shape) > 0.5), m2, m3)

        # Make sure that the key lengths is paid attention to
        q, k, v, m1, m2, m3 = self._get_inputs(S=10, D=1)
        m3 = LengthMask(torch.tensor(list(range(10)))+1)
        for i in range(9):
            v[i, i+1:] = 1e9
        v_new = att(q, k, v, m1, m2, m3)
        self.assertLess(v_new.max().item(), 1)

    def test_benchmark_cpu(self):
        q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64)
        att = LinearAttention(64)

        # warmup the cache
        for i in range(10):
            v_new = att(q, k, v, m1, m2, m3)

        # measure
        start = time.time()
        for i in range(10):
            v_new = att(q, k, v, m1, m2, m3)
        end = time.time()
        print("CPU time taken:", (end-start)*1000, "(ms)")


In [12]:


class MyLinearAttention(Module):
    """Implement unmasked attention using dot product of feature maps in
    O(N D^2) complexity.

    Given the queries, keys and values as Q, K, V instead of computing

        V' = softmax(Q.mm(K.t()), dim=-1).mm(V),

    we make use of a feature map function Φ(.) and perform the following
    computation

        V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).

    The above can be computed in O(N D^2) complexity where D is the
    dimensionality of Q, K and V and N is the sequence length. Depending on the
    feature map, however, the complexity of the attention might be limited.

    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6, event_dispatcher=""):
        super(MyLinearAttention, self).__init__()
        self.feature_map = elu_feature_map(query_dimensions)
        self.eps = eps

    def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths):
        """
        queries: (N, L, H, E)
        keys: (N, S, H, E)
        values: (N, S, H, D)
        attn_mask: (L, S)
        query_lengths: (N, L)
        key_lengths: (N, S)

        where
            - N: batch-size
            - L: seq len for queries
            - S: seq len for keys & values
            - H: number of heads
            - E: key & query dim
            - D: value dim
        """
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)
        # change the shapes so we broadcast across the right dims
        N, L, H, E = Q.shape
        _, S, _, _ = K.shape

        # (N, L, H, E) => (N, H, L, E)
        Q = Q.transpose(1, 2)
        # (N, S, H, E) => (N, H, S, E)
        K = K.transpose(1, 2)
        # (N, S, H, D) => (N, H, S, D)
        values = values.transpose(1, 2)

        # reshape K & V to get KV
        # (N, H, S, E) => (N, H, S, E, 1)
        K = K[:, :, :, :, None]
        # (N, H, S, D) => (N, H, S, 1, D)
        values = values[:, :, :, None, :]

        # (N, H, S, E, 1) x (N, H, S, 1, D) = (N, H, E, D)
        KV = torch.sum(K @ values, dim=2)
        # (N, H, L, E) x (N, H, E, D) = (N, H, L, D) 
        QKV = Q @ KV 
        # (N, H, E)
        K_sum = torch.sum(K, dim=2).squeeze()
        # (N, H, L, E) x (N, H, E, 1) = > (N, H, L, 1) 
        Z = 1 / (Q @ K_sum[:, :, :, None] + self.eps).squeeze()
        # (N, H, L, D) x (N, H, L, 1) = (N, H, L, D) 
        out = QKV * Z[:, :, :, None]
        # (N, H, L, D) => (N, L, H, D)
        out = out.transpose(1, 2)
        return out


In [92]:


class EinsumLinearAttention(Module):
    
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6, event_dispatcher=""):
        super().__init__()
        self.feature_map = elu_feature_map(query_dimensions)
        self.eps = eps

    def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths):
        """
        queries: (N, L, H, E)
        keys: (N, S, H, E)
        values: (N, S, H, D)
        attn_mask: (L, S)
        query_lengths: (N, L)
        key_lengths: (N, S)

        where
            - N: batch-size
            - L: seq len for queries
            - S: seq len for keys & values
            - H: number of heads
            - E: key & query dim
            - D: value dim
        """
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = values.transpose(1, 2)

        # (N, H, E, D)
        KV = torch.einsum("...ki,...kj->...ij", [K, V])
        # (N, H, L, E) x (N, H, E, D) = (N, H, L, D)
        QKV = torch.einsum("...ik,...kj->...ij", [Q, KV])
        # (N, H, E)
        # TODO: they do something w/ key lengths that you should look at
        K_sum = torch.sum(K, dim=-2)
        # (N, H, L, E) dot prod (N, H, E) = (N, H, L) 
        Z = 1 / (torch.einsum("...ij,...j->...i", [Q, K_sum]) + self.eps) 
        # (N, H, L, D) * (N, H, L, 1) = (N, H, L, D)
        out = QKV * Z[..., None]
        #  (N, H, L, D) =>  (N, L, H, D)
        out = out.transpose(1, 2)
        return out
        

In [82]:
qkv_temp = torch.rand(10, 4, 5, 64)
z_temp = torch.rand(10, 4, 5)

In [83]:
lin_tester = TestLinearAttention()

In [84]:
q, k, v, m1, m2, m3 = lin_tester._get_inputs()

In [85]:
mine = MyLinearAttention(32)

In [86]:
theirs = LinearAttention(32)

In [87]:
ein = EinsumLinearAttention(32)

In [88]:
my_out = mine(q, k, v, m1, m2, m3)
their_out = theirs(q, k, v, m1, m2, m3)
ein_out = ein(q, k, v, m1, m2, m3)

Q shape: torch.Size([10, 4, 5, 32])
V shape: torch.Size([10, 4, 8, 64])
K shape: torch.Size([10, 4, 8, 32])
KV shape: torch.Size([10, 4, 32, 64])
QKV shape: torch.Size([10, 4, 5, 64])
K sum shape: torch.Size([10, 4, 32])
Z shape: torch.Size([10, 4, 5])


In [89]:
torch.allclose(my_out, their_out)

True

In [90]:
torch.allclose(ein_out, their_out)

True

### Linear Attention Tests

### Causal Linear Attention Code

In [5]:
class CausalLinearAttention(Module):
    """Implement causally masked attention using dot product of feature maps in
    O(N D^2) complexity.

    See fast_transformers.attention.linear_attention.LinearAttention for the
    general concept of replacing the softmax with feature maps. In addition to
    that, we also make use of the fact that causal masking is a triangular mask
    which allows us to apply the masking and still compute the attention in O(N
    D^2) complexity.

    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
                 event_dispatcher=""):
        super(CausalLinearAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            elu_feature_map(query_dimensions)
        )
        self.eps = eps
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def _make_sizes_compatible(self, Q, K):
        """Either slice or pad K in case that the sizes do not match between Q
        and K."""
        N, L, H, E = Q.shape
        _, S, _, _ = K.shape
        if L == S:
            return Q, K

        if L < S:
            return Q, K[:, :L, :, :]

        if L > S:
            return Q, torch.cat([K, K.new_zeros(N, L-S, H, E)], dim=1)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure the attn_mask is a
        # lower triangular causal mask
        if not attn_mask.lower_triangular:
            raise RuntimeError(("CausalLinearAttention only supports full "
                                "lower triangular masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Ensure that Q and K have compatible sizes for the following
        # computations, namely L == S
        Q, K = self._make_sizes_compatible(Q, K)

        # TODO: Shall we divide the Q and K with a relatively large number to
        #       avoid numerical instabilities in computing the denominator?
        #       We used to divide each with the max norm of all q and k but
        #       that seems relatively costly for a simple normalization.

        # Compute the normalizers
        Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)

        # Compute the unnormalized result
        V = causal_linear(
            Q,
            K,
            values
        )

        return V * Z[:, :, :, None]


### Causal Linear Attention Tests

In [6]:
class TestCausalLinearAttention(unittest.TestCase):
    def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"):
        return (
            torch.rand(N, L, H, E).to(device),
            torch.rand(N, S, H, E).to(device),
            torch.rand(N, S, H, D).to(device),
            TriangularCausalMask(L, device=device),
            FullMask(N, L, device=device),
            FullMask(N, S, device=device)
        )

    def test_forward(self):
        att = CausalLinearAttention(32)
        q, k, v, m1, m2, m3 = self._get_inputs(L=5, S=5)
        v = att(q, k, v, m1, m2, m3)
        self.assertTrue(v.is_contiguous())

        q, k, v, m1, m2, m3 = self._get_inputs(L=5, S=10)
        v = att(q, k, v, m1, m2, m3)
        self.assertTrue(v.is_contiguous())

        q, k, v, m1, m2, m3 = self._get_inputs(L=10, S=5)
        v = att(q, k, v, m1, m2, m3)
        self.assertTrue(v.is_contiguous())

NameError: name 'unittest' is not defined