This is a re-implementation of the original transformers. This is for my general practice, for more of a tutorial structure. Consider going through my previous implementation.  

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [2]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, d_model):
        super().__init__()
        assert d_model%num_heads == 0

        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model//num_heads

        self.W_Q = nn.Linear(self.d_model, self.d_model)
        self.W_K = nn.Linear(self.d_model, self.d_model)
        self.W_V = nn.Linear(self.d_model, self.d_model)
        self.W_O = nn.Linear(self.d_model, self.d_model)


    @staticmethod
    def scaled_dot_product_attention(Q,K,V, mask):
        """
        Args:
            query: (batch_size, num_heads, seq_len_q, d_q)
            key: (batch_size, num_heads, seq_len_k, d_k)
            value: (batch_size, num_heads, seq_len_v, d_v)
            mask: Optional mask to prevent attention to certain positions
        """
        assert Q.shape[-1] == K.shape[-1] #query and key dimension should be equal

        attention_score  = (torch.matmul(Q,K.transpose(-2,-1)))/torch.sqrt(torch.tensor(K.shape[-1]))
    
        # if mask:
        #     upper_mask = torch.tril(torch.ones(attention_score.shape[-2], attention_score.shape[-1]))
        #     upper_mask.masked_fill_(upper_mask==0, float('-inf'))
        #     attention_score = attention_score + upper_mask

        # Your mask logic:
        # upper_mask = torch.tril(torch.ones(3, 3))  # Lower triangular
        # print("tril result:")
        # print(upper_mask)
        # # Output:
        # # [[1, 0, 0],
        # #  [1, 1, 0], 
        # #  [1, 1, 1]]

        # upper_mask.masked_fill_(upper_mask==0, float('-inf'))
        # print("After masked_fill:")
        # print(upper_mask)
        # # Output:
        # # [[1, -inf, -inf],
        # #  [1, 1, -inf],
        # #  [1, 1, 1]]

        # # Then you ADD this to attention scores:
        # attention_score = attention_score + upper_mask

        # The above has a problem
        # The problems:
        # You're adding 1's to allowed positions (should add 0)
        # You're adding -inf to masked positions (this is correct)

        # FIX
        # if mask:
        #     mask_matrix = torch.tril(torch.ones(attention_score.shape[-2], attention_score.shape[-1]))
        #     mask_matrix = mask_matrix.masked_fill(mask_matrix == 0, float('-inf'))
        #     mask_matrix = mask_matrix.masked_fill(mask_matrix == 1, 0.0)  # Don't change allowed positions
        #     attention_score = attention_score + mask_matrix

        # Simpler fix 
        if mask:
            mask_matrix = torch.triu(torch.ones(attention_score.shape[-2], attention_score.shape[-1]), diagonal=1)
            attention_score = attention_score.masked_fill(mask_matrix == 1, float('-inf'))

        attention_weights  = F.softmax(attention_score, dim=-1)
        assert attention_weights.shape == (Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2])

        Z = torch.einsum('bhqk,bhkd -> bhqd', attention_weights, V)
        
        return Z
    
    # def forward(self, input_matrix_1, input_matrix_2, input_matrix_3, mask = 0):
    #     batch_size, seq_len = input_matrix_1.shape[0], input_matrix_1.shape[1]
    #     # The above is a bug 
    #     # The problem: You're using input_matrix_1.shape[1] for seq_len, but in the decoder, input_matrix_1, input_matrix_2, and input_matrix_3 might have different sequence lengths!

    #     self.Q = self.W_Q(input_matrix_1).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)
    #     self.K = self.W_K(input_matrix_2).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)
    #     self.V = self.W_V(input_matrix_3).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)

    #     attention_score = self.scaled_dot_product_attention(self.Q, self.K, self.V, mask=mask).transpose(1,2).reshape(batch_size, seq_len,self.d_model)
    #     return self.W_O(attention_score)

    def forward(self, input_matrix_1, input_matrix_2, input_matrix_3, mask = 0):
        batch_size = input_matrix_1.shape[0]

        # Get sequence length from each matrix individually
        seq_len_q = input_matrix_1.shape[1]  # Query sequence length
        seq_len_k = input_matrix_2.shape[1]  # Key sequence length  
        seq_len_v = input_matrix_3.shape[1]  # Value sequence length

        self.Q = self.W_Q(input_matrix_1).reshape(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1,2)
        self.K = self.W_K(input_matrix_2).reshape(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1,2)
        self.V = self.W_V(input_matrix_3).reshape(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1,2)

        attention_score = self.scaled_dot_product_attention(self.Q, self.K, self.V, mask=mask)
        attention_score = attention_score.transpose(1,2).reshape(batch_size, seq_len_q, self.d_model)

        return self.W_O(attention_score)

In [3]:
# Q, K, V = 0


# Z = torch.einsum('bhqk,bhkd -> bhqd', F.softmax((torch.matmul(Q,K.transpose(-2,-1)))/torch.sqrt(torch.tensor(K.shape[-1])), dim=-1), V)

In [4]:
# def test_attention():
#     # Small test case
#     batch_size, num_heads, seq_len, d_k = 2, 4, 6, 8
#     d_model = num_heads * d_k  # 32
    
#     print("=== Testing Scaled Dot Product Attention ===")
    
#     # Create sample tensors for scaled_dot_product_attention
#     Q = torch.randn(batch_size, num_heads, seq_len, d_k)
#     K = torch.randn(batch_size, num_heads, seq_len, d_k)
#     V = torch.randn(batch_size, num_heads, seq_len, d_k)
    
#     # Test without mask
#     output = MultiHeadAttention.scaled_dot_product_attention(Q, K, V, mask=False)
#     print(f"SDPA Output shape: {output.shape}")
#     print(f"SDPA Expected: {(batch_size, num_heads, seq_len, d_k)}")
    
#     # Test with mask
#     output_masked = MultiHeadAttention.scaled_dot_product_attention(Q, K, V, mask=True)
#     print(f"SDPA Masked output shape: {output_masked.shape}")
    
#     print("✅ SDPA tests passed!")
    
#     print("\n=== Testing Full MultiHeadAttention ===")
    
#     # Create MultiHeadAttention module
#     mha = MultiHeadAttention(num_heads=num_heads, d_model=d_model)
    
#     # Create input tensor (batch_size, seq_len, d_model)
#     input_tensor = torch.randn(batch_size, seq_len, d_model)
#     print(f"Input shape: {input_tensor.shape}")
    
#     # Test forward pass
#     try:
#         mha_output = mha.forward(input_tensor)
#         print(f"MHA Output shape: {mha_output.shape}")
#         print(f"MHA Expected: {(batch_size, seq_len, d_model)}")
        
#         # Check if output has reasonable values (not NaN or inf)
#         if torch.isnan(mha_output).any():
#             print("❌ Output contains NaN values!")
#         elif torch.isinf(mha_output).any():
#             print("❌ Output contains infinite values!")
#         else:
#             print("✅ Output values look reasonable!")
            
#         print("✅ Full MHA test passed!")
        
#     except Exception as e:
#         print(f"❌ MHA test failed with error: {e}")
#         print("Check your forward() method implementation")
    
#     print("\n=== Testing with different input sizes ===")
    
#     # Test with different sequence length
#     seq_len_2 = 10
#     input_tensor_2 = torch.randn(batch_size, seq_len_2, d_model)
    
#     try:
#         mha_output_2 = mha.forward(input_tensor_2)
#         print(f"Different seq_len input: {input_tensor_2.shape}")
#         print(f"Different seq_len output: {mha_output_2.shape}")
#         print("✅ Variable sequence length test passed!")
#     except Exception as e:
#         print(f"❌ Variable sequence length test failed: {e}")

# # Run the test
# test_attention()

In [5]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, seq_len):
        super().__init__()
        
        pe = torch.ones([seq_len,d_model])
        pos = torch.arange(seq_len).unsqueeze(1)
        dim = torch.arange(d_model).unsqueeze(0) # More efficient will be to have d_model//2 right here

        # self.pe[:, ::2] = torch.sin(self.pos/10000**(2*self.dim[:d_model//2]/d_model)) # This slices about row, not column, So I still have the full tensor
        # self.pe[:, 1::2] = torch.cos(self.pos/10000**(2*self.dim[:d_model//2]/d_model))
        
        pe[:, ::2] = torch.sin(pos/10000**(2*dim[:, :d_model//2]/d_model)) # This slices it in half, this is inefficient tho. Takes more space
        pe[:, 1::2] = torch.cos(pos/10000**(2*dim[:, :d_model//2]/d_model))

        # Register as buffer - won't be updated during training
        self.register_buffer('pe', pe)

    def forward(self, input_matrix):
        # return input_matrix + self.pe # This has a problem, Our inputs have variable length. It will always give fixed length pe 
        seq_len = input_matrix.size(1)
        return input_matrix + self.pe[:seq_len, :].unsqueeze(0)

In [6]:
# # Test the fix
# pe = PositionalEncoding(d_model=4, seq_len=10)
# test_input = torch.randn(2, 5, 4)
# output = pe(test_input)
# print("Success! Output shape:", output.shape)
# print("PE buffer accessible:", hasattr(pe, 'pe'))
# print("PE buffer shape:", pe.pe.shape)

In [7]:
# # test Positional Encoding 

# #
# def pe_matrix(d_model, seq_len):

#     pe = torch.ones([seq_len,d_model])
#     pos = torch.arange(seq_len).unsqueeze(1)
#     dim = torch.arange(d_model).unsqueeze(0)

#     pe[:, ::2] = torch.sin(pos/10000**(2*dim[:, :d_model//2]/d_model))
#     pe[:, 1::2] = torch.cos(pos/10000**(2*dim[:, :d_model//2]/d_model))

#     return pe

# # Clearer approach
# # def pe_matrix(d_model, seq_len):
# #     pe = torch.zeros([seq_len, d_model])  
# #     pos = torch.arange(seq_len).unsqueeze(1)
# #     dim = torch.arange(d_model//2).unsqueeze(0)  
    
    
# #     angles = pos / 10000**(2*dim/d_model)
    
# #     pe[:, ::2] = torch.sin(angles)
# #     pe[:, 1::2] = torch.cos(angles)
    
# #     return pe

# def test_pe_matrix(d_model, seq_len):
#     print(f"Testing PE with d_model={d_model}, seq_len={seq_len}")
    
#     # Generate PE matrix
#     pe = pe_matrix(d_model, seq_len)
    
#     # Test 1: Shape check
#     print(f"Shape: {pe.shape} (expected: ({seq_len}, {d_model}))")
#     assert pe.shape == (seq_len, d_model), f"Shape mismatch!"
    
#     # Test 2: Print the matrix to visually inspect
#     print("PE Matrix:")
#     print(pe)
    
#     # Test 3: Manual verification for position 0, dimensions 0 and 1
#     pos_0 = 0
#     expected_dim_0 = torch.sin(torch.tensor(pos_0 / 10000**(2*0/d_model)))  # i=0, so 2i=0
#     expected_dim_1 = torch.cos(torch.tensor(pos_0 / 10000**(2*0/d_model)))  # same i=0
    
#     print(f"\nManual check for position 0:")
#     print(f"PE[0,0] = {pe[0,0]:.6f}, expected = {expected_dim_0:.6f}")
#     print(f"PE[0,1] = {pe[0,1]:.6f}, expected = {expected_dim_1:.6f}")
    
#     # Test 4: Check that values are different across positions
#     if seq_len > 1:
#         print(f"\nDifferent positions check:")
#         print(f"PE[0,0] = {pe[0,0]:.6f}")
#         print(f"PE[1,0] = {pe[1,0]:.6f}")
#         print(f"Different? {not torch.allclose(pe[0,0], pe[1,0])}")
    
#     print("✓ All tests passed!")

# # Run the test
# test_pe_matrix(4, 3)


In [8]:
class FeedForwardNetwork(nn.Module):

    def __init__(self,d_model,d_ff,dropout=0.1):
        super().__init__()
        self.ff_model = nn.Sequential(
            nn.Linear(d_model,d_ff), 
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self,input_matrix):
        return self.ff_model(input_matrix)

In [9]:
class EncoderLayer(nn.Module):

    def __init__(self, num_heads, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads=num_heads,d_model=d_model)
        self.ffn = FeedForwardNetwork(d_model=d_model,d_ff=d_ff)
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.layer_norm_2 = nn.LayerNorm(d_model)
        self.drop_out = nn.Dropout(dropout)

    def forward(self, input_matrix):
        x = self.mha(input_matrix, input_matrix, input_matrix)
        x = self.drop_out(input_matrix + x)
        x_1 = self.layer_norm_1(x)
        x_1 = self.ffn(x_1)
        x_2 = self.drop_out(x + x_1)
        x_2 = self.layer_norm_2(x_2)
        return x_2


In [10]:
class Encoder(nn.Module):

    def __init__(self, num_heads, d_model, d_ff, num_layers, seq_len, vocab_size):
        super().__init__()
        self.embedding_model = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        self.pe = PositionalEncoding(d_model=d_model, seq_len=seq_len)
        self.encoders = nn.ModuleList([EncoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff) for i in range(num_layers)])
        

    def forward(self, input_tokens):
        input_matrix = self.embedding_model(input_tokens)
        input_matrix = self.pe(input_matrix)

        for encoder_layer in self.encoders:
            input_matrix = encoder_layer(input_matrix)

        return input_matrix

In [11]:
class DecoderLayer(nn.Module):

    def __init__(self, num_heads, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads=num_heads,d_model=d_model)
        self.cross_mha = MultiHeadAttention(num_heads=num_heads,d_model=d_model)
        self.ffn = FeedForwardNetwork(d_model=d_model,d_ff=d_ff)
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.layer_norm_2 = nn.LayerNorm(d_model)
        self.layer_norm_3 = nn.LayerNorm(d_model)
        self.drop_out = nn.Dropout(dropout)

    def forward(self, encoder_matrix, input_matrix):
        x = self.mha(input_matrix, input_matrix, input_matrix, mask=True)
        x = self.drop_out(x + input_matrix)
        x = self.layer_norm_1(x)
        x_1 = self.cross_mha(x, encoder_matrix, encoder_matrix, mask=False)
        x_1 = self.drop_out(x_1 + x)
        x_1 = self.layer_norm_2(x_1)
        x_2 = self.ffn(x_1)
        x_2 = self.drop_out(x_2 + x_1)
        x_2 = self.layer_norm_3(x_2)
        return x_2

In [12]:
class Decoder(nn.Module):

    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, seq_len):
        super().__init__()
        self.embedding_model = nn.Embedding(embedding_dim=d_model, num_embeddings=vocab_size)
        self.pe = PositionalEncoding(d_model=d_model,seq_len=seq_len)
        self.decoders = nn.ModuleList([DecoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff) for i in range(num_layers)])

    def forward(self, output_tokens, encoder_output):
        output_matrix = self.embedding_model(output_tokens)
        output_matrix = self.pe(output_matrix)

        for decoder_layer in self.decoders:
            output_matrix = decoder_layer(encoder_output, output_matrix)

        return output_matrix

In [35]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, seq_len):
        super().__init__()
        self.encoder = Encoder(num_heads=num_heads, d_model=d_model, d_ff=d_ff, num_layers=num_layers, seq_len=seq_len, vocab_size=src_vocab_size)
        self.decoder = Decoder(num_heads=num_heads, d_model=d_model, d_ff=d_ff, num_layers=num_layers, seq_len=seq_len, vocab_size=tgt_vocab_size)
        #Vocab sizes may be different
        self.final_layer = nn.Linear(d_model, tgt_vocab_size)
        
    def forward(self, src_tokens, tgt_tokens):
        encoder_output = self.encoder(src_tokens) 
        
        decoder_output = self.decoder(tgt_tokens,encoder_output)
        
        logits = self.final_layer(decoder_output) #Notice how we do not apply softmax here, as the loss function does that internally
        
        return logits

In [36]:
# def create_vocab(dataset):
#     words = dataset.split()
#     vocab = {}
#     count = 0
#     for word in words: 
#         if word not in vocab.keys():
#             vocab[word] = count
#             count+=1
#     return vocab

# def tokenize(sentence, vocab):
#     words = sentence.split()
#     converted_sentence = []
#     for word in words:
#         converted_sentence.append(vocab[word])
#     return converted_sentence

In [37]:
# big_sentence = "HI how are you , you look good."
# print(create_vocab(big_sentence))


In [38]:
def create_vocab(dataset, min_freq=1):
    from collections import Counter
    
    # Count word frequencies
    # words = [sentence.split() for sentence in dataset] #This has error, it creates a list. Counter doesn't work with lists 
    
    #Fix
    # Flatten all sentences into one big list of words
    # all_words = []
    # for sentence in dataset:
    #     all_words.extend(sentence.split()) # Look at the difference between extend and append

    # Simpler fix
    words = ' '.join(dataset).split()  # Join all sentences, then split
    word_counts = Counter(words)
    
    # Create word to index mapping
    word_to_idx = {
        '<PAD>': 0,
        '<UNK>': 1, 
        '<BOS>': 2,
        '<EOS>': 3
    }
    
    # Add words that appear at least min_freq times
    idx = 4
    for word, count in word_counts.items():
        if count >= min_freq:
            word_to_idx[word] = idx
            idx += 1
    
    # Create reverse mapping - this is the efficient part!
    idx_to_word = {idx: word for word, idx in word_to_idx.items()}
    
    return word_to_idx, idx_to_word

def tokenize(sentence, word_to_idx, add_special_tokens=True):
    words = sentence.split()
    tokens = []
    
    if add_special_tokens:
        tokens.append(word_to_idx['<BOS>'])
    
    for word in words:
        tokens.append(word_to_idx.get(word, word_to_idx['<UNK>']))
    
    if add_special_tokens:
        tokens.append(word_to_idx['<EOS>'])
    
    return tokens

def detokenize(tokens, idx_to_word):
    words = []
    for token in tokens:
        words.append(idx_to_word[token])
    return ' '.join(words)

In [39]:
# sentence = "hello world"
# word_to_idx, idx_to_word = create_vocab(sentence)
# tokens = tokenize(sentence, word_to_idx)
# reconstructed = detokenize(tokens, idx_to_word)
# print(f"Original: {sentence}")
# print(f"Tokens: {tokens}")
# print(f"Reconstructed: {reconstructed}")

In [40]:
class TranslationDataset(torch.utils.data.Dataset):
    def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
    
    def __len__(self):
        return len(self.src_sentences)
    
    def __getitem__(self, idx):
        src_tokens = tokenize(self.src_sentences[idx], self.src_vocab)
        tgt_tokens = tokenize(self.tgt_sentences[idx], self.tgt_vocab)
        return src_tokens, tgt_tokens

In [41]:
# from torch.nn.utils.rnn import pad_sequence

# def collate_fn(batch):
#     src_tokens, tgt_tokens = zip(*batch)
#     tgt_input = tgt_tokens[:-1]
#     tgt_output = tgt_tokens[1:]
#     padded_src = pad_sequence(torch.tensor(src_tokens))
#     padded_tgt_input = pad_sequence(torch.tensor(tgt_input))
#     padded_tgt_output = pad_sequence(torch.tensor(tgt_output))
    
#     return {
#         'src': padded_src,
#         'tgt_input': padded_tgt_input,
#         'tgt_output': padded_tgt_output
#     }

In [42]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    src_tokens, tgt_tokens = zip(*batch)
    
    # Convert each sequence to tensor first
    src_tensors = [torch.tensor(seq) for seq in src_tokens]
    tgt_tensors = [torch.tensor(seq) for seq in tgt_tokens]
    
    # Create input and output targets for each sequence
    tgt_input_tensors = [seq[:-1] for seq in tgt_tensors]  # Remove last token
    tgt_output_tensors = [seq[1:] for seq in tgt_tensors]  # Remove first token
    
    # Pad sequences
    padded_src = pad_sequence(src_tensors, batch_first=True, padding_value=0)
    padded_tgt_input = pad_sequence(tgt_input_tensors, batch_first=True, padding_value=0)
    padded_tgt_output = pad_sequence(tgt_output_tensors, batch_first=True, padding_value=0)
    
    return {
        'src': padded_src,
        'tgt_input': padded_tgt_input,
        'tgt_output': padded_tgt_output
    }

In [43]:
# # Mock data
# batch = [
#     ([2, 5, 8, 3], [2, 10, 11, 3]),      # src, tgt
#     ([2, 12, 7, 9, 15, 3], [2, 20, 21, 22, 3])  # longer sequences
# ]

# result = collate_fn(batch)
# print("Shapes:")
# print(f"src: {result['src'].shape}")
# print(f"tgt_input: {result['tgt_input'].shape}")
# print(f"tgt_output: {result['tgt_output'].shape}")

In [44]:
# def lr_scheduler(d_model,step_num, warmup_steps):
#     l_rate = d_model**(-1/2)*np.min(step_num**(-1/2), step_num*warmup_steps**(-3/2))
#     return l_rate

class TransformerLRScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0
        
    def step(self):
        self.step_num += 1
        lr = self.d_model ** (-0.5) * min(self.step_num ** (-0.5), 
                                         self.step_num * self.warmup_steps ** (-1.5))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

In [45]:
# def compute_loss(logits, targets, pad_token_id=0):
#     # logits: (batch_size, seq_len, vocab_size)
#     # targets: (batch_size, seq_len)
    
#     # Reshape for loss calculation
#     logits = logits.view(-1, logits.size(-1))  # (batch_size * seq_len, vocab_size)
#     targets = targets.view(-1)  # (batch_size * seq_len,)
    
#     # CrossEntropyLoss with ignore_index
#     criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
#     loss = criterion(logits, targets)
    
#     return loss

In [46]:
# # Understanding how compute loss works 

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np

# # Let's create a concrete example to understand the loss calculation

# # Example setup
# batch_size = 2
# seq_len = 4
# vocab_size = 6
# pad_token_id = 0

# # Example logits from model output
# # Shape: (batch_size, seq_len, vocab_size)
# logits = torch.randn(batch_size, seq_len, vocab_size)
# print(f"Original logits shape: {logits.shape}")
# print(f"Logits for first position of first batch:\n{logits[0, 0, :]}")

# # Example targets (ground truth tokens)
# # Shape: (batch_size, seq_len)
# targets = torch.tensor([
#     [1, 2, 3, 0],  # First sequence: word_ids [1,2,3] + padding [0]
#     [4, 5, 0, 0]   # Second sequence: word_ids [4,5] + padding [0,0]
# ])
# print(f"\nOriginal targets shape: {targets.shape}")
# print(f"Targets:\n{targets}")

# # Step 1: Reshape for CrossEntropyLoss
# # CrossEntropyLoss expects:
# # - Input: (N, C) where N = number of samples, C = number of classes
# # - Target: (N,) where each value is the class index

# logits_reshaped = logits.view(-1, vocab_size)  # (batch_size * seq_len, vocab_size)
# targets_reshaped = targets.view(-1)            # (batch_size * seq_len,)

# print(f"\nAfter reshaping:")
# print(f"Logits shape: {logits_reshaped.shape}")
# print(f"Targets shape: {targets_reshaped.shape}")
# print(f"Reshaped targets: {targets_reshaped}")

# # Step 2: Apply CrossEntropyLoss
# criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
# loss = criterion(logits_reshaped, targets_reshaped)

# print(f"\nFinal loss: {loss.item()}")

# # Let's manually see what happens with ignore_index
# print(f"\nWhich positions are ignored (pad_token_id={pad_token_id}):")
# mask = targets_reshaped != pad_token_id
# print(f"Valid positions: {mask}")
# print(f"Valid targets: {targets_reshaped[mask]}")

# # Manual calculation to show the math
# print(f"\n--- Manual CrossEntropy Calculation ---")

# # For each valid position, calculate -log(softmax(logits)[target])
# valid_losses = []
# for i in range(len(targets_reshaped)):
#     if targets_reshaped[i] != pad_token_id:  # Skip padding
#         # Get logits for this position
#         position_logits = logits_reshaped[i]
#         true_class = targets_reshaped[i]
        
#         # Apply softmax
#         softmax_probs = F.softmax(position_logits, dim=0)
        
#         # Cross entropy: -log(probability of true class)
#         ce_loss = -torch.log(softmax_probs[true_class])
#         valid_losses.append(ce_loss.item())
        
#         print(f"Position {i}: target={true_class}, prob={softmax_probs[true_class]:.4f}, loss={ce_loss:.4f}")

# manual_loss = np.mean(valid_losses)
# print(f"\nManual average loss: {manual_loss:.4f}")
# print(f"PyTorch loss: {loss.item():.4f}")

# # Show the mathematical formula
# print(f"\n--- Mathematical Formula ---")
# print("CrossEntropyLoss(x, y) = -log(softmax(x)[y])")
# print("where:")
# print("- x is the logits vector for one sample")
# print("- y is the true class index")
# print("- softmax(x)[y] is the predicted probability for the true class")
# print("\nFor multiple samples: average over all valid (non-padded) positions")

# # Example with actual word meanings
# print(f"\n--- Intuitive Example ---")
# vocab = {0: '<PAD>', 1: 'hello', 2: 'world', 3: 'how', 4: 'are', 5: 'you'}
# print("If our vocabulary is:", vocab)
# print("And our target sequence is: [hello, world, how, <PAD>]")
# print("The model needs to predict:")
# print("- Position 0: 'hello' (class 1)")
# print("- Position 1: 'world' (class 2)")  
# print("- Position 2: 'how' (class 3)")
# print("- Position 3: <PAD> (ignored)")
# print("\nThe loss measures how well the model's probability distribution")
# print("matches the true next word at each position (ignoring padding).")

In [47]:
# The below is AI generated and doesn't make much sense, write in ur own words
# Also no need for this, https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html they added label_smoothing to pytorch

class LabelSmoothingLoss(nn.Module):
    def __init__(self, vocab_size, smoothing=0.1, ignore_index=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        
    def forward(self, logits, targets):
        # logits: (batch_size, seq_len, vocab_size)
        # targets: (batch_size, seq_len) - still hard labels initially
        x
        batch_size, seq_len = targets.shape
        
        # Create smoothed labels
        smoothed_targets = torch.zeros(batch_size, seq_len, self.vocab_size)
        
        # Fill with smoothing value
        smoothed_targets.fill_(self.smoothing / (self.vocab_size - 1))
        
        # Set correct class to (1 - smoothing)
        smoothed_targets.scatter_(2, targets.unsqueeze(-1), 1.0 - self.smoothing)
        
        # Mask out padding positions
        mask = (targets != self.ignore_index).unsqueeze(-1)
        smoothed_targets = smoothed_targets * mask
        
        # Compute KL divergence loss
        log_probs = F.log_softmax(logits, dim=-1)
        loss = F.kl_div(log_probs, smoothed_targets, reduction='batchmean')
        
        return loss

In [48]:
from datasets import load_dataset

dataset = load_dataset("bentrevett/multi30k")
train_data = dataset['train']
valid_data = dataset['validation']
test_data = dataset['test']

# Extract sentences
src_sentences = [item['de'] for item in train_data]
tgt_sentences = [item['en'] for item in train_data]

In [49]:
from torch.utils.data import DataLoader

word_to_idx_src, idx_to_words_src = create_vocab(src_sentences)
word_to_idx_tgt, idx_to_words_tgt = create_vocab(tgt_sentences)

# Assuming you have your source and target sentences
train_dataset = TranslationDataset(src_sentences, tgt_sentences, word_to_idx_src, word_to_idx_tgt)

train_loader = DataLoader(
    train_dataset,
    batch_size=32, 
    shuffle=True,
    collate_fn=collate_fn
)

In [50]:
# def train_epoch(model, dataloader, optimizer, criterion, scheduler=None):
#     model.train()
#     total_loss = 0
    
#     for batch in dataloader:
#         src = batch['src']
#         tgt_input = batch['tgt_input'] 
#         tgt_output = batch['tgt_output']
        
#         # Forward pass
#         logits = model(src, tgt_input)
        
#         # Loss calculation (this needs special handling!)
#         loss = compute_loss(logits=logits, targets=tgt_output)
        
#         # Backward pass
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         if scheduler:
#             scheduler.step()
            
#         total_loss += loss.item()
    
#     return total_loss / len(dataloader)

In [51]:
# Quick test to ensure model doesn't crash
def test_model():
    # Small test parameters
    vocab_size = 1000
    d_model = 128
    num_heads = 8
    num_layers = 2
    d_ff = 512
    seq_len = 50
    
    model = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, seq_len)
    
    # Create dummy data
    batch_size = 4
    src_tokens = torch.randint(0, vocab_size, (batch_size, 20))
    tgt_tokens = torch.randint(0, vocab_size, (batch_size, 15))
    
    # Test forward pass
    try:
        output = model(src_tokens, tgt_tokens)
        print(f"✅ Model works! Output shape: {output.shape}")
        print(f"Expected shape: (batch_size={batch_size}, seq_len=15, vocab_size={vocab_size})")
        return True
    except Exception as e:
        print(f"❌ Model failed: {e}")
        return False

# Run the test
test_model()

TypeError: Transformer.__init__() missing 1 required positional argument: 'seq_len'

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler=None):
    model.train()
    total_loss = 0
    
    criterion = LabelSmoothingLoss(vocab_size=len(word_to_idx_tgt), smoothing=0.1)
    
    for batch in dataloader:
        src = batch['src']
        tgt_input = batch['tgt_input'] 
        tgt_output = batch['tgt_output']
        
        # Forward pass
        logits = model(src, tgt_input)
        
        # Loss with label smoothing
        loss = criterion(logits, tgt_output)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if scheduler:
            scheduler.step()
            
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [None]:
# Extract validation sentences  
val_src_sentences = [item['de'] for item in valid_data]
val_tgt_sentences = [item['en'] for item in valid_data]

# Create validation dataset using the SAME vocabulary as training
val_dataset = TranslationDataset(val_src_sentences, val_tgt_sentences, 
                                word_to_idx_src, word_to_idx_tgt)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [53]:
src_vocab_size = len(word_to_idx_src)
tgt_vocab_size = len(word_to_idx_tgt)
d_model = 128
num_heads = 8
num_layers = 2
d_ff = 512
seq_len = 50

model = Transformer(src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, d_model=d_model, num_heads=num_heads, num_layers=num_layers, d_ff=d_ff, seq_len=seq_len)

def test_data_pipeline():
    print("=== Testing Data Pipeline ===")
    
    # Test one batch
    for batch in train_loader:
        src = batch['src']
        tgt_input = batch['tgt_input'] 
        tgt_output = batch['tgt_output']
        
        print(f"Batch shapes:")
        print(f"  src: {src.shape}")
        print(f"  tgt_input: {tgt_input.shape}")
        print(f"  tgt_output: {tgt_output.shape}")
        
        # Test with model
        try:
            logits = model(src, tgt_input)
            print(f"  model output: {logits.shape}")
            print("✅ Data pipeline works!")
        except Exception as e:
            print(f"❌ Data pipeline failed: {e}")
        
        break  # Only test first batch

# Run this after you fix vocabulary
test_data_pipeline()

=== Testing Data Pipeline ===
Batch shapes:
  src: torch.Size([32, 22])
  tgt_input: torch.Size([32, 20])
  tgt_output: torch.Size([32, 20])
  model output: torch.Size([32, 20, 15460])
✅ Data pipeline works!


In [57]:
def quick_training_test():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model with correct vocab size
    model = Transformer(
        src_vocab_size=len(word_to_idx_src),
        tgt_vocab_size=len(word_to_idx_tgt), 
        d_model=128,
        num_heads=8, 
        num_layers=2,
        d_ff=256,
        seq_len=50
    )
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    
    # Use CrossEntropyLoss with label smoothing
    criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
    
    # Train on just 3 batches
    model.train()
    for i, batch in enumerate(train_loader):
        if i >= 3:
            break
            
        src = batch['src'].to(device)
        tgt_input = batch['tgt_input'].to(device)
        tgt_output = batch['tgt_output'].to(device)
        
        optimizer.zero_grad()
        logits = model(src, tgt_input)
        
        # Reshape for CrossEntropyLoss
        # logits: (batch_size, seq_len, vocab_size) -> (batch_size * seq_len, vocab_size)
        # targets: (batch_size, seq_len) -> (batch_size * seq_len)
        logits_flat = logits.view(-1, logits.size(-1))
        targets_flat = tgt_output.view(-1)
        
        loss = criterion(logits_flat, targets_flat)
        loss.backward()
        optimizer.step()
        
        print(f"Batch {i+1}/3: Loss = {loss.item():.4f}")
    
    print("✅ Quick training test passed!")

quick_training_test()

Using device: cpu
Batch 1/3: Loss = 9.6717
Batch 2/3: Loss = 9.7118
Batch 3/3: Loss = 9.6744
✅ Quick training test passed!


In [None]:
def train_epoch(model, dataloader, optimizer, scheduler=None):
    model.train()
    total_loss = 0
    
    # Use CrossEntropyLoss with label smoothing instead of custom loss
    criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
    
    for batch in dataloader:
        src = batch['src']
        tgt_input = batch['tgt_input'] 
        tgt_output = batch['tgt_output']
        
        # Forward pass
        logits = model(src, tgt_input)
        
        # Reshape for CrossEntropyLoss
        logits_flat = logits.view(-1, logits.size(-1))
        targets_flat = tgt_output.view(-1)
        
        # Loss with label smoothing
        loss = criterion(logits_flat, targets_flat)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if scheduler:
            scheduler.step()
            
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [58]:
## Clean this us later, read it thoroughly too!!

Excellent! Now that everything is working, let's set up proper training with evaluation and monitoring.

## Step 1: Set up validation data

```python
# Create validation dataset
val_src_sentences = [item['de'] for item in valid_data]
val_tgt_sentences = [item['en'] for item in valid_data]

val_dataset = TranslationDataset(val_src_sentences, val_tgt_sentences, 
                                word_to_idx_src, word_to_idx_tgt)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
```

## Step 2: Enhanced evaluation function

```python
def evaluate_model(model, dataloader, device, idx_to_word_tgt, max_batches=50):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
    
    predictions = []
    references = []
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= max_batches:  # Don't evaluate entire dataset
                break
                
            src = batch['src'].to(device)
            tgt_input = batch['tgt_input'].to(device)
            tgt_output = batch['tgt_output'].to(device)
            
            # Forward pass
            logits = model(src, tgt_input)
            
            # Compute loss
            logits_flat = logits.view(-1, logits.size(-1))
            targets_flat = tgt_output.view(-1)
            loss = criterion(logits_flat, targets_flat)
            
            total_loss += loss.item()
            num_batches += 1
            
            # Collect predictions for BLEU score (first 5 batches only)
            if i < 5:
                pred_ids = torch.argmax(logits, dim=-1)
                
                for j in range(min(2, src.size(0))):  # First 2 samples per batch
                    # Convert to words
                    pred_tokens = pred_ids[j].cpu().numpy()
                    ref_tokens = tgt_output[j].cpu().numpy()
                    
                    # Convert to words and remove special tokens
                    pred_words = []
                    ref_words = []
                    
                    for token in pred_tokens:
                        word = idx_to_word_tgt.get(token, '<UNK>')
                        if word not in ['<PAD>', '<BOS>', '<EOS>']:
                            pred_words.append(word)
                    
                    for token in ref_tokens:
                        word = idx_to_word_tgt.get(token, '<UNK>')
                        if word not in ['<PAD>', '<BOS>', '<EOS>']:
                            ref_words.append(word)
                    
                    predictions.append(' '.join(pred_words))
                    references.append(' '.join(ref_words))
    
    avg_loss = total_loss / num_batches
    bleu_score = compute_bleu(predictions, references) if predictions else 0
    
    return avg_loss, bleu_score, predictions[:5], references[:5]

def compute_bleu(predictions, references):
    """Simple BLEU approximation"""
    total_score = 0
    for pred, ref in zip(predictions, references):
        pred_words = set(pred.split())
        ref_words = set(ref.split())
        
        if len(pred_words) == 0:
            score = 0
        else:
            common_words = pred_words & ref_words
            precision = len(common_words) / len(pred_words)
            recall = len(common_words) / len(ref_words) if len(ref_words) > 0 else 0
            score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        total_score += score
    
    return total_score / len(predictions)
```

## Step 3: Training function with monitoring

```python
def train_model(model, train_loader, val_loader, num_epochs=10, save_path='best_model.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = TransformerLRScheduler(optimizer, d_model=128, warmup_steps=4000)
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'bleu_scores': [],
        'learning_rates': []
    }
    
    best_bleu = 0.0
    
    for epoch in range(num_epochs):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"{'='*50}")
        
        # Training
        print("Training...")
        train_loss = train_epoch(model, train_loader, optimizer, scheduler)
        
        # Evaluation
        print("Evaluating...")
        val_loss, bleu, pred_samples, ref_samples = evaluate_model(
            model, val_loader, device, idx_to_words_tgt
        )
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['bleu_scores'].append(bleu)
        history['learning_rates'].append(optimizer.param_groups[0]['lr'])
        
        # Print results
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"BLEU Score: {bleu:.4f}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save best model
        if bleu > best_bleu:
            best_bleu = bleu
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'bleu_score': bleu,
            }, save_path)
            print(f"💾 New best model saved! BLEU: {bleu:.4f}")
        
        # Show sample translations
        print("\nSample translations:")
        for i, (pred, ref) in enumerate(zip(pred_samples[:3], ref_samples[:3])):
            print(f"  {i+1}. Pred: {pred}")
            print(f"     Ref:  {ref}")
            print()
    
    return history
```

## Step 4: Create and start training

```python
# Create model
model = Transformer(
    src_vocab_size=len(word_to_idx_src),
    tgt_vocab_size=len(word_to_idx_tgt),
    d_model=256,      # Slightly larger model
    num_heads=8,
    num_layers=4,     # More layers
    d_ff=1024,        # Larger feedforward
    seq_len=100       # Longer sequences
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Start training
history = train_model(model, train_loader, val_loader, num_epochs=5)
```

## Step 5: Visualization

```python
import matplotlib.pyplot as plt

def plot_training_curves(history):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
    
    # Loss curves
    ax1.plot(history['train_loss'], label='Train Loss', color='blue')
    ax1.plot(history['val_loss'], label='Val Loss', color='red')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # BLEU score
    ax2.plot(history['bleu_scores'], label='BLEU Score', color='green')
    ax2.set_title('BLEU Score Over Time')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('BLEU Score')
    ax2.legend()
    ax2.grid(True)
    
    # Learning rate
    ax3.plot(history['learning_rates'], label='Learning Rate', color='orange')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_xlabel('Step')
    ax3.set_ylabel('Learning Rate')
    ax3.legend()
    ax3.grid(True)
    
    # Loss difference
    loss_diff = [abs(t - v) for t, v in zip(history['train_loss'], history['val_loss'])]
    ax4.plot(loss_diff, label='|Train - Val| Loss', color='purple')
    ax4.set_title('Train-Val Loss Difference')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Loss Difference')
    ax4.legend()
    ax4.grid(True)
    
    plt.tight_layout()
    plt.show()

# After training completes:
# plot_training_curves(history)
```

## Your tasks:

1. **Set up validation data** - run the validation dataset creation
2. **Start with a small test** - maybe 2 epochs first to see if everything works
3. **Monitor the training** - watch the loss and BLEU scores

**Questions**:
1. What's the total number of parameters in your model?
2. How long does one epoch take on your machine?
3. Are you seeing the train/val loss decreasing?

Let me know how the training goes!