In [1]:
import torch
import torch.nn as nn

In [8]:
class Generator(nn.Module):
    def __init__(self, ni, vocab):
        self.linear = nn.Linear(ni, vocab)

    def forward(self, x):
        x = self.linear(x)
        return nn.functional.softmax(x, dim=-1)

class FeedForwardLayer(nn.Module):
    def __init__(self, d_model, n_hidden, act = nn.ReLU):
        self.layers = nn.Sequential(nn.Linear(d_model, n_hidden), act(), nn.Linear(n_hidden, d_model))
    def forward(self, x): return self.layers(x)

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads, batch_first = True):
        super().__init__()
        self.d_model, self.num_heads = d_model, num_heads
        self.qk_proj = nn.Linear(d_model, 2 * d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.attention_layer = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=batch_first)

    def forward(self, encoder_output, v_output, mask):
        qk = self.qk_proj(encoder_output)

        # retrieve the individual matrices
        q,k = qk.chunk(2, dim=-1)
        v = self.v_proj(v_output)

        # grab the attention output
        output, _ = self.attention_layer(query=q, key=k, value=v, attn_mask=mask)
        return output
    
class DecoderBlock(nn.Module):
    def __init__(self, seq_len, embed_dim, num_heads, batch_first = True):

        # our first attention layer will take in our output and predicted words
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.attention_layer = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=batch_first)
        self.layernorm1 = nn.LayerNorm(seq_len, embed_dim)

        # now we set out cross attention layer
        self.cross_attention_layer = MultiHeadCrossAttention(d_model=embed_dim, num_heads=num_heads, batch_first=batch_first)
        self.layernorm2 = nn.LayerNorm(seq_len, embed_dim)

        # now our feed forward mechanism
        self.feed_forward_layer = FeedForwardLayer(d_model=embed_dim, n_hidden=embed_dim * 4)
        self.layernorm3 = nn.LayerNorm(seq_len, embed_dim)

    def forward(self, trgt, trgt_mask, encoder_output):

        # 1. we will projecy our initial qkv which is for the target variables and ouputs
        qkv = self.qkv_proj(trgt)
        q,k,v = qkv.chunk(3, dim=-1)

        # 2. feed out trgt and trgt mask into our multihead attention layer
        # we will name is v because it turns into the values into our cross-attention layer
        v_output, _ = self.attention_layer(query=q, key=k, value=v, attn_mask=trgt_mask)
        v_output = self.layernorm1(v_output + trgt)

        # 3. perform cross-attention
        cross_output = self.cross_attention_layer(encoder_output, v_output, trgt_mask)
        cross_output = self.layernorm2(cross_output + v_output)

        # 4. feed into our feed-forward module
        x = self.feed_forward_layer(cross_output)
        return self.layernorm3(x + cross_output)
    
class Decoder(nn.Module):
    def __init__(self, seq_len, embed_dim, num_heads = 8, num_blocks = 12):
        super().__init__()
        self.seq_len, self.embed_dim, self.num_heads, self.num_blocks = seq_len, embed_dim, num_heads, num_blocks
        self.sequential_decoder = nn.Sequential(
            *[DecoderBlock(seq_len=seq_len, embed_dim=embed_dim, num_heads=num_heads)
            for i in range(num_blocks)])
        
    def forward(self, trgt, trgt_mask, encoder_output):
        return self.sequential_decoder(trgt, trgt_mask, encoder_output)