In [17]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [18]:
from torch import nn

class Embedding(nn.Module):

    def __init__(self, vocab_size: int = 1024, d_model: int = 512):
        super().__init__()

        self.d_model = d_model
        self.emb = nn.Parameter(torch.randn((vocab_size, d_model), dtype=torch.float32)).to(device) # initialize randomly for testing purposes

    def forward(self, X):
        """
        X: (batch_size, seq_len)
        """
        return self.emb[X] * self.d_model**(1/2) # multiply embeddings by sqrt(d_model) as was said in the paper section 3.4

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int = 512, max_len: int = 1024):
        super().__init__()

        pe = torch.zeros((max_len, d_model), dtype=torch.float32).to(device)
        pe[:, ::2] = torch.sin(torch.arange(max_len, dtype=torch.float32).unsqueeze(1) / 10000 ** (torch.arange(0, d_model, 2, dtype=torch.float32) / d_model))
        pe[:, 1::2] = torch.cos(torch.arange(max_len, dtype=torch.float32).unsqueeze(1) / 10000 ** (torch.arange(0, d_model, 2, dtype=torch.float32) / d_model))

        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)
        self.d_model = d_model


    def forward(self, embeddings):
        seq_len = embeddings.shape[1]
        assert seq_len <= self.pe.shape[1]
        return embeddings + self.pe[:, :seq_len, :]

In [19]:
import torch.nn.functional as F
import math
class AttentionHead(nn.Module):

    def __init__(self, d_model: int = 512, d_k: int = 64):
        super().__init__()

        self.d_k = d_k
        self.WQ = nn.Linear(d_model, d_k, bias=False)
        self.WK = nn.Linear(d_model, d_k, bias=False)
        self.WV = nn.Linear(d_model, d_k, bias=False)

    def forward(self, Q, K, V, mask):
        scores = ((self.WQ(Q) @ self.WK(K).transpose(-2, -1)) / math.sqrt(self.d_k))
        masked_scores = scores.masked_fill(mask, float('-inf'))
        attn = F.softmax(masked_scores, dim=2)
        out = attn @ self.WV(V)
        return out
    

class MultiHeadAttention(nn.Module):

    def __init__(self, d_model: int = 512, d_k: int = 64, n_h: int = 8, mask: bool = False):
        super().__init__()

        self.mask = mask
        self.WO = nn.Linear(d_k * n_h, d_model, bias=False)
        self.heads = nn.ModuleList([
            AttentionHead(d_model, d_k) for _ in range(n_h)
        ])

    def forward(self, Q, K, V):

        seq_len = Q.shape[1]
        mask = (torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=1) if self.mask else torch.zeros((seq_len, seq_len), dtype=torch.bool)).unsqueeze(0).to(device)
        head_outputs = [head(Q, K, V, mask) for head in self.heads]
        out = self.WO(torch.cat(head_outputs, dim=-1))
        return out

In [20]:
class Encoder(nn.Module):

    def __init__(self, d_model: int = 512, d_k: int = 64, n_h: int = 8, d_ff: int = 2048, p_drop: int = 0.1):
        super().__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, d_k, n_h, mask=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

        self.FFN = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(p=p_drop)
        )

    def forward(self, X):

        attn_out = self.dropout(self.multi_head_attention(X, X, X))
        X = self.layer_norm1(attn_out + X)
        ffn_out = self.FFN(X)
        X = self.layer_norm2(ffn_out + X)
        return X
    
class Decoder(nn.Module):

    def __init__(self, d_model: int = 512, d_k: int = 64, n_h: int = 8, d_ff: int = 2048, p_drop: int = 0.1):
        super().__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, d_k, n_h, mask=True)
        self.multi_head_attention_combined = MultiHeadAttention(d_model, d_k, n_h, mask=False)
        self.FFN = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(p=p_drop)
        )

        self.dropout1 = nn.Dropout(p=p_drop)
        self.dropout2 = nn.Dropout(p=p_drop)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)


    def forward(self, X, encoder_out):

        attn_out = self.dropout1(self.multi_head_attention(X, X, X))
        X = self.layer_norm1(attn_out + X)

        attn_out = self.dropout2(self.multi_head_attention_combined(X, encoder_out, encoder_out))
        X = self.layer_norm2(attn_out + X)

        ffn_out =  self.FFN(X)
        X = self.layer_norm3(ffn_out + X)

        return X   

In [21]:
class Transformer(nn.Module):

    def __init__(self, d_model: int = 512, d_k: int = 64, n_h: int = 8, d_ff: int = 2048, N: int = 6, vocab_size: int = 2, p_drop: int = 0.1):
        super().__init__()

        self.N = N
        self.encoders = nn.ModuleList([Encoder(d_model, d_k, n_h, d_ff, p_drop) for _ in range(N)])
        self.decoders = nn.ModuleList([Decoder(d_model, d_k, n_h, d_ff, p_drop) for _ in range(N)])
        self.encoder_dropout = nn.Dropout(p=p_drop)
        self.decoder_dropout = nn.Dropout(p=p_drop)


        self.linear = nn.Linear(d_model, vocab_size)

        print(f"Number of parameters: {sum([p.numel() for p in self.parameters()])}")

    def forward(self, inputs, outputs):

        encoder_prev = self.encoder_dropout(inputs)
        decoder_prev = self.decoder_dropout(outputs)

        for i in range(self.N):
            encoder_prev = self.encoders[i](encoder_prev)

        
        for i in range(self.N):
            decoder_prev = self.decoders[i](decoder_prev, encoder_prev)

        out = self.linear(decoder_prev)

        return out

In [22]:
vocab_size = 2    
d_model = 512
max_len = 1024
d_ff = 2048
d_k = 64
n_h = 8
N = 6
p_drop = 0.1

emb = Embedding(vocab_size, d_model).to(device)
pos_emb = PositionalEncoding(d_model, max_len).to(device)

input_ids = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1], [1, 1, 0, 0, 0, 1, 0, 1], [0, 0, 0, 1, 0, 1, 0, 1], [1, 1, 1, 1, 0, 1, 0, 1]], dtype=torch.int64).to(device)

inputs = pos_emb(emb(input_ids)).to(device)

output_ids = torch.tensor([[0, 0, 0, 1, 0, 1, 0, 1], [1, 1, 1, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1], [1, 1, 0, 0, 0, 1, 0, 1]], dtype=torch.int64).to(device)

outputs = pos_emb(emb(output_ids)).to(device)

model = Transformer(d_model, d_k, n_h, d_ff, N, vocab_size, p_drop).to(device)

out = model(inputs, outputs)

print(out.shape)

Number of parameters: 44102658
torch.Size([4, 8, 2])
