# `014` Attention mechanisms

Requirements: 010 Embeddings, 013 LSTM

⚠️ WIP

There is a fundamental problem with LSTM, GRU and other RNNs, which is that they are not feedforward, so that the individual timesteps of processing a sequence cannot be parallelized. This time dependency makes training slower. RNNs typically capture dependencies 100

There is a different kind of layer called `Attention` that can be used to solve this problem. The idea is to have a layer that can look at the entire sequence at once and decide which parts of the sequence are important for the current timestep. This way, the layer can be parallelized and can capture long-term dependencies. The underlying idea is converting every element in the sequence into a linear combination of all the elements in the sequence, with the weights of the linear combination being learned. Let's see an implementation in code:

In [1]:
from json import loads
from matplotlib import pyplot as plt
from string import ascii_letters, digits
from time import time
from unicodedata import category, normalize
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocabulary = ascii_letters + digits + ' .,;\'!'
c2i = {c: i for i, c in enumerate(vocabulary)}
i2c = {i: c for i, c in enumerate(vocabulary)}

def vectorize_sentence(s):
	return [c2i[c] for c in normalize('NFD', s) if category(c) != 'Mn' and c in vocabulary]

In [2]:
input = torch.tensor([vectorize_sentence('Hello, World!')])  # (batch_size=1, context_length=13)

embedding_channels = 8
embeddings = torch.randn(len(vocabulary), embedding_channels)  # (vocabulary_size, embedding_channels=8)
input = embeddings[input]  # (batch_size=1, context_length=13, embedding_channels=8)

head_size = 16
# query space mapping: what each token is looking for
W_q = torch.randn(embedding_channels, head_size)  # (embedding_channels=8, head_size=16)
# key space mapping: what each has to offer
W_k = torch.randn(embedding_channels, head_size)  # (embedding_channels=8, head_size=16)

q = input @ W_q  # (batch_size=1, context_length=13, head_size=16)
k = input @ W_k  # (batch_size=1, context_length=13, head_size=16)

# dot product of what each token is looking for and what each has to offer -> attention weights: how much each token should pay attention to each other token
# note that you swap the last two dimensions of k to make a square matrix of dot products
attention_weights = (q @ k.transpose(-2, -1)).softmax(dim=-1)  # (batch_size=1, context_length=13, context_length=13)

# value space: what each token is contributing
W_v = torch.randn(embedding_channels, head_size)  # (embedding_channels=8, head_size=16)

v = input @ W_v  # (batch_size=1, context_length=13, head_size=16)

# weighted sum of what each token is contributing, based on how much each token should pay attention to each other token
output = attention_weights @ v  # (batch_size=1, context_length=13, head_size=16)

In [10]:
class SimpleAttention(torch.nn.Module):
	def __init__(self, embedding_channels, head_size):
		super().__init__()
		self.W_q = torch.nn.Parameter(torch.randn(embedding_channels, head_size))
		self.W_k = torch.nn.Parameter(torch.randn(embedding_channels, head_size))
		self.W_v = torch.nn.Parameter(torch.randn(embedding_channels, head_size))

	def forward(self, input):
		q = input @ self.W_q
		k = input @ self.W_k
		v = input @ self.W_v
		attention_weights = (q @ k.transpose(-2, -1)).softmax(dim=-1)
		return attention_weights @ v
	
class DigitClassifier(torch.nn.Module):
	def __init__(self, vocabulary_size, embedding_channels, head_size):
		super().__init__()
		self.embedding = torch.nn.Embedding(vocabulary_size, embedding_channels)
		self.attention = SimpleAttention(embedding_channels, head_size)
		self.fc = torch.nn.Linear(head_size, 10)
	
	def forward(self, input):
		embedding = self.embedding(input)
		attention = self.attention(embedding)
		return self.fc(attention.mean(dim=1))

model = DigitClassifier(len(vocabulary), 8, 16).to(device)

In [4]:
with open('custom-data/sentences.json', encoding='utf-8') as f:
	data = loads(f.read())

languages = list(data.keys())

X, Y = [], []
sentence_size = 16
for language, sentences in data.items():
	for sentence in sentences:
		sentence = vectorize_sentence(sentence)
		for i in range(len(sentence) - sentence_size + 1):
			X.append(sentence[i:i+sentence_size])
			Y.append(languages.index(language))
ix = torch.randperm(len(X))
X = torch.tensor([X[i] for i in ix], device=device)
Y = torch.tensor([Y[i] for i in ix], device=device)

print(f'Loaded {len(X)} sentences from {len(languages)} languages using {len(vocabulary)} different characters')
print(X[0], Y[0], '->', ''.join(i2c[i.item()] for i in X[0]), languages[Y[0]])

Loaded 6752415 sentences from 8 languages using 68 different characters
tensor([ 1, 24, 62, 45,  0, 13,  6, 62, 51,  0,  8, 24,  0, 13,  6, 64]) tensor(4) -> by Tang Zaiyang, en


In [11]:
model(X[0])

tensor([-0.4104,  0.2441, -1.0296,  0.5545,  0.7851, -0.5065, -0.4514, -0.9293,
         1.0055,  0.2061], grad_fn=<ViewBackward0>)