In [1]:
import numpy as np
from jonigrad.layers import *

In [55]:
class MultiHeadAttention(Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.wq = Linear(d_model, d_model)
        self.wk = Linear(d_model, d_model)
        self.wv = Linear(d_model, d_model)

        self.attention = ScaledDPAttention(self.depth)

        self.linear = Linear(d_model, d_model)
        self.norm = LayerNorm(d_model)

    def split_heads(self, x, batch_size):        
        x = x.reshape(batch_size, -1, self.num_heads, self.depth)
        return x.transpose(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, depth)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0]
        
        q = self.split_heads(self.wq(q), batch_size)
        k = self.split_heads(self.wk(k), batch_size)
        v = self.split_heads(self.wv(v), batch_size)
        
        scaled_attention, _ = self.attention(q, k, v, mask)
        scaled_attention = scaled_attention.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.d_model)  # (batch_size, seq_len, d_model)
        
        output = self.linear(scaled_attention)
        
        return self.norm(output + q.transpose(0, 2, 1, 3).reshape(output.shape))  

class ScaledDPAttention(Module):
    def __init__(self, depth):
        super().__init__()
        self.scale = np.sqrt(depth)
        self.softmax = Softmax()
    
    def forward(self, q, k, v, mask=None):
        scores = np.matmul(q, k.transpose(0, 1, 3, 2)) / self.scale
        
        if mask is not None:

            mask = mask[:, :, :scores.shape[-2], :scores.shape[-1]]
            scores = np.where(mask, scores, -1e9)
        
        attn = self.softmax(scores)
        output = np.matmul(attn, v)
        return output, attn

class LinearLayer(Module):
    def __init__(self, d_model=512):
        super().__init__()
        self.fc1 = Linear(d_model, d_model)
        self.relu = ReLU()
        self.fc2 = Linear(d_model, d_model)
        self.norm = LayerNorm(d_model)

    def forward(self, x):
        residual = x
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return self.norm(x + residual)


def create_look_ahead_mask(size):
    mask = np.tril(np.ones((size, size), dtype=bool))
    return mask  # Shape (seq_len, seq_len)

def create_padding_mask(seq):
    mask = (seq == 0)
    return mask[np.newaxis, np.newaxis, :, :]


class TransformerDecoder(Module):
    def __init__(self, input_vocab_size=1000, output_vocab_size=1000, d_model=512, num_heads=8, seq_len=10):
        super().__init__()
        self.num_heads = num_heads
        self.input_embedding = Embedding(input_vocab_size, d_model)
        self.positional_embedding = Embedding(seq_len, d_model)
        self.masked_multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.linear_layer = LinearLayer(d_model)
        self.final_linear = Linear(d_model, output_vocab_size)  # Final linear layer to project to vocab size
        self.softmax = Softmax()
        
    def forward(self, x, k, v):
        batch_size, seq_len = x.shape
        look_ahead_mask = create_look_ahead_mask(seq_len)
        padding_mask = create_padding_mask(x)

        pos = np.tile(np.arange(seq_len), (batch_size, 1))
        x = self.input_embedding(x) + self.positional_embedding(pos)        
        x = x.transpose(1, 0, 2)  # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
        
        # Masked multi-head self-attention
        x = self.masked_multi_head_attention(x, x, x, look_ahead_mask[np.newaxis, np.newaxis, :, :])
        
        # Multi-head attention with encoder output
        x = self.multi_head_attention(x, k.transpose(1, 0, 2), v.transpose(1, 0, 2), padding_mask)
        
        x = x.transpose(1, 0, 2)  # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
        
        x = self.linear_layer(x)
        
        logits = self.final_linear(x)  # Project to vocab size
        
        return self.softmax(logits, -1)

class TransformerEncoder(Module):
    def __init__(self, vocab_size=1000, d_model=512, num_heads=8, seq_len=10):
        super().__init__()
        self.input_embedding = Embedding(vocab_size, d_model)
        self.positional_embedding = Embedding(seq_len, d_model)
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.linear_layer = LinearLayer(d_model)
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        pos = np.tile(np.arange(seq_len), (batch_size, 1))
        x = self.input_embedding(x) + self.positional_embedding(pos)

        x = x.transpose(1, 0, 2)  # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
        x = self.multi_head_attention(x, x, x)
        x = x.transpose(1, 0, 2)  # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
        x = self.linear_layer(x)
        return x


In [60]:
vocab_size1 = 8919
vocab_size2 = 15243
seq_len1 = 127
seq_len2 = 57

d_model = 512
num_heads = 8
batch_size = 4

encoder = TransformerEncoder(vocab_size1, d_model, num_heads, seq_len1)

# Create a sample input (batch_size, seq_len)
sample_input = np.random.randint(0, vocab_size1, (batch_size, seq_len1))
print(sample_input.shape)
# Forward pass
encoder_output = encoder(sample_input)

# Check the encoder_output shape
assert encoder_output.shape == (batch_size, seq_len1, d_model), f"Output shape mismatch: expected {(batch_size, seq_len1, d_model)}, got {encoder_output.shape}"

print("Transformer Encoder test passed!")

# Initialize the decoder
decoder = TransformerDecoder(vocab_size1, vocab_size2, d_model, num_heads, seq_len1)

# Create a sample input for the decoder (batch_size, seq_len)
sample_input = np.random.randint(0, vocab_size2, (batch_size, seq_len2))
print(sample_input.shape, encoder_output.shape)
# Forward pass through the decoder
output = decoder(sample_input, encoder_output, encoder_output)

# Check the output shape
assert output.shape == (batch_size, seq_len2, vocab_size2), f"Output shape mismatch: expected {(batch_size, seq_len2, vocab_size2)}, got {output.shape}"

print("Transformer Decoder test passed!")


(4, 127)
Transformer Encoder test passed!
(4, 57) (4, 127, 512)


IndexError: index 11401 is out of bounds for axis 0 with size 8919

In [None]:
output.shape

(4, 127, 8919)

In [7]:
from jonigrad.utils import load_fi_en_translations
from tqdm import tqdm

g = np.random.default_rng()  # create a random generator

D_MODEL = 512
NUM_HEADS = 8
BATCH_SIZE = 16
ITERS = 2
LR = 0.001
THRESHOLD = 5

en_data, en_vocab, fi_data, fi_vocab = load_fi_en_translations(debug=False)


INPUT_VOCAB = len(en_vocab)
OUTPUT_VOCAB = len(fi_vocab)

print(OUTPUT_VOCAB)

encoder = TransformerEncoder(INPUT_VOCAB, D_MODEL, NUM_HEADS)
decoder = TransformerDecoder(INPUT_VOCAB, D_MODEL, NUM_HEADS)

loss_f = CrossEntropyLoss()
losses = []

encoder.train()
decoder.train()
for iter in tqdm(range(ITERS)):
    ix = g.integers(low=0, high=en_data.shape[0], size=BATCH_SIZE)
    Xb, Yb = en_data[ix], fi_data[ix]

    encoder.zero_grad()

    print(Xb.shape)
    encoder_output = encoder(Xb)
    print(encoder_output.shape, Yb.shape)
    decoder_output = decoder(Yb, encoder_output, encoder_output)

    loss = loss_f(decoder_output, Yb)
    dL_dy = loss_f.backward()

    dL_dy_decoder = dL_dy

    dL_dy_decoder, dh, dc = decoder.backward(dL_dy_decoder)
    dL_dy_encoder = encoder.backward(dL_dy_decoder, dh, dc)

    decoder.clip_grad(THRESHOLD, BATCH_SIZE)
    encoder.clip_grad(THRESHOLD, BATCH_SIZE)

    encoder.step(LR)
    decoder.step(LR)

    losses.append(loss.item())

  from .autonotebook import tqdm as notebook_tqdm


15243


  0%|          | 0/2 [00:00<?, ?it/s]

(16, 127)


  0%|          | 0/2 [00:00<?, ?it/s]

(16, 127, 512) (16, 57)





IndexError: index 13388 is out of bounds for axis 0 with size 8919