# Transformer Machine Translation Inference
This notebook demonstrates how to use the trained transformer model for English to Italian translation.

In [2]:
import torch
from model import build_transformer
from config import get_config, latest_weights_file_path
from tokenizers import Tokenizer
from train_wb import greedy_decode
import warnings
warnings.filterwarnings('ignore')

In [3]:
# Load configuration
config = get_config()

# Device selection
device = torch.device('cuda' if torch.cuda.is_available() 
                     else 'mps' if torch.backends.mps.is_available() 
                     else 'cpu')
print(f"Using device: {device}")

Using device: mps


In [None]:
# Load tokenizers
tokenizer_src = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_src']))
tokenizer_tgt = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_tgt']))

# Load model
model = build_transformer(
    tokenizer_src.get_vocab_size(),
    tokenizer_tgt.get_vocab_size(),
    config['seq_len'],
    config['seq_len'],
    d_model=config['d_model']
).to(device)

# Load latest weights
                     # Load model
model = build_transformer(
    tokenizer_src.get_vocab_size(),
    tokenizer_tgt.get_vocab_size(),
    config['seq_len'],
    config['seq_len'],
    d_model=config['d_model']
).to(device)

# Try to find the latest weights file
try:
    model_path = latest_weights_file_path(config)
    if model_path is None:
        raise FileNotFoundError("No trained model weights found!")
        
    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state['model_state_dict'])
    print(f"Loaded weights from {model_path}")
except Exception as e:
    print(f"Error loading weights: {str(e)}")
    print("Proceeding with untrained model")
if model_path is None:
    raise FileNotFoundError("No trained model weights found!")
    
state = torch.load(model_path)
model.load_state_dict(state['model_state_dict'])
print(f"Loaded weights from {model_path}")

In [None]:
def translate(sentence: str):
    """Translate an English sentence to Italian"""
    # Tokenize the source text
    encoder_input = tokenizer_src.encode(sentence)
    encoder_input_ids = torch.tensor([encoder_input.ids]).to(device)
    
    # Create source mask
    encoder_mask = torch.ones(1, 1, len(encoder_input.ids)).to(device)
    
    # Translate
    model.eval()
    with torch.no_grad():
        translated_tokens = greedy_decode(
            model, 
            encoder_input_ids, 
            encoder_mask,
            tokenizer_src,
            tokenizer_tgt,
            config['seq_len'],
            device
        )
    
    # Decode the translated tokens
    translated_text = tokenizer_tgt.decode(translated_tokens.detach().cpu().numpy())
    
    # Clean up special tokens
    translated_text = translated_text.replace('[SOS]', '').replace('[EOS]', '').strip()
    
    return translated_text

In [None]:
# Test the translation
test_sentences = [
    "Hello, how are you?",
    "I love reading books.",
    "The weather is beautiful today.",
    "Can you help me find my way to the train station?"
]

print("English to Italian Translation Examples:")
print("-" * 50)
for sentence in test_sentences:
    translation = translate(sentence)
    print(f"English: {sentence}")
    print(f"Italian: {translation}")
    print("-" * 50)

In [None]:
# Interactive translation
while True:
    text = input("Enter English text to translate (or 'q' to quit): ")
    if text.lower() == 'q':
        break
    
    translation = translate(text)
    print(f"Italian translation: {translation}\n")