In [33]:
import numpy as np
import torch
from copy import copy
import seaborn as sns

np.random.seed = 0
torch.manual_seed(0)

batch_size = 2
seq_len = 4
d_model = 8
kind = 3

arr1 = torch.round(torch.rand(batch_size, seq_len, d_model)*10)
arr2 = torch.round(torch.rand(batch_size, seq_len, d_model)*10)
arr3 = torch.round(torch.rand(batch_size, seq_len, d_model)*10)

# Self

In [68]:
import math
def get_self_mask(kind, seq_len):
    arr = torch.zeros(kind, kind) + 1
    mask = torch.block_diag(*[arr]*seq_len)
    mask = torch.where(mask==1, 0, -torch.inf)
    return mask

def get_cross_mask(kind, seq_len):
    arr = torch.zeros(1, kind) + 1
    mask = torch.block_diag(*[arr]*seq_len)
    mask = torch.where(mask==1, 0, -torch.inf)
    return mask


def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value, attn_weight

### MHA

In [92]:
torch.manual_seed(0)
arr = torch.stack([arr1, arr2, arr3], dim=-2).view(batch_size, -1, d_model)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr)
K = K_linear(arr)
V = V_linear(arr)

attn_output, attn_weight = scaled_dot_product_attention(query=Q, key=K, value=V, attn_mask=get_self_mask(kind, seq_len))

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 12, 8])
attn_weight: torch.Size([2, 12, 12])
tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1302e-01, 8.8690e-01, 7.9086e-05],
       grad_fn=<SelectBackward0>)
tensor([-5.2547, -2.5219, -3.8461, -2.2512, -0.8428,  3.0889, -4.2951, -0.1270],
       grad_fn=<SelectBackward0>)


### Naive

In [93]:
import math
torch.manual_seed(0)
arr = torch.stack([arr1, arr2, arr3], dim=-2).view(batch_size, -1, d_model)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr)
K = K_linear(arr)
V = V_linear(arr)

# softmax(QKt)/sqrt(d_model)
### QKt
QKt = torch.matmul(Q, K.permute(0,2,1))
### QKt/sqrt(d_model)
QKt /= math.sqrt(d_model)
## Mask
QKt += (get_self_mask(kind, seq_len))
### Softmax()
softmax = torch.nn.Softmax(dim=-1)
attn_weight = softmax(QKt)
### Softmax() @ V
attn_output = attn_weight @ V

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 12, 8])
attn_weight: torch.Size([2, 12, 12])
tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1302e-01, 8.8690e-01, 7.9086e-05],
       grad_fn=<SelectBackward0>)
tensor([-5.2547, -2.5219, -3.8461, -2.2512, -0.8428,  3.0889, -4.2951, -0.1270],
       grad_fn=<SelectBackward0>)


### Block

In [64]:
torch.manual_seed(0)
arr = torch.stack([arr1, arr2, arr3], dim=-2)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr)
K = K_linear(arr)
V = V_linear(arr)


# Attention score
attn_weight = torch.nn.functional.softmax(Q@K.permute(0,1,3,2)/math.sqrt(d_model), dim=-1)
attn_output = attn_weight @ V

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 4, 3, 8])
attn_weight: torch.Size([2, 4, 3, 3])
tensor([[2.6076e-01, 7.3909e-01, 1.5296e-04],
        [2.2207e-03, 9.9652e-01, 1.2559e-03],
        [1.1302e-01, 8.8690e-01, 7.9086e-05]], grad_fn=<SelectBackward0>)
tensor([[-5.6034, -2.4412, -3.4502, -1.4781, -0.6058,  2.9580, -4.1531,  0.1394],
        [-4.9994, -2.5835, -4.1425, -2.8293, -1.0179,  3.1918, -4.4018, -0.3230],
        [-5.2547, -2.5219, -3.8461, -2.2512, -0.8428,  3.0889, -4.2951, -0.1270]],
       grad_fn=<SelectBackward0>)


# Cross

### MHA

In [102]:
torch.manual_seed(0)
arr = torch.stack([arr1, arr2, arr3], dim=-2).view(batch_size, -1, d_model)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr1)
K = K_linear(arr)
V = V_linear(arr)

attn_output, attn_weight = scaled_dot_product_attention(query=Q, key=K, value=V, attn_mask=get_cross_mask(kind, seq_len))

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 4, 8])
attn_weight: torch.Size([2, 4, 12])
tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.6076e-01, 7.3909e-01, 1.5296e-04],
       grad_fn=<SelectBackward0>)
tensor([-5.6034, -2.4412, -3.4502, -1.4781, -0.6058,  2.9580, -4.1531,  0.1394],
       grad_fn=<SelectBackward0>)


### Naive

In [103]:
import math
torch.manual_seed(0)
arr = torch.stack([arr1, arr2, arr3], dim=-2).view(batch_size, -1, d_model)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr1)
K = K_linear(arr)
V = V_linear(arr)

# softmax(QKt)/sqrt(d_model)
### QKt
QKt = torch.matmul(Q, K.permute(0,2,1))

### QKt/sqrt(d_model)
QKt /= math.sqrt(d_model)
## Mask
QKt += get_cross_mask(kind, seq_len)
### Softmax()
softmax = torch.nn.Softmax(dim=-1)
attn_weight = softmax(QKt)
attn_output = attn_weight @ V

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 4, 8])
attn_weight: torch.Size([2, 4, 12])
tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.6076e-01, 7.3909e-01, 1.5296e-04],
       grad_fn=<SelectBackward0>)
tensor([-5.6034, -2.4412, -3.4502, -1.4781, -0.6058,  2.9580, -4.1531,  0.1394],
       grad_fn=<SelectBackward0>)


### Block

In [104]:
torch.manual_seed(0)
arr = torch.stack([arr1, arr2, arr3], dim=-2)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr1).unsqueeze(-2)
K = K_linear(arr)
V = V_linear(arr)

# Attention score
attn_weight = torch.nn.functional.softmax(Q@K.permute(0,1,3,2)/math.sqrt(d_model), dim=-1)
attn_output = attn_weight @ V

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 4, 1, 8])
attn_weight: torch.Size([2, 4, 1, 3])
tensor([[2.6076e-01, 7.3909e-01, 1.5296e-04]], grad_fn=<SelectBackward0>)
tensor([[-5.6034, -2.4412, -3.4502, -1.4781, -0.6058,  2.9580, -4.1531,  0.1394]],
       grad_fn=<SelectBackward0>)


# Multihead

### MHA

In [160]:
import math
def get_self_mask(kind, seq_len):
    arr = torch.zeros(kind, kind) + 1
    mask = torch.block_diag(*[arr]*seq_len)
    mask = torch.where(mask==1, 0, -torch.inf)
    return mask

def get_cross_mask(kind, seq_len):
    arr = torch.zeros(1, kind) + 1
    mask = torch.block_diag(*[arr]*seq_len)
    mask = torch.where(mask==1, 0, -torch.inf)
    return mask


def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    # print("query")
    # print(query.shape)
    # print(query[0][0])
    # print("_"*100)
    # print(key.transpose(-2, -1).shape)
    # print((key.transpose(-2, -1))[0][0]); print("_"*100)
    # print((query @ key.transpose(-2, -1)).shape)
    # print(query @ key.transpose(-2, -1))
    # print("_"*100)
    # print(attn_bias)
    # raise
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    # attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value, attn_weight

In [161]:
torch.manual_seed(0)

nhead = 2
arr = torch.stack([arr1, arr2, arr3], dim=-2).view(batch_size, -1, d_model)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr)
K = K_linear(arr)
V = V_linear(arr)

# Split heads
Q = Q.view(batch_size, -1, nhead, d_model//nhead).permute(0,2,1,3)
K = K.view(batch_size, -1, nhead, d_model//nhead).permute(0,2,1,3)
V = V.view(batch_size, -1, nhead, d_model//nhead).permute(0,2,1,3)

attn_output, attn_weight = scaled_dot_product_attention(query=Q, key=K, value=V, attn_mask=get_self_mask(kind, seq_len))
attn_output = attn_output.permute(0,2,1,3).reshape(batch_size, -1, d_model)

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 12, 8])
attn_weight: torch.Size([2, 2, 12, 12])
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.1066, 0.8446, 0.0488], grad_fn=<SelectBackward0>)
tensor([-5.6971, -2.4192, -3.3432, -1.2691, -0.7473,  3.2825, -4.3113,  0.0101],
       grad_fn=<SelectBackward0>)


### Block

In [166]:
torch.manual_seed(0)
arr = torch.stack([arr1, arr2, arr3], dim=-2)

# Linear transform
Q_linear = torch.nn.Linear(d_model, d_model)
K_linear = torch.nn.Linear(d_model, d_model)
V_linear = torch.nn.Linear(d_model, d_model)

Q = Q_linear(arr)
K = K_linear(arr)
V = V_linear(arr)

Q = Q.view(batch_size, seq_len, kind, nhead, d_model//nhead).permute(0,3,1,2,4)
K = K.view(batch_size, seq_len, kind, nhead, d_model//nhead).permute(0,3,1,2,4)
V = V.view(batch_size, seq_len, kind, nhead, d_model//nhead).permute(0,3,1,2,4)

# Attention score
# print("Q")
# print(Q.shape)
# print(Q[0][0])
# print("_"*100)
# print("K")
# print(K.permute(0,1,2,4,3).shape)
# print((K.permute(0,1,2,4,3))[0][0]); print("_"*100)
# print(Q@K.permute(0,1,2,4,3)[0][0])
# raise
attn_weight = torch.nn.functional.softmax((Q@K.permute(0,1,2,4,3))/math.sqrt(d_model//nhead), dim=-1)
attn_output = attn_weight @ V
attn_output = attn_output.permute(0,2,3,1,4).reshape(batch_size, seq_len, kind, d_model)

print("attn_output:", attn_output.shape)
print("attn_weight:", attn_weight.shape)
print(attn_weight[-1][-1][-1])
print(attn_output[-1][-1])

attn_output: torch.Size([2, 4, 3, 8])
attn_weight: torch.Size([2, 2, 4, 3, 3])
tensor([[1.0351e-01, 8.2873e-01, 6.7763e-02],
        [2.0962e-04, 9.9912e-01, 6.6725e-04],
        [1.0660e-01, 8.4464e-01, 4.8763e-02]], grad_fn=<SelectBackward0>)
tensor([[-6.5140, -2.2299, -2.4150,  0.5435, -0.7110,  3.3585, -4.3182,  0.0625],
        [-6.3079, -2.4043, -2.9771, -0.5100, -1.0224,  3.1913, -4.4036, -0.3284],
        [-5.6971, -2.4192, -3.3432, -1.2691, -0.7473,  3.2825, -4.3113,  0.0101]],
       grad_fn=<SelectBackward0>)
