In [None]:
    def forward(self, x):
        proj = self.c_attn(x)
        query, key, value = proj.chunk(3, -1)
        _, seq_len, _ = query.shape
        selected_mask = self.causal_mask[:seq_len, :seq_len]
        selected_bias = self.pos_bias[:seq_len, :seq_len]
        selected_bias = selected_bias.unsqueeze(0)
        masked_bias = selected_bias.masked_fill(~selected_mask, float('-inf'))

        # normalize k and bias to prevent numerical instability when taking exp
        maxk = key.max(dim=-1, keepdim=True)[0]
        key = key - maxk
        maxpb = masked_bias.max(dim=-1, keepdim=True)[0]
        masked_bias = masked_bias - maxpb

        key = torch.exp(key)
        expbias = torch.exp(masked_bias)
        num = torch.einsum('bij, bjd -> bid', expbias, key * value)
        denom = torch.einsum('bij, bjd -> bid', expbias, key)

        y = torch.sigmoid(query) * (num / denom)
        return self.c_proj(y)

In [46]:
import torch

In [47]:
batch, seq_len, dim = 2, 3, 4

In [48]:
q,k,v = torch.rand(batch, seq_len, dim), torch.rand(batch, seq_len, dim), torch.rand(batch, seq_len, dim)

In [49]:
mask = torch.rand(seq_len, seq_len)

In [50]:
k = torch.exp(k)
mask = mask.unsqueeze(0)
mask = torch.exp(mask)

In [51]:
kv = k * v

In [52]:
num = torch.einsum('bij, bjd -> bid', mask, kv)

In [53]:
imperative_num = torch.zeros(batch, seq_len, dim)

In [None]:
for b in range(batch):
    for i in range(seq_len):
        for d in range(dim):
            total = 0
            for j in range(seq_len):
                total += mask[0, i, j] * kv[b, j, d]
            imperative_num[b, i, d] = total

In [55]:
#check all close
print(torch.allclose(num, imperative_num))

True


In [56]:
mask

tensor([[[1.1534, 1.6675, 2.5779],
         [2.2789, 2.5981, 1.2872],
         [1.1746, 1.9078, 1.0497]]])

In [57]:
kv

tensor([[[0.6336, 1.1607, 0.8870, 2.2571],
         [0.0812, 0.0491, 0.0094, 0.7999],
         [0.7740, 0.4729, 1.8156, 0.7394]],

        [[0.8664, 0.2987, 1.2297, 0.8572],
         [0.7729, 1.5015, 2.2767, 0.9715],
         [0.0795, 0.9834, 1.1632, 0.3573]]])

In [58]:
num

tensor([[[ 2.8617,  2.6400,  5.7191,  5.8435],
         [ 2.6513,  3.3816,  4.3828,  8.1739],
         [ 1.7117,  1.9536,  2.9657,  4.9536]],

        [[ 2.4932,  5.3832,  8.2133,  3.5297],
         [ 4.0849,  5.8474, 10.2146,  4.9374],
         [ 2.5757,  4.2476,  7.0089,  3.2353]]])

# Finally Figured it Out!

So for that fancy-lookig einsum, what is really doing is surprisingly simple. Take for instance the following `kv`:

        [[0.6336, 1.1607, 0.8870, 2.2571],
         [0.0812, 0.0491, 0.0094, 0.7999],
         [0.7740, 0.4729, 1.8156, 0.7394]],

        [[0.8664, 0.2987, 1.2297, 0.8572],
         [0.7729, 1.5015, 2.2767, 0.9715],
         [0.0795, 0.9834, 1.1632, 0.3573]]

which was a element-wise multiplication of the values in the `k` and `v` matrices.

Now suppose we have the following mask:

        [[[1.1534, 1.6675, 2.5779],
         [2.2789, 2.5981, 1.2872],
         [1.1746, 1.9078, 1.0497]]]

Recall that in a transformer architecture, we take into account the effects of previous context by taking a weighted sum of the value vectors of the respective tokens. 

In this case, the first token of the first batch is represented using

        [0.6336, 1.1607, 0.8870, 2.2571]
embedding vector

In the resulting vector after "attention" has been applied, the first element is a weighted sum of the 3 vectors:

$w_1 (0.6336) + w_2 (0.0812) + w_3 (0.7740) $

where 0.0812 comes from

[0.0812, 0.0491, 0.0094, 0.7999]
and 0.7740 comes from

[0.0795, 0.9834, 1.1632, 0.3573]

This is what the einsum does.

The weights for the first vector will always be from the mask vector

[1.1534, 1.6675, 2.5779]
so $w_1 = 1.1534$, $w_2 = 1.6675$, and $w_3 = 2.5779$

So for example, the second element in the resulting vector for the first token will be:

$.1534*1.1607 + 1.6675*0.0491 + 2.5779*0.4729 = 2.6397$

For reference, the computed vector is:

[ 2.8617,  2.6400,  5.7191,  5.8435]

Notice that (for the print out at least), the second vector was rounded