In [5]:
%reload_ext autoreload
%autoreload 2

# test model loading and vocab loading
import torch
from transformer.transformer import Transformer
import torch.nn as nn
from trainer import Trainer
from data.translation_data import TranslationData
import sentencepiece as spm

# Test Loading a model, vocabulary and validation

In [6]:
# create device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load best model
model_path = './checkpoints/en_fr_large_512_long/best_model.pt'
checkpoint = torch.load(model_path, map_location=device)

# load hypter parameters
args = checkpoint['args']

# get tokenizers form sentence piece
sp_tokenizer = spm.SentencePieceProcessor()
sp_tokenizer.load(args['sp_model_path'])

print(f"SP piece size ('vocab size'): {sp_tokenizer.get_piece_size()}")
model = Transformer(vocab_size=sp_tokenizer.get_piece_size(), d_model=args['d_model'], n_heads=args['n_heads'],
                       max_len=args['max_len'], dropout_rate = args['dropout_rate'],
                       hidden_ff_d=args['d_model']*4,
                       num_encoder_layers=args['num_layers'],
                       num_decoder_layers=args['num_layers'], encoding_type=args['encoding_type']).to(device=device)

# load dataset
data_module = TranslationData(src_lang='en', tgt_lang='fr', batch_size=args['batch_size'],
                              max_len=args['max_len'], tokenizer=sp_tokenizer)
data_module.prepare_data()
# get validation loader
_, valid_loader, _ = data_module.get_dataloaders()

# create a trainer object for inference
loss_fn = nn.CrossEntropyLoss(ignore_index=data_module.special_tokens['<pad>'])
trainer = Trainer(model=model, val_loader=valid_loader, loss_fn=loss_fn, tokenizer=sp_tokenizer)
trainer.load_checkpoint(path=model_path)
# run validation only
val_loss, bleu_score = trainer.validate()
print(f"Val Loss: {val_loss:.04f} | BLEU Score: {bleu_score:.02f}")

SP piece size ('vocab size'): 16000
Loading dataset...
Data num_workers: 4
Data Loaders ready
Cuda available: True


                                                              

Val Loss: 4.1187 | BLEU Score: 4.80


In [None]:
import torch
import torch.onnx
import os

# --- Preparation ---
model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Create dummy input tensors (match your real shape and vocab size)
batch_size = args['batch_size']
max_len = args['max_len']
vocab_size = 16000
pad_token_id = data_module.special_tokens['<pad>']

src = torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.long).to(device)
tgt = torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.long).to(device)

# Create masks (must be tensors and on same device)
src_mask = trainer.create_src_mask(src, pad_token_id=pad_token_id).to(device)
tgt_mask = trainer.create_tgt_mask(tgt, pad_token_id=pad_token_id).to(device)

# --- Tracing and Export ---
# TorchScript trace
traced_model = torch.jit.trace(model, (src, tgt, src_mask, tgt_mask))

# Save TorchScript model (optional)
traced_model.save("transformer_script.pt")

# Export to ONNX
torch.onnx.export(
    traced_model,
    (src, tgt, src_mask, tgt_mask),
    "transformer.onnx",
    input_names=["src", "tgt", "src_mask", "tgt_mask"],
    output_names=["logits"],
    dynamic_axes={
        "src": {1: "src_len"},
        "tgt": {1: "tgt_len"},
        "logits": {1: "tgt_len"}  # logits: (batch, tgt_len, vocab)
    },
    dynamo=True
)
print("✅ ONNX export completed: transformer.onnx")


  torch.onnx.export(


[torch.onnx] Obtain model graph for `Transformer([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Transformer([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 18 of general pattern rewrite rules.
✅ ONNX export completed: transformer.onnx


# Test inference

In [None]:
# take a batch from the validation loader
#src_batch, tgt_batch = next(iter(valid_loader))

# send to device
#src_batch = src_batch.to(device)
# run inference
#src_sentences = trainer.decode_ids(id_sequences=src_batch,)
from scripts.infer import translate_sentences_non_batched
translated_sentences_beam, in_tokens, out_tokens, attention = translate_sentences_non_batched(trainer, sentences=['I like to eat pizza'], decode_type='beam', beam_size=3,return_attention=True )

print(len(out_tokens[0]))
print(out_tokens)

# Print some translations
if False:
    for idx in range(5):  # first 5 examples+
        print(f'Source sentences: {src_sentences[idx]}')
        print(f"Predicted Translation Greedy: {translated_sentences[idx]}")
        print(f"Predicted Translation Beam: {translated_sentences_beam[idx]}")
        print("="*50)

KeyboardInterrupt: 

In [4]:

print(translated_sentences_beam)

[["J'aime manger de la pizza que j'aime manger en mange à une pizza pizza."]]
