In [1]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

[1]: https://twitter.com/abhi1thakur/status/1470406495678439426
[2]: https://arxiv.org/pdf/2207.09238.pdf

### Attention is all you need
1. [Thread][1]
2. [Paper][2]

#### Basic Building Blocks

This is the `EDTransformer` (encoder-decoder) architecture as described in [2] used for seq2seq inference.

#### Unsused

In [2]:
# missing embedding
class Transformer(nn.Module):
    def __init__(
        self,
        d_model=512,
        num_heads=8,  # hyperparam H in [2]
        num_encoders=6,  # hyperparam l_enc in [2]
        num_decoders=6,  # hyperparam l_dec in [2]
    ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(d_model, num_heads, num_encoders)
        self.decoder = Decoder(d_model, num_heads, num_decoders)
    
    def forward(self, source, target, source_mask, target_mask):
        # source, target = Z, X in [2] (in case of EDTransformer)
        # source = Z, target = X (context in the decoder)
        encoder_output = self.encoder(source, source_mask)
        return self.decoder(target, encoder_output, source_mask, target_mask)

In [21]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, num_decoders):
        super(Decoder, self).__init__()
        self.decoder_layers = nn.ModuleList(
            [
                DecoderLayer(d_model, num_heads)
                for _ in range(num_decoders)
            ]
        )
        
    def forward(self, target, encoder, target_mask, encoder_mask):
        # Decoder takes the output from last the encoder layer
        # along with target embeddings and target mask as its input.
        # Encoder mask is still the source mask.
        output = None
        for layer in self.decoder_layers:
            output = layer(
                output if output is not None else target, encoder, target_mask, encoder_mask
            )
            
        return output
    
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.3):
        super(DecoderLayer, self).__init__()
        
        # This multi-head attention takes key, value (K, V) from
        # the final encoder output (K, V are equal). Query
        # comes from the output of the masked multi-head attention
        self.masked_attention = MultiHeadAttention(
            d_model, num_heads, dropout=dropout
        )

        self.attention = MultiHeadAttention(
            d_model, num_heads, dropout=dropout
        )
        
        # feed forward part, same as the encoder
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(d_ff, d_model)
        )
        
        self.attention_norm = nn.LayerNorm(d_model)
        self.masked_attention_norm = nn.LayerNorm(d_model)
        self.ffn_norm = nn.LayerNorm(d_model)
        
    def forward(self, target, encoder, target_mask, encoder_mask):
        x = target
        # line 14 of alg 8 from [2], q, k, v == target == X
        # Target mask := [[t_z <= t_x]] only tokens preceeding the
        # current token are used as the context (unidirectional/masked self-attention).
        x = self.masked_attention(q=x, k=x, v=x, mask=target_mask)
        x = self.masked_attention_norm(x)
        # Query comes from the output of the masked attention.
        # key, value come from the last encoder layer.
        # line 16 of alg 8 from [2]
        # q is X (target), k, v is the source (context) Z encoded by the encoder
        x = self.attention(q=x, k=encoder, v=encoder, mask=encoder_mask)
        x = self.attention_norm(x)
        x = self.ffn(x)
        x = self.ffn(norm)
        return x   

#### Used for ETransformer

In [4]:
class Encoder(nn.Module):
    
    def __init__(self, d_model, num_heads, num_encoders):
        # Encoder consists of `num_encoder` layers (L_enc in [2]).
        # ModuleList holds a list of submodules, can be indexed as Python list.
        # `linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])`
        super(Encoder, self).__init__()
        self.encoder_layers = nn.ModuleList(
            [
                EncoderLayer(d_model, num_heads)
                for _ in range(num_encoders)
            ]
        )
        
    def forward(self, source, source_mask):
        output = None
        # source shape (batch_size, sequence_len, embedding_dim)
        # source is the tensor of embedded tokens
        for layer in self.encoder_layers:
            # The output of one encoder goes into the next one.
            # Source mask stays the same.
            output = layer(output if output is not None else source, source_mask)
        # output shape (batch_size, sequence_len, d_model == embedding_dim)
            
        return output

In [16]:
class EncoderLayer(nn.Module):
    # layer l
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.3):
        super(EncoderLayer, self).__init__()
        
        # params \mathcal{W}_l^enc from [2] (H heads)
        self.attention = MultiHeadAttention(
            d_model, num_heads, dropout=dropout
        )
        
        # feed forward part of the encoder, after attention
        self.ffn = nn.Sequential(
            # Linear (Wx + b) is the same as Keras Dense layer without activations.
            nn.Linear(d_model, d_ff),  # W_mlp1^l, b_mlp2^l from [2]
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),  # W_mlp2^l, b_mlp2^l from [2]
            nn.Dropout(dropout),
        )
        
        self.attention_norm = nn.LayerNorm(d_model)  # β_l^1, γ_l^1 norm parameters
        self.ffn_norm = nn.LayerNorm(d_model)  # β_l^1, γ_l^1 norm parameters
        
    def forward(self, source, source_mask):
        x = source
        # x shape (batch_size, sequence_len, d_model)
        x = self.attention(q=x, k=x, v=x, mask=source_mask)  # line 5 of alg 8 from [2]
        x = self.attention_norm(x)  # line 6
        # x shape (batch_size, sequence_len, d_model)
        x = self.ffn(x)  # line 7
        x = self.ffn_norm(x)  # line 8
        # x shape (batch_size, sequence_len, d_model)
        return x

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = dropout
        
        self.attention_output_size = d_model // num_heads
        # 64 = 512 // 8
      
        # H layers of attentions,veach attention layer h
        # consists of the following params
        # (formula 4 and alg 5 from [2]):
        # W_q^h, b_q^h query params
        # W_k^h, b_k^h key params
        # W_v^h, b_v^h value params
        self.attentions = nn.ModuleList(
            [
                SelfAttention(d_model, self.attention_output_size)
                for _ in range(num_heads)
            ]
        )
        
        # W_O in formula (4) and alg 5 from [2]
        self.output = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask):
        # Output of each layer is concatenated.
        # line 3 of alg 5 from [2]
        # Y_h shape (n, m) (fill in n, m)!!!!!
        # Y_1, Y_h, ..., Y_H -> Y shape (n, h * m)
        # q, k, v shape (batch_size, sequence_len, d_model)
        x = torch.cat(
            [
                # line 1, 2 of alg 5 from [2]
                layer(q, k, v, mask) for layer in self.attentions
            ],
            dim=-1,
        )
        
        # Final dense layer after the attentions.
        # shape (batch_size, sequence_len, d_model)
        # 8 heads each producing shape (batch_size, sequence_len, 64)
        # concatenated into (batch_size, sequence_len, 512 == 8 * 64)
        x = self.output(x)
        # shape (batch_size, sequence_len, d_model)
        return x

In [18]:
class SelfAttention(nn.Module):
    # Self attention -> attention on the same sequence (X = Z)
    # It tells us which other tokens from the sequence are relevant
    # to current token.
    
    def __init__(self, d_model, output_size, dropout=0.3):
        super(SelfAttention, self).__init__()
        
        # Params of alg 4 from [2]
        # W_q, b_q query params
        # W_k, b_k key params
        # W_v, b_v value params
        self.query = nn.Linear(d_model, output_size)
        self.key = nn.Linear(d_model, output_size)
        self.value = nn.Linear(d_model, output_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        query = self.query(q)
        key = self.key(k)
        value = self.value(v)
        # q, k, v are the same tensors in case of self-attention
        
        batch_size = q.shape[0]
        # target and sequence length is the same
        # in case of self-attention
        target_len = q.shape[1]
        sequence_len = k.shape[1]
        
        dim_k = key.size(-1)
        
        # batch matrix multiplication
        # (b, n, m) x (b, m, p) -> (b, n, p) where b is the size of the batch
        # mat_1_i x mat_2_i = mat_out_i where for i-th matrix pair in the batch
        # S / sqrt(d_attn) in line 6 of alg 4 from  [2] (S = QK)
        scores = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(dim_k)
        # scores shape (batch_size, sequence_len, sequence_len)
        # query, key shape (batch_size, sequence_len, 512 // 8)
        # key.transpose(1, 2) shape (batch_size, 512 // 8, sequence_len)
        
        if mask is not None:
            # mask shape is (batch_size, sequence_len)
            # expanded mask shape is (batch_size, target_len, sequence_len)
            expanded_mask = mask[:, None, :].expand(batch_size, target_len, sequence_len)
            # triu -> upper triangular part of the matrix. Values below the diagonal of the
            # input tensor are set to 0. This creates the unidirectional self-attention.
            # Only the current token or the tokens preceeding the current token are considered
            # as the context. The model does not look into the future.
            subsequent_mask = 1 - torch.triu(
                torch.ones((target_len, target_len), device=mask.device, dtype=torch.unit8),
                # diagonal=1 means we are including the current token in the context
                diagonal=1
            )

            # Replaces a score with -inf if the mask in the same position is equal to 0.
            # line 5 of alg 4 from [2]
            scores = scores.masked_fill(expanded_mask == 0, -float("inf"))
            scores = scores.masked_fill(subsequent_mask == 0, -float("int"))
            
        # applying softmax row-wise, line 6 of alg 4 from [2]        
        weights = F.softmax(scores, dim=-1)
        # weights.shape == scores.shape
        # V^~ = V * softmax(S / sqrt(d_attn))
        # V^~ = V * weights, line 6 of alg 4 from [2]
        # value shape (batch_size, sequence_len, 512 // 8)
        # weight shape (batch_size, sequence_len, sequence_len)
        # bmm(weights, value) shape (batch_size, sequence_len, 512 // 8)
        return torch.bmm(weights, value)

In [19]:
class Embedding(nn.Module):
    
    def __init__(self, vocabulary_size, embedding_dim, max_sequence_len):
        super(Embedding, self).__init__()
        
        self.embedding = nn.Embedding(
            vocabulary_size, embedding_dim
        )
        
        # alg 2 from [2]
        # We are using learned positional embeddings.
        # Example https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # is using hard-coded positional embeddings (similar to the sin/cos formula from [2]).
        self.positional_embedding = nn.Embedding(
            max_sequence_len, embedding_dim
        )
        
    def forward(self, x):
        # input: shape (batch_size, sequence_len)
        # x = self.embeddings(x)
        # there is an embedding vector for every token in every sequence
        # output shape (batch_size, sequence_len, embedding_dim)
        
        batch_size, sequence_len = x.shape
        positions = torch.arange(sequence_len, dtype=torch.int32).repeat(batch_size, 1)
        
        return self.embedding(x) + self.positional_embedding(positions)

#### BERT (ETransformer)
Based on the `Encoder` from the `EDTransformer` architecture.

In [22]:
class ETransformer(nn.Module):
    def __init__(
        self,
        d_model=256,
        num_heads=2,
        num_encoders=2,
        vocabulary_size=100_000,
        max_sequence_len=128,
    ):
        super(ETransformer, self).__init__()
        
        # Add an id for the mask out token (used during the training).
        # See "Encoder-only transformer: BERT" paragraph in the section 6 of [2].
        vocabulary_size += 1
        self.vocabulary_size = vocabulary_size
        
        embedding_dim = d_model
        
        self.embedding = Embedding(vocabulary_size, embedding_dim, max_sequence_len)
        self.encoder = Encoder(d_model, num_heads, num_encoders)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 1024),
            nn.GELU(),
            nn.Linear(1024, d_model)
        )
    
        self.ffn_norm = nn.LayerNorm(d_model)
        
        # We want the final output to be softmax probability for every
        # token in the vocabulary. output_size == vocabulary_size
        self.unembedding = nn.Linear(d_model, vocabulary_size)
    
    def forward(self, tokenized_source):
        # source = X in [2] (in case of ETransformer)
        # source_mask = None because ETransformer is always using mask = 1 everywhere.
        # tokenized_source shape (batch_size, sequence_len), rows of sequences
        source = self.embedding(tokenized_source)
        # source shape (batch_size, sequence_len, embedding_dim)
        # each token in each sequence embedded into a vector
        encoder_output = self.encoder(source, source_mask=None)
        # encoder_output shape (batch_size, sequence_len, d_model == embedding_dim)
        self.ffn(encoder_output)
        # There is an additional dense layer + norm.
        output = encoder_output + self.ffn(encoder_output)
        # output shape (batch_size, sequence_len, d_model)
        output = self.ffn_norm(output)
        output = self.unembedding(output)
        # output shape (batch_size, sequence_len, vocabulary_size)
        # Because we are predicting probability that the i-th token of the j-th
        # is the k-th entry in the vocabulary.
        return F.softmax(output, dim=2)

In [224]:
et = ETransformer()
x = torch.zeros([2, 4], device="cpu", dtype=torch.int32)

# make sure that the vocabulary dimension adds up to 1
et(x).sum(dim=2)

tensor([[1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)

In [10]:
def process_to_flat(dataset):
    # Creates one lond 1-D tensor containing the whole dataset as a long "tokenized" string.
    return torch.cat(
        [
            torch.tensor(vocab(tokenizer(item)), dtype=torch.long)
            for item in dataset
            if len(item) > 0
        ]
    )

def create_sequences(flat_data, sequence_len=128):
    n_sequences = flat_data.size(0) // sequence_len
    
    # trim data to fit and reshape
    return flat_data[:n_sequences * sequence_len].reshape(n_sequences, sequence_len)

class SequenceBatchGenerator:
    
    def __init__(self, sequences, batch_size):
        self._sequences = sequences
        self._batch_size = batch_size
        self._n_batches = self._sequences.size(0) // batch_size    
        
    def __len__(self):
        return self._n_batches      
        
    def __iter__(self):
        for i in range(self._n_batches):
            start = i * self._batch_size
            end = start + self._batch_size
            yield i, self._sequences[start:end]

In [11]:
dataset = WikiText2(split="test")
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, dataset), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

sequences = create_sequences(process_to_flat(dataset))
sequence_generator = SequenceBatchGenerator(sequences, 64)

In [38]:
def train_one_epoch(sequence_generator, model, loss_fn, optimizer, p_mask=0.05):
    size = len(sequence_generator._sequences)
    for i, sequence_batch in sequence_generator:
        # sequence batch := tensor with containing a batch of tokenized sequences.
        # sequence batch shape (batch_size, sequence_len)

        # Deletes random tokens from the sequences with the probability p_mask.
        # Replaces them with the masked_token_id. Is that correct? What should be the replacement?
        # line 5 of alg 12 and section 6 from [2]
        mask = torch.rand(*sequence_batch.shape) < p_mask
        masked_sequence_batch = sequence_batch.clone()

        # Defines a special id for the masked out token.
        # We added +1 in the model.__init__ to the vocabulary_size
        # to make room for the masked out token.
        # max token id == vocab length - 1
        masked_token_id = model.vocabulary_size - 1
        masked_sequence_batch[mask] = masked_token_id

        # y_hat shape (batch_size, sequence_len, vocabulary_size)
        y_hat = model(masked_sequence_batch)
        
        # This will one-hot encodes the token ids (entries in the tensor)
        # which reshapes tensor
        # from (batch_size, sequence_len)
        # to (batch_size, sequence_len, vocabulary_size).
        # We are using the unmasked sequences with the original token ids as the target.
        y = F.one_hot(sequence_batch, num_classes=model.vocabulary_size).float()
        
        # To compute the loss, we use only only the probabilities
        # of the masked out tokens being the original token. That's 
        # line 7 of alg 12 from [2]
        # mask shape (batch_size, sequence_len)
        # Applying a 2-d mask to a 3-d tensor flattens the output tensor to 2-d.
        # y_hat_ shape (number of masked out tokens, vocabulary_size)
        y_hat_ = y_hat[mask]
        # y_ shape (number of masked out token, vocabulary_size)
        y_ = y[mask]
        
        loss = loss_fn(y_hat_, y_)
        
        # backprop
        # reset param gradients at the start of each iteration
        # with torch.autograd.set_detect_anomaly(True):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 5 == 0:
            loss_, current = loss.item(), i * len(sequence_batch)
            print(f"loss: {loss_:.7f}, [{current}/{size}]")

In [39]:
len(vocab)

12455

In [40]:
len(sequences)

1889

In [43]:
model = ETransformer(vocabulary_size=len(vocab))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
epochs = 1
for i in range(epochs):
    print(f"Epoch: {i}")
    train_one_epoch(sequence_generator, model, loss_fn, optimizer)

Epoch: 0


In [36]:
special_token = 5
x = torch.Tensor([[0, 2, 1], [3, 4, 4]]).long()

masked_x = x.clone()
mask = torch.Tensor([[True, True, True], [False, True, True]]).bool()
masked_x[mask] = special_token


et = ETransformer(vocabulary_size=5)

y_hat = et(masked_x)

In [193]:
y_hat.shape

torch.Size([2, 3, 6])

In [37]:
y_hat[mask].shape

torch.Size([5, 6])

In [234]:
F.one_hot(x)[mask].double()

tensor([[1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.]], dtype=torch.float64)

In [200]:
F.one_hot(x[mask])

tensor([[1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1]])