In [1]:
import torch as t
import torch.nn as nn

from fancy_einsum import einsum

In [2]:
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)
    '''
    batch = Q.shape[0]
    seq_len = Q.shape[1]
    headsize = Q.shape[2] // num_heads

    Q = Q.reshape(batch, seq_len, num_heads, headsize)
    K = K.reshape(batch, seq_len, num_heads, headsize)
    V = V.reshape(batch, seq_len, num_heads, headsize)

    scale = t.sqrt(t.tensor(K.shape[-1]).type(t.float32))
    raw_attention_filter = einsum('b sl_Q nh hs, b sl_K nh hs -> b nh sl_Q sl_K', Q, K)
    mask_filter = t.triu(t.full_like(raw_attention_filter, -t.inf), 1)
    masked_attention_filter = t.softmax((raw_attention_filter + mask_filter) / scale, dim=-1)
    attention_values = einsum('b nh sl_Q sl_K, b sl_K nh hs -> b sl_Q nh hs', masked_attention_filter, V)
    return attention_values.reshape(batch, seq_len, num_heads * headsize)

In [7]:
class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):
        assert hidden_size % num_heads == 0, "num_heads should be divisible by hidden_size"
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size, bias=False)
        self.W_O = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        headsize = self.hidden_size // self.num_heads

        QKV = self.W_QKV(x)        
        Q = QKV[..., :self.hidden_size]
        K = QKV[..., self.hidden_size:2*self.hidden_size]
        V = QKV[..., 2*self.hidden_size:3*self.hidden_size]
        attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
        return self.W_O(attention_values)

In [23]:
t.manual_seed(420)
m = MultiheadMaskedAttention(6, 2)
x = t.linspace(0, 42, 2 * 3 * 6).reshape(2, 3, 6)
m(x)

tensor([[[  0.9091,  -1.2757,   0.7524,  -0.4398,   0.5692,  -0.0323],
         [  1.0873,  -1.7496,   0.8778,  -0.5991,   1.0275,   0.3806],
         [  1.7786,  -3.5476,   1.3684,  -1.2170,   2.8066,   2.0091]],

        [[  0.2038, -13.1862,  -1.1567,   0.3344,  -1.8587, -11.3195],
         [  0.7182, -14.5300,  -0.7923,  -0.1254,  -0.5352, -10.1120],
         [  1.2383, -15.8890,  -0.4238,  -0.5903,   0.8032,  -8.8908]]],
       grad_fn=<UnsafeViewBackward0>)