In [27]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

"""
Setting some hyperparameters
"""
# size of the query, key, value and z vectors 
# in the attention layer (64 was used in paper)
ATTENTION_OUTPUT_DIM = 64
WORD_EMBEDDING_DIM = 512

def scaled_dot_product_attention(queries, keys, values, mask=None):
    scores = queries @ keys.T
    scaled_scores = scores / torch.sqrt(keys.shape[1])
    softmax_scores = F.softmax(scaled_scores, dim=1)
    z = softmax_scores @ values
    return z

def positional_encoding(seq_len, input_dim, device):
    """ FROM https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51
    TODO understand this.
    """
    pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
    dim = torch.arange(dim_model, dtype=torch.float, device=device).reshape(1, 1, -1)
    phase = pos / 10000 ** (dim // dim_model)

    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

class TwoLayerNN(nn.Module):
    """ Two layer network, without activation function for the last layer like in the paper
    """
    def __init__(self, input_dim, hidden_dim=2048):
        super().__init__()
        self.dense_1 = nn.Linear(input_dim, hidden_dim)
        self.dense_2 = nn.Linear(hidden_dim, input_dim)
        
    def forward(self, x):
        x = self.dense_1(x)
        x = F.relu(x)
        x = self.dense_2(x)
        return x
    
    
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, input_dim, num_heads):
        """
        input_dim : dim of the input, output will have same dimension
        num_heads : number of self attention layers to concatenate, each will have a 
        query, key and value dimension of num_heads // input_dim
        """
        super().__init__()
        self.head_dim = input_dim // num_heads
        self.num_heads = num_heads
        self.W_o = nn.Linear(self.head_dim * num_heads, input_dim, bias=False)
        
    def forward(self, queries, keys, values, mask=None):
        dot_attentions = []
        for i in range(0, self.num_heads * self.head_dim, self.head_dim):
            q = queries[:,i:i+self.head_dim]
            k = keys[:,i:i+self.head_dim]
            v = values[:,i:i+self.head_dim]
            dot_attention = scaled_dot_product_attention(q, k, v)
            dot_attentions.append(dot_attention)
        z = torch.cat(dot_attentions, dim=1)
        output = self.W_o(z)
        return z
    
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dropout_p=0.1, ff_hidden_dim=2048):
        super().__init__()
        self.W_q = nn.Linear(input_dim, input_dim, bias=False)
        self.W_k = nn.Linear(input_dim, input_dim, bias=False)
        self.W_v = nn.Linear(input_dim, input_dim, bias=False)
        self.attention = MultiHeadAttentionLayer(input_dim, num_heads, input_dim)
        self.feed_forward = TwoLayerNN(input_dim, ff_hidden_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.norm_1 = nn.LayerNorm(input_dim)
        self.norm_2 = nn.LayerNorm(input_dim)
        
    def forward(self, x, mask=None):
        x = self.dropout(x)
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)
        z = self.attention(queries, keys, values, mask)
        z = self.dropout(z)
        x = self.norm_1(x + z)
        z = self.feed_forward(normalised)
        z = self.dropout(z)
        x = self.norm_2(x + z)
        return x

class Encoder(nn.Module):
    def __init__(self, vocab_len, device, num_heads=8, embedding_dim=512, n_layers=6, dropout_p=0.1, ff_hidden_dim=2048):
        self.word_embedding = nn.Embedding(vocab_len, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.encoder_blocks = nn.ModuleList(
            [EncoderBlock(embedding_dim, num_heads, dropout_p, ff_hidden_dim) for i in range(n_layers)]
        )
        self.embedding_dim = embedding_dim
        self.device = device
        
        
    def forward(self, x):
        x = self.word_embedding(x)
        pos_enc = positional_encoding(x.shape[0], self.embedding_dim, self.device)
        x = x + pos_enc
        x = self.dropout(x)
        x = self.encoder_blocks(x)
        return x

SyntaxError: invalid syntax (<ipython-input-27-e1e3d21d637a>, line 21)

In [14]:
test = torch.ones((30, WORD_EMBEDDING_DIM)).float()
multi_headed_attention = MultiHeadAttentionLayer(WORD_EMBEDDING_DIM, 8, 512)

print(multi_headed_attention(test))

NameError: name 'output_dim' is not defined

In [3]:
num_words = 30 
hidden_dim = 64

softmax_scores = torch.ones((30, 30))
z_scores = torch.ones((30, 64))

print((softmax_scores @ z_scores).shape)

torch.Size([30, 64])


In [23]:
1e20

1e+20