In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def dot_product_attention(q, K, V):
    """ Dot-Product Attention on one query
    Args:
        q: a vector with shape [k]
        K: a matrix with shape [m, k], means m keys
        V: a matrix with shape [m, v], means m values
    Returns:
        y: a vector with shape [v]
    """
    logits = torch.einsum("k, mk->m", q, K)
    weights = F.softmax(logits, dim=0)
    return torch.einsum("m,mv->v", weights, V)
    


In [None]:
D = 512
SRC_LEN = D * 2
q = torch.randn(D)
K = torch.randn(SRC_LEN, D)
V = torch.randn(SRC_LEN, D)

In [None]:
y = dot_product_attention(q, K, V)

In [None]:
assert y.shape == torch.Size((D,))

In [None]:
y.shape

torch.Size([512])

In [None]:
def multihead_attention(x, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on one query
    Args :
        x: a vector with shape [d]
        M: a matrix with shape [m, d], m is source sequence length
        P_q: a tensor with shape [h, d, k], the projection tensor for query
        P_k: a tensor with shape [h, d, k], the 
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
        Returns :
        y : a vector with shape [d]
    """
    q = torch.einsum("d, hdk->hk", x, P_q)
    K = torch.einsum("md, hdk->hmk", M, P_k)
    V = torch.einsum("md, hdv->hmv", M, P_v)
    logits = torch.einsum("hk, hmk->hm", q, K)
    weights = torch.softmax(logits, dim=-1)
    o = torch.einsum("hm,hmv->hv", weights, V)
    y = torch.einsum("hv, hdv->d", o, P_o)
    return y
    
    

In [None]:
D = 512
H = 16
K = V = D // H
SRC_LEN = 256

In [None]:
x = torch.randn((D,))
M = torch.randn((SRC_LEN, D))
P_q = torch.randn((H, D, K))
P_k = torch.randn((H, D, K))
P_v = torch.randn((H, D, V))
P_o = torch.randn((H, D, V))

In [None]:
y = multihead_attention(x, M, P_q, P_k, P_v, P_o)

In [None]:
assert y.shape == torch.Size((D,))

In [None]:
def multihead_attention_batched(X, M, mask, P_q, P_k, P_v, P_o):
    """ Multi−head Attention.
        Args :
            X: a tensor with shape [b, n, d]
            M: a tensor with shape [b, m, d]
            mask : a tensor with shape [b, h, n, m]
            P_q: a tensor with shape [h, d, k]
            P_k: a tensor with shape [h, d, k]
            P_v: a tensor with shape [h, d, v]
            P_o: a tensor with shape [h, d, v]
        Returns :
            Y: a tensor with shape [b, n, d]
    """
    Q = torch.einsum("bnd, hdk->bhnk", X, P_q)
    K = torch.einsum("bmd, hdk->bhmk", M, P_k)
    V = torch.einsum("bmd, hdv->bhmv", M, P_v)
    logits = torch.einsum("bhnk,bhmk->bhnm", Q, K)
    weights = F.softmax(logits + mask, dim=-1)
    O = torch.einsum("bhnm,bhmv->bhnv", weights, V)
    Y = torch.einsum("bhnv,hdv->bnd", O, P_o)
    return Y

In [None]:
B = 32
SRC_LEN = 256
TGT_LEN = 128
D = 512
H = 16
K = V = D // H

In [None]:
x = torch.randn((B, TGT_LEN, D))
M = torch.randn((B, SRC_LEN, D))
MASK = torch.randn((B, H, TGT_LEN, SRC_LEN))
P_q = torch.randn((H, D, K))
P_k = torch.randn((H, D, K))
P_v = torch.randn((H, D, V))
P_o = torch.randn((H, D, V))

In [None]:
y = multihead_attention_batched(x, M, MASK, P_q, P_k, P_v, P_o)

In [None]:
assert y.shape == torch.Size((B, TGT_LEN, D))