In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class Linear_attention(nn.Module):
    """Heart of https://arxiv.org/pdf/2006.16236.pdf
    """

    def __init__(self, d_model, n_heads, eps=1e-6):

        assert d_model % n_heads == 0, 'd_model must be a multiple of n_heads'
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = self.d_model // self.n_heads
        self.eps = eps
        self.w_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        # Trick to get model device. Stolen from:
        # https://stackoverflow.com/questions/58926054/how-to-get-the-device-type-of-a-pytorch-module-conveniently
        self.dummy_param = nn.Parameter(torch.empty(0))

    def dev(self):
        return self.dummy_param.device

    def get_mask1(self, l):
        # l -> [batch_size]
        mask = torch.arange(l.max(), device=self.dev())[None, :] < l[:, None]
        # mask -> [batch_size, seq_len]
        mask = torch.logical_not(mask[..., None, None])
        # mask -> [batch_size, seq_len, 1, 1]
        return mask

    def get_mask(self, l):
        # l -> [batch_size]
        mask = torch.arange(l.max(), device=self.dev())[None, :] < l[:, None]
        # mask -> [batch_size, seq_len]
        mask = torch.logical_not(mask[:, None, :, None])
        # mask -> [batch_size, 1, seq_len, 1]
        return mask

    def zero_out_padded1(self, x, mask):
        # x -> [batch_size, n_heads, seq_len, d_head]
        # mask -> [batch_size, 1, seq_len, 1]
        x.masked_fill_(mask, 0)
        return x

    def zero_out_padded(self, x, mask):
        # x -> [batch_size, seq_len, n_heads, d_head]
        # mask -> [batch_size, seq_len, 1, 1]
        x.masked_fill_(mask, 0)
        return x

    def split_heads1(self, x):
        batch_size, seq_len = x.shape[:2]
        # x -> [batch_size, seq_len, d_model]
        return x.view(
            batch_size, seq_len, self.n_heads, self.d_head).permute(0, 2, 1, 3)
        # result -> [batch_size, n_heads, seq_len, d_head]

    def split_heads(self, x):
        batch_size, seq_len = x.shape[:2]
        # x -> [batch_size, seq_len, d_model]
        return x.view(batch_size, seq_len, self.n_heads, self.d_head)
        # result -> [batch_size, seq_len, n_heads, d_head]

    def join_heads1(self, x):
        batch_size, seq_len = x.shape[0], x.shape[2]
        # x -> [batch_size, n_heads, seq_len, d_head]
        return x.permute(0, 2, 1, 3).view(batch_size, seq_len, self.d_model)
        # result -> [batch_size, seq_len, d_model]

    def join_heads(self, x):
        batch_size, seq_len = x.shape[:2]
        # x -> [batch_size, seq_len, n_heads, d_head]
        return x.view(batch_size, seq_len, self.d_model)
        # result -> [batch_size, seq_len, d_model]

    def kernel1(self, x):
        return F.elu(x) + 1

    def linear_attention1(self, q, k, v):
        # q, k, v -> [batch_size, n_heads, seq_len, d_head]
        kv = torch.einsum('bnsd,bnsd->bndd', k, v)
        # kv -> [batch_size, n_heads, d_head, d_head]
        return torch.einsum('bnsd,bndd->bnsd', q, kv)
        # result -> [batch_size, n_heads, seq_len, d_head]

    def linear_attention(self, q, k, v):
        # stolen from 
        # https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/kernel_attention.py
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]
        kv = torch.einsum('bsnx,bsnz->bnxz', k, v)
        # kv -> [batch_size, n_heads, d_head, d_head]
        # add dropout here
        denominator = 1.0 / (torch.einsum('bsnd,bnd->bsn', q, k.sum(axis=1)) + self.eps)
        # denominator -> [batch_size, seq_len, n_heads]

        output = torch.einsum('bsnx,bnxz,bsn->bsnz', q, kv, denominator).contiguous()
        # output -> [batch_size, seq_len, n_heads, d_head]

        return output

    def forward(self, x, mask):
        # x -> [batch_size, seq_len, d_model]
        # mask -> [batch_size, 1, seq_len, 1] 
        q, k, v = torch.chunk(self.w_qkv(x), 3, -1) 
        # q, k, v -> [batch_size, seq_len, d_model]

        q = self.kernel1(self.split_heads(q))
        k = self.kernel1(self.split_heads(k))
        v = self.zero_out_padded(self.split_heads(v), mask)
        # q, k, v -> [batch_size, seq_len, n_heads, d_head]

        x = self.linear_attention(q, k, v)
        # x -> [batch_size, seq_len, n_heads, d_head]
        x = self.join_heads(x)
        # x -> [batch_size, seq_len, d_model]

        return x






In [None]:
lens = torch.tensor([5, 8, 2])
torch.arange(lens.max(), device=self.dev())[None, :] < lens[:, None]

tensor([[ True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True, False, False, False, False, False, False]])

In [None]:
batch_size = 2
seq_len = 4
d_model = 5

x = torch.arange(batch_size * seq_len * 3 * d_model).reshape(batch_size, seq_len, 3 * d_model)

y = torch.chunk(x, 3, -1)

print(x)

print(y)



In [None]:
a, b = x.shape[:2]
type(a)

int