# Transformer Machine Translation Inference
This notebook demonstrates how to use the trained transformer model for English to Italian translation.

In [8]:
import torch
from model import build_transformer
from config import get_config, latest_weights_file_path
from tokenizers import Tokenizer
from train_wb import greedy_decode
import warnings
warnings.filterwarnings('ignore')

In [9]:
# Load configuration
config = get_config()

# Device selection
device = torch.device('cuda' if torch.cuda.is_available() 
                     else 'mps' if torch.backends.mps.is_available() 
                     else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [11]:
# Load tokenizers
tokenizer_src = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_src']))
tokenizer_tgt = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_tgt']))

# Load model
model = build_transformer(
    tokenizer_src.get_vocab_size(),
    tokenizer_tgt.get_vocab_size(),
    config['seq_len'],
    config['seq_len'],
    d_model=config['d_model']
).to(device)

# Function to find or prompt for weights file
def get_model_weights_path(config):
    # Try to get path from config
    model_path = latest_weights_file_path(config)
    
    # If not found, allow manual input
    if not model_path:
        print("No model weights found in the configured location.")
        print("Options:")
        print("1. Enter a custom path to weights file")
        print("2. Continue with uninitialized model (not recommended)")
        choice = input("Select option (1/2): ")
        
        if choice == "1":
            custom_path = input("Enter the full path to the weights file: ")
            if os.path.exists(custom_path):
                return custom_path
            else:
                print(f"File not found: {custom_path}")
                return None
        else:
            return None
    return model_path

# Load weights with better error handling
import os
try:
    model_path = get_model_weights_path(config)
    
    if model_path:
        # Load the model state
        state = torch.load(model_path, map_location=device)
        model.load_state_dict(state['model_state_dict'])
        print(f"Loaded weights from {model_path}")
    else:
        print("Warning: Running with uninitialized model weights!")
        print("Translation results may be nonsensical.")

except Exception as e:
    print(f"Error loading model weights: {str(e)}")
    print("Warning: Running with uninitialized model weights!")
    print("Translation results may be nonsensical.")

No model weights found in the configured location.
Options:
1. Enter a custom path to weights file
2. Continue with uninitialized model (not recommended)


File not found: opus_books_weights/mach-trans-model29.pt
Translation results may be nonsensical.


In [12]:
def translate(sentence: str):
    """Translate an English sentence to Italian"""
    # Tokenize the source text
    encoder_input = tokenizer_src.encode(sentence)
    encoder_input_ids = torch.tensor([encoder_input.ids]).to(device)
    
    # Create source mask
    encoder_mask = torch.ones(1, 1, len(encoder_input.ids)).to(device)
    
    # Translate
    model.eval()
    with torch.no_grad():
        translated_tokens = greedy_decode(
            model, 
            encoder_input_ids, 
            encoder_mask,
            tokenizer_src,
            tokenizer_tgt,
            config['seq_len'],
            device
        )
    
    # Decode the translated tokens
    translated_text = tokenizer_tgt.decode(translated_tokens.detach().cpu().numpy())
    
    # Clean up special tokens
    translated_text = translated_text.replace('[SOS]', '').replace('[EOS]', '').strip()
    
    return translated_text

In [13]:
# Test the translation
test_sentences = [
    "Hello, how are you?",
    "I love reading books.",
    "The weather is beautiful today.",
    "Can you help me find my way to the train station?"
]

print("English to Italian Translation Examples:")
print("-" * 50)
for sentence in test_sentences:
    translation = translate(sentence)
    print(f"English: {sentence}")
    print(f"Italian: {translation}")
    print("-" * 50)

English to Italian Translation Examples:
--------------------------------------------------
Vocabulary size: 22463
SOS token ID: 2, EOS token ID: 3, Comma token ID: 4
Step 1 - Top logits: tensor([[0.7545, 0.7489, 0.7475, 0.7425, 0.7414, 0.7344, 0.7295, 0.7271, 0.7202,
         0.7173]])
Step 1 - Top tokens: ['registri', 'morirà', 'togliendo', 'rientrata', 'colte', 'effetto', 'perderono', 'accettati', 'perché', 'commiserava']
Selected token: 1589 (tolse)
Step 2 - Top logits: tensor([[0.7545, 0.7488, 0.7475, 0.7425, 0.7414, 0.7344, 0.7295, 0.7271, 0.7202,
         0.7173]])
Step 2 - Top tokens: ['registri', 'morirà', 'togliendo', 'rientrata', 'colte', 'effetto', 'perderono', 'accettati', 'perché', 'commiserava']
Selected token: 680 (messo)
Step 3 - Top logits: tensor([[0.7545, 0.7489, 0.7475, 0.7425, 0.7414, 0.7344, 0.7295, 0.7271, 0.7202,
         0.7173]])
Step 3 - Top tokens: ['registri', 'morirà', 'togliendo', 'rientrata', 'colte', 'effetto', 'perderono', 'accettati', 'perché', 'comm

In [None]:
# Interactive translation
while True:
    text = input("Enter English text to translate (or 'q' to quit): ")
    if text.lower() == 'q':
        break
    
    translation = translate(text)
    print(f"Italian translation: {translation}\n")

Vocabulary size: 22463
SOS token ID: 2, EOS token ID: 3, Comma token ID: 4
Step 1 - Top logits: tensor([[0.7545, 0.7489, 0.7475, 0.7425, 0.7414, 0.7344, 0.7295, 0.7271, 0.7202,
         0.7173]])
Step 1 - Top tokens: ['registri', 'morirà', 'togliendo', 'rientrata', 'colte', 'effetto', 'perderono', 'accettati', 'perché', 'commiserava']
Selected token: 6940 (interrompendo)
Step 2 - Top logits: tensor([[0.7545, 0.7489, 0.7475, 0.7425, 0.7414, 0.7344, 0.7295, 0.7271, 0.7202,
         0.7173]])
Step 2 - Top tokens: ['registri', 'morirà', 'togliendo', 'rientrata', 'colte', 'effetto', 'perderono', 'accettati', 'perché', 'commiserava']
Selected token: 13990 (disprezzate)
Step 3 - Top logits: tensor([[0.7545, 0.7489, 0.7475, 0.7425, 0.7415, 0.7344, 0.7295, 0.7271, 0.7202,
         0.7173]])
Step 3 - Top tokens: ['registri', 'morirà', 'togliendo', 'rientrata', 'colte', 'effetto', 'perderono', 'accettati', 'perché', 'commiserava']
Selected token: 6540 (striscia)
Step 4 - Top logits: tensor([[0.75