In [1]:
import numpy as np
from jonigrad.layers import *

In [51]:
class MultiHeadAttention(Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.wq = Linear(d_model, d_model)
        self.wk = Linear(d_model, d_model)
        self.wv = Linear(d_model, d_model)

        self.attention = ScaledDPAttention(self.depth)

        self.linear = Linear(d_model, d_model)
        self.norm = LayerNorm(d_model)

    def split_heads(self, x, batch_size):        
        x = x.reshape(batch_size,  self.num_heads, -1, self.depth)
        return x.transpose(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, depth)
    
    def forward(self, q, k, v):
        batch_size = q.shape[0]
        
        q = self.split_heads(self.wq(q), batch_size)
        k = self.split_heads(self.wk(k), batch_size)
        v = self.split_heads(self.wv(v), batch_size)
        
        scaled_attention, _ = self.attention.forward(q, k, v)
        scaled_attention = scaled_attention.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.d_model)  # (batch_size, seq_len, d_model)
        
        output = self.linear(scaled_attention)
        
        return self.norm(output + q.transpose(0, 2, 1, 3).reshape(output.shape))  

class ScaledDPAttention(Module):
    def __init__(self, d_model=512):
        super().__init__()
        self.scale = d_model ** 0.5
        self.softmax = Softmax()
    
    def forward(self, q, k, v):
        scores = (q @ k.transpose((0, 1, 3, 2))) / self.scale
        attn = self.softmax(scores, dim=-1)
        output = attn @ v
        return output, attn
        
class LinearLayer(Module):
    def __init__(self, d_model=512):
        self.fc1 = Linear(d_model, d_model)
        self.relu = ReLU()
        self.fc2 = Linear(d_model, d_model)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class TransformerEncoder(Module):
    def __init__(self, vocab_size=1000, d_model=512, num_heads=8):
        super().__init__()
        self.input_embedding = Embedding(vocab_size, d_model)
        self.positional_embedding = Embedding(5000, d_model)
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.linear_layer = LinearLayer(d_model)
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        pos = np.arange(seq_len)
        x = self.input_embedding(x) + self.positional_embedding(pos)[None, :, :]
        x = x.transpose(1, 0, 2)  # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
        x = self.multi_head_attention(x, x, x)
        x = x.transpose(1, 0, 2)  # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
        x = self.linear_layer(x)
        return x

In [52]:
vocab_size = 1000
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 4

encoder = TransformerEncoder(vocab_size, d_model, num_heads)

# Create a sample input (batch_size, seq_len)
sample_input = np.random.randint(0, vocab_size, (batch_size, seq_len))

# Forward pass
output = encoder(sample_input)

# Check the output shape
assert output.shape == (batch_size, seq_len, d_model), f"Output shape mismatch: expected {(batch_size, seq_len, d_model)}, got {output.shape}"

print("Transformer Encoder test passed!")

(10, 8, 4, 64)
Transformer Encoder test passed!


In [None]:
class TransformerDecoder(Module):
    def __init__(self, vocab_size=1000, d_model=512, num_heads=8):
        super().__init__()
        self.input_embedding = Embedding(vocab_size, d_model)
        self.positional_embedding = Embedding(5000, d_model)
        self.masked_multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.linear_layer = Linear(d_model, d_model)
        self.softmax = Softmax()
        self.linear_layer = LinearLayer(d_model)
        
    def forward(self, x):
        batch_size, seq_len = x.shape
        pos = np.arange(seq_len)
        x = self.input_embedding(x) + self.positional_embedding(pos)[None, :, :]
        x = x.transpose(1, 0, 2)  # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
        x = self.multi_head_attention(x, x, x)
        x = x.transpose(1, 0, 2)  # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
        x = self.linear_layer(x)
        return x