In [56]:
import torch
from torch import nn
import torch.nn.functional as f

In [57]:
def attention_dot_product(query, key, value):
    """
    Dot product of query and key is taken to determine the inital dependance between two embeddings. This score is divided by the square toot of the key vectors which is then passed through a softmax.
    The result is then multiplied by the value vector.
    """
    temp = query.bmm(key.transpose(1, 2))
    scale = query.size(-1) ** 0.5
    softmax = f.softmax(temp / scale, dim=-1)
    return softmax.bmm(value)

In [58]:
class AttentionBlock(nn.Module):
    def __init__(self, input_dim, q_dim, k_dim):
        super().__init__()
        self.q = nn.Linear(input_dim, q_dim)
        self.k = nn.Linear(input_dim, k_dim)
        self.v = nn.Linear(input_dim, k_dim)
    
    def forward(self, query, key, value):
        return attention_dot_product(self.q(query), self.k(key), self.v(value))

In [59]:
class MultiHeadAttention(nn.Module):
    """
    Implementing multiple attention layers for multiple representation subspaces, enabling the model to focus on different positions.
    """
    def __init__(self, num_heads, input_dim, q_dim, k_dim):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionBlock(input_dim, q_dim, k_dim) for _ in range(num_heads)]
        )

        self.linear = nn.Linear(num_heads * k_dim, input_dim)
    
    def forward(self, query, key, value):
        return self.linear(torch.cat([x(query, key, value) for x in self.heads], dim=-1)
        )

In [60]:
def positionalEncoding(seq_len, model_dim, device = 'cpu'):
    """
    Positional Embeddings adds a vector to each input embedding to encode the position.
    Implemented sinosodial embedding as described by Vaswani et al.
    """

    device = torch.device(device)
    pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
    dimensions = torch.arange(model_dim, dtype=torch.float, device=device).reshape(1, 1, -1)
    phase = pos / (1e4 ** torch.div(dimensions, model_dim, rounding_mode="floor"))

    return torch.where(dimensions.long() % 2 == 0, torch.sin(phase), torch.cos(phase))



In [61]:
def feedForward(input_dim = 512, ff_dim = 2048):
    return nn.Sequential(
        nn.Linear(input_dim, ff_dim),
        nn.ReLU(),
        nn.Linear(ff_dim, input_dim)
    )

In [62]:
class Residuals(nn.Module):
    def __init__(self, sublayer, dimensions, dropout = 0.1):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dimensions)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, *tensors):
        "Assuming tensors listed are Query, Key, and Value in the respective order."
        return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors)))

In [70]:
class EncoderLayer(nn.Module):
    def __init__(self, model_dim= 512, num_heads= 6, ff_dim= 2048, dropout= 0.1):
        super().__init__()
        q_dim = k_dim = max(model_dim // num_heads, 1)
        self.attention = Residuals(
            MultiHeadAttention(num_heads, model_dim, q_dim, k_dim),
            dimensions= model_dim,
            dropout=dropout
        )
        self.feedForward = Residuals(
            feedForward(model_dim, ff_dim),
            dimensions=model_dim,
            dropout=dropout
        )
    
    def forward(self, src):
        src = self.attention(src, src, src)
        return self.feedForward(src)


In [64]:
class Encoder(nn.Module):
    def __init__(self, num_layers= 6, model_dim= 512, num_heads= 8, ff_dim= 2048, dropout= 0.1):
        super().__init__()
        self.layers = nn.ModuleList(
            [EncoderLayer(model_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
    
    def forward(self, src):
        seq_len, dimensions = src.size(1), src.size(2)
        src += positionalEncoding(seq_len, dimensions)
        for l in self.layers:
            src = l(src)
        
        return src

In [65]:
class DecoderLayer(nn.Module):
    def __init__(self, model_dim= 512, num_heads= 6, ff_dim= 2048, dropout= 0.1):
        super().__init__()
        q_dim = k_dim = max(model_dim // num_heads, 1)
        self.attention1 = Residuals(
            MultiHeadAttention(num_heads, model_dim, q_dim, k_dim),
            dimensions= model_dim,
            dropout=dropout
        )
        self.attention2 = Residuals(
            MultiHeadAttention(num_heads, model_dim, q_dim, k_dim),
            dimensions= model_dim,
            dropout=dropout
        )

        self.feedForward = Residuals(
            feedForward(model_dim, ff_dim),
            dimensions= model_dim,
            dropout= dropout
        )
    
    def forward(self, tgt, memory):
        tgt = self.attention1(tgt, tgt, tgt)
        tgt = self.attention2(tgt, memory, memory)
        
        return self.feedForward(tgt)


In [66]:
class Decoder(nn.Module):
    def __init__(self, num_layers=6, model_dim= 512, num_heads= 8, ff_dim= 2048, dropout= 0.1):
        super().__init__()
        self.layers = nn.ModuleList(
            [DecoderLayer(model_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        self.linear = nn.Linear(model_dim, model_dim)
    
    def forward(self, tgt, memory):
        seq_len, dimensions = tgt.size(1), tgt.size(2)
        tgt += positionalEncoding(seq_len, dimensions)
        for l in self.layers:
            tgt = l(tgt, memory)
        
        return torch.softmax(self.linear(tgt), dim=-1)

In [67]:
class Transformer(nn.Module):
    def __init__(self, num_enc_layers = 6, num_dec_layers = 6, model_dim = 512, num_heads = 6, ff_dim = 2048, dropout = 0.1, activation = nn.ReLU()):
        super().__init__()
        self.encoder = Encoder(num_layers=num_enc_layers, model_dim=model_dim, num_heads=num_heads, ff_dim=ff_dim, dropout=dropout)
        self.decoder = Decoder(num_layers=num_dec_layers, model_dim=model_dim, num_heads=num_heads, ff_dim=ff_dim, dropout=dropout)
    
    def forward(self, src, tgt):
        return self.decoder(tgt, self.encoder(src))

In [71]:
#Testing out the model

src = torch.rand(64, 32, 512)
tgt = torch.rand(64, 32, 512)
out = Transformer()(src, tgt)
print(out.shape)

torch.Size([64, 32, 512])
