In [1]:
import torch
from config import config
from model import Transformer
from data_utils import build_vocab, token_transform, create_src_mask, tensor_transform
from torchtext.datasets import Multi30k
import sacrebleu
from torchtext.data.functional import to_map_style_dataset
import shutil
import os

In [2]:
vocab_transform = build_vocab()

In [3]:
config["src_vocab_size"] = len(vocab_transform['de'])
config["tgt_vocab_size"] = len(vocab_transform['en'])

In [4]:
model = Transformer(
    src_vocab_size=config["src_vocab_size"],
    tgt_vocab_size=config["tgt_vocab_size"],
    model_dim=config["model_dim"],
    num_heads=config["num_heads"],
    ff_dim=config["ff_dim"],
    num_layers=config["num_layers"],
    max_seq_length=config["max_seq_length"],
    dropout=config["dropout"]
)

In [5]:
model.load_state_dict(torch.load("checkpoints/transformer_best.pt"))
model.eval()

Transformer(
  (src_embedding): InputEmbedding(
    (embedding): Embedding(18669, 512)
  )
  (tgt_embedding): InputEmbedding(
    (embedding): Embedding(9795, 512)
  )
  (src_positional_encoding): PositionalEncoding()
  (tgt_positional_encoding): PositionalEncoding()
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (attention): MultiHeadAttention(
          (Wq): Linear(in_features=512, out_features=512, bias=True)
          (Wk): Linear(in_features=512, out_features=512, bias=True)
          (Wv): Linear(in_features=512, out_features=512, bias=True)
          (Wo): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (ffn): FeedForward(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
         

In [6]:
def greedy_decode(model, src, src_mask, max_len, bos_idx, eos_idx, config):
    """
    Generate a translation using greedy decoding.

    Args:
        model: Trained Transformer model
        src: (1, src_seq_len) input tensor (already tokenized and indexed)
        src_mask: (1, 1, 1, src_seq_len) mask for source
        max_len: maximum length of the generated sentence
        bos_idx: index of <bos> token in target vocab
        eos_idx: index of <eos> token in target vocab
        config: config dict for accessing padding and max_seq_length

    Returns:
        output: (1, generated_seq_len) tensor of predicted token IDs
    """
    model.eval()

    # Step 1: Encode source
    with torch.no_grad():
        src_emb = model.src_embedding(src)
        src_emb = model.src_positional_encoding(src_emb)
        memory = model.encoder(src_emb, src_mask)

    # Step 2: Start decoding with BOS token
    ys = torch.ones((1, 1), dtype=torch.long).fill_(bos_idx).to(src.device)

    for _ in range(max_len - 1):
        tgt_emb = model.tgt_embedding(ys)
        tgt_emb = model.tgt_positional_encoding(tgt_emb)

        tgt_mask = (ys != config["pad_idx"]).unsqueeze(1).unsqueeze(2)
        tgt_sub_mask = torch.tril(torch.ones((ys.size(1), ys.size(1)), device=ys.device)).bool()
        combined_mask = tgt_mask & tgt_sub_mask.unsqueeze(0).unsqueeze(1)

        # Decode
        with torch.no_grad():
            out = model.decoder(tgt_emb, memory, src_mask, combined_mask)
            logits = model.projection_layer(out[:, -1])  # (1, vocab_size)
            next_token = torch.argmax(logits, dim=-1).unsqueeze(1)  # (1, 1)

        ys = torch.cat([ys, next_token], dim=1)

        # Stop if EOS is generated
        if next_token.item() == eos_idx:
            break

    return ys


In [7]:
test_iter = list(Multi30k(split='valid'))
de_sentence, en_sentence = test_iter[0]

src_text = de_sentence.lower()
src_tokens = token_transform['de'](src_text)
src_ids = [vocab_transform['de'][tok] for tok in src_tokens]
src_tensor = tensor_transform(src_ids).unsqueeze(0)
src_mask = create_src_mask(src_tensor, config["pad_idx"])

In [8]:
with torch.no_grad():
    output_ids = greedy_decode(
        model,
        src_tensor,
        src_mask,
        max_len=config["max_seq_length"],
        bos_idx=config["bos_idx"],
        eos_idx=config["eos_idx"],
        config=config
    )

output_tokens = [vocab_transform['en'].lookup_token(tok) for tok in output_ids[0]]
output_sentence = " ".join([tok for tok in output_tokens if tok not in ["<bos>", "<eos>", "<pad>"]])


In [9]:
print(f"\nSource (DE):     {src_text}")
print(f"Generated (EN):  {output_sentence}")
print(f"Reference (EN):  {en_sentence}")


Source (DE):     eine gruppe von männern lädt baumwolle auf einen lastwagen
Generated (EN):  a group of men loading vons shopping bags on a truck .
Reference (EN):  A group of men are loading cotton onto a truck


In [10]:
bleu = sacrebleu.corpus_bleu([output_sentence], [[en_sentence]])
print(f"\n✅ Corpus BLEU score: {bleu.score:.2f}")


✅ Corpus BLEU score: 16.59
