In [None]:
from torch import nn
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from nltk.translate.bleu_score import corpus_bleu


In [None]:
# Model Configuration
ENC_CONFIG = {
    'input_dim': 10000,  # set based on vocab size
    'emb_dim': 256,
    'hid_dim': 512,
    'dropout': 0.5
}

DEC_CONFIG = {
    'output_dim': 10000,  # Will be set from vocab
    'emb_dim': 256,
    'hid_dim': 512,
    'dropout': 0.5
}

# Training Configuration
TRAIN_CONFIG = {
    'batch_size': 128,
    'epochs': 20,
    'learning_rate': 0.001,
    'teacher_forcing_ratio': 0.5,
    'clip': 1.0,
    'optimizer': torch.optim.Adam,
    'loss_fn': nn.CrossEntropyLoss(ignore_index=1),  # Ignore padding
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [None]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Tokenizers
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

def build_vocab(dataset, language):
    def yield_tokens():
        for example in dataset:
            yield de_tokenizer(example[0]) if language == SRC_LANGUAGE else en_tokenizer(example[1])
    
    vocab = build_vocab_from_iterator(
        yield_tokens(),
        min_freq=2,
        specials=['<unk>', '<pad>', '<sos>', '<eos>']
    )
    vocab.set_default_index(vocab['<unk>'])
    return vocab

def get_datasets():
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    valid_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    
    # Build vocabularies
    de_vocab = build_vocab(train_iter, SRC_LANGUAGE)
    en_vocab = build_vocab(train_iter, TGT_LANGUAGE)
    
    return train_iter, valid_iter, de_vocab, en_vocab

def collate_fn(batch, de_vocab, en_vocab, device):
    src_batch, tgt_batch = [], []
    for de_text, en_text in batch:
        src_tensor = torch.LongTensor([de_vocab['<sos>']] + 
                      de_vocab(de_tokenizer(de_text)) + 
                      [de_vocab['<eos>']])
        tgt_tensor = torch.LongTensor([en_vocab['<sos>']] + 
                      en_vocab(en_tokenizer(en_text)) + 
                      [en_vocab['<eos>']])
        src_batch.append(src_tensor)
        tgt_batch.append(tgt_tensor)
    
    src_batch = nn.utils.rnn.pad_sequence(src_batch, padding_value=1)  # pad_idx=1
    tgt_batch = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=1)
    
    return src_batch.to(device), tgt_batch.to(device)

In [None]:
class Encoder(nn.Module):
    """GRU Encoder with embedding layer"""
    def __init__(self, input_dim, emb_dim, hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.gru = nn.GRU(emb_dim, hid_dim, bidirectional=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.gru(embedded)
        return outputs, hidden

class Attention(nn.Module):
    """Bahdanau-style attention mechanism"""
    def __init__(self, hid_dim):
        super().__init__()
        self.attn = nn.Linear(hid_dim * 2, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias=False)
        
    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[0]
        hidden = hidden.repeat(src_len, 1, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        return F.softmax(attention, dim=0)

class Decoder(nn.Module):
    """GRU Decoder with attention"""
    def __init__(self, output_dim, emb_dim, hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.gru = nn.GRU(emb_dim + hid_dim, hid_dim)
        self.fc = nn.Linear(hid_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden, encoder_outputs).unsqueeze(1)
        weighted = torch.bmm(a, encoder_outputs.transpose(0, 1))
        gru_input = torch.cat((embedded, weighted), dim=2)
        output, hidden = self.gru(gru_input, hidden)
        prediction = self.fc(torch.cat((output.squeeze(0), weighted.squeeze(1)), dim=1))
        return prediction, hidden

class Seq2Seq(nn.Module):
    """Complete seq2seq model"""
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        encoder_outputs, hidden = self.encoder(src)
        
        input = trg[0,:]
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            outputs[t] = output
            teacher_force = torch.rand(1) < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1
            
        return outputs

In [None]:
class Trainer:
    def __init__(self, model, train_loader, valid_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.config = config
        self.optimizer = config['optimizer'](model.parameters(), lr=config['learning_rate'])
        self.loss_fn = config['loss_fn']
        
    def train_epoch(self):
        self.model.train()
        epoch_loss = 0
        
        for src, tgt in tqdm(self.train_loader, desc="Training"):
            self.optimizer.zero_grad()
            output = self.model(src, tgt)
            output = output[1:].view(-1, output.shape[-1])
            tgt = tgt[1:].view(-1)
            loss = self.loss_fn(output, tgt)
            loss.backward()
            clip_grad_norm_(self.model.parameters(), self.config['clip'])
            self.optimizer.step()
            epoch_loss += loss.item()
            
        return epoch_loss / len(self.train_loader)
    
    def evaluate(self):
        self.model.eval()
        epoch_loss = 0
        references = []
        hypotheses = []
        
        with torch.no_grad():
            for src, tgt in self.valid_loader:
                output = self.model(src, tgt, teacher_forcing_ratio=0)
                output = output[1:].view(-1, output.shape[-1])
                tgt = tgt[1:].view(-1)
                loss = self.loss_fn(output, tgt)
                epoch_loss += loss.item()
                
                # For BLEU score
                output_ids = output.argmax(1).view(-1, tgt.shape[0])
                tgt_ids = tgt.view(-1, tgt.shape[0])
                
                for i in range(output_ids.shape[0]):
                    ref = [tgt_ids[i].tolist()]
                    hyp = output_ids[i].tolist()
                    references.append(ref)
                    hypotheses.append(hyp)
                    
        bleu = corpus_bleu(references, hypotheses)
        return epoch_loss / len(self.valid_loader), bleu

In [None]:
def main():
    # Load dataset and vocabularies
    train_iter, valid_iter, de_vocab, en_vocab = get_datasets()
    
    # Update config with actual vocab sizes
    ENC_CONFIG['input_dim'] = len(de_vocab)
    DEC_CONFIG['output_dim'] = len(en_vocab)
    
    # Initialize model
    attention = Attention(ENC_CONFIG['hid_dim'])
    encoder = Encoder(**ENC_CONFIG)
    decoder = Decoder(attention=attention, **DEC_CONFIG)
    model = Seq2Seq(encoder, decoder, TRAIN_CONFIG['device'])
    
    # Create data loaders
    train_loader = DataLoader(list(train_iter), 
                            batch_size=TRAIN_CONFIG['batch_size'],
                            collate_fn=lambda x: collate_fn(x, de_vocab, en_vocab, TRAIN_CONFIG['device']))
    
    valid_loader = DataLoader(list(valid_iter),
                            batch_size=TRAIN_CONFIG['batch_size'],
                            collate_fn=lambda x: collate_fn(x, de_vocab, en_vocab, TRAIN_CONFIG['device']))
    
    # Initialize trainer
    trainer = Trainer(model, train_loader, valid_loader, TRAIN_CONFIG)
    
    # Training loop
    for epoch in range(TRAIN_CONFIG['epochs']):
        train_loss = trainer.train_epoch()
        valid_loss, bleu = trainer.evaluate()
        
        print(f"Epoch {epoch+1}/{TRAIN_CONFIG['epochs']}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f}")
        print(f"BLEU Score: {bleu:.4f}")
    
    torch.save(model.state_dict(), 'seq2seq_multi30k.pth')


main()