In [6]:
import torch
from torch import nn
from torch.nn import functional as F
import math

In [7]:
def scaled_dot_prod_attention(Q, K, V, mask=None):
	# softmax(Q@KT / sqrt(d_K))@V
	d_K = K.size(-1)
	scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_K)
	if mask is not None:
		scores = scores.masked_fill(mask==0, -1e9)
	attn_weights = F.softmax(scores)
	return attn_weights @ V

In [8]:
B = 2
L = 5
D = 3
Q = torch.randn(B, L, D)
K = torch.randn(B, L, D)
V = torch.randn(B, L, D)

In [9]:
scaled_dot_prod_attention(Q, K, V)

  attn_weights = F.softmax(scores)


tensor([[[-1.1489, -2.9056,  2.1917],
         [-0.5000, -1.5153,  1.3965],
         [ 0.1936, -0.9806,  0.4869],
         [-1.1397, -1.3985,  1.5173],
         [-1.1454, -1.8547,  1.8464]],

        [[ 0.8482,  0.1189,  0.4005],
         [ 1.8838,  0.5028,  0.6647],
         [ 2.1943, -0.0124,  0.3628],
         [ 1.7696,  1.1031,  1.0317],
         [ 1.5680,  1.2566,  0.9846]]])

In [10]:
class MultiHeadAttention(nn.Module):
	def __init__(self, d_model, num_heads, dropout=0.1):
		super().__init__()

		assert d_model % num_heads == 0

		self.d_model = d_model
		self.num_heads = num_heads
		self.d_k = d_model // num_heads
		
		self.W_q = nn.Linear(d_model, d_model)
		self.W_k = nn.Linear(d_model, d_model)
		self.W_v = nn.Linear(d_model, d_model)

		self.W_o = nn.Linear(d_model, d_model)
		self.dropout = nn.Dropout(dropout)
	
	def split_heads(self, x):
		B, L, _ = x.size()
		x = x.view(B, L, self.num_heads, self.d_k)
		return x.transpose(1, 2)
	
	def combine_heads(self, x):
		B, _, L, _ = x.size()
		x = x.transpose(1, 2).contiguous() # B, L, num_heads, d_k
		return x.view(B, L, self.d_model)

	def forward(self, x, mask=None):
		Q = self.W_q(x)
		K = self.W_k(x)
		V = self.W_v(x)

		Q = self.split_heads(Q)
		K = self.split_heads(K)
		V = self.split_heads(V)

		scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

		if mask is not None:
			scores = scores.masked_fill(mask==0, -1e9)
		
		attn_weights = F.softmax(scores, dim=-1)
		attn_weights = self.dropout(attn_weights)

		attn_output = torch.matmul(attn_weights, V)

		output = self.combine_heads(attn_output)
		output = self.W_o(output)

		return output, attn_weights

In [11]:
d_model = 100
num_heads = 10
dropout = 0.1

X = torch.randn(5, 15, d_model)

In [12]:
mha = MultiHeadAttention(d_model, num_heads, dropout)
op, _ = mha.forward(X)
op.size()

torch.Size([5, 15, 100])

In [13]:
class ein_MultiHeadAttention(nn.Module):
	def __init__(self, D, H, dropout=0.1):
		super().__init__()

		assert D % H == 0
		self.D = D
		self.H = H
		self.d_K = D // H

		self.W_q = nn.Linear(self.D, self.D)
		self.W_k = nn.Linear(self.D, self.D)
		self.W_v = nn.Linear(self.D, self.D)

		self.W_o = nn.Linear(self.D, self.D)
		self.dropout = nn.Dropout(dropout)
	
	def split_heads(self, x):
		B, L, _ = x.size()
		x = x.view(B, L, self.H, self.d_K)
		return torch.einsum('ijkl->ikjl', [x])
	
	def combine_heads(self, x):
		B, _, L, _ = x.size()
		x = torch.einsum('ijkl->ikjl', [x])
		return x.reshape(B, L, self.D)
	
	def forward(self, x, mask = None):
		Q = self.W_q(x)
		K = self.W_k(x)
		V = self.W_v(x)

		Q = self.split_heads(Q)
		K = self.split_heads(K)
		V = self.split_heads(V)

		scores = torch.einsum('bhld,bhdk->bhlk', [Q, torch.einsum('bhld->bhdl', [K])]) / math.sqrt(self.d_K)
		if mask is not None:
			scores = scores.masked_fill(mask==0, -1e9)
		attn_weights = F.softmax(scores, dim=-1)
		attn_weights = self.dropout(attn_weights)

		attn_output = torch.einsum('bhlk,bhkd->bhld', [attn_weights, V])

		attn_output = self.combine_heads(attn_output)

		output = self.W_o(attn_output)
		return output, attn_weights

In [14]:
mha = ein_MultiHeadAttention(d_model, num_heads, dropout)
op, _ = mha.forward(X)
op.size()

torch.Size([5, 15, 100])

In [15]:
f"Num of parameters: {sum([p.numel() for p in mha.parameters()]): ,}"

'Num of parameters:  40,400'