In [1]:
from pathlib import Path
import torch
import lightning as L
import torch.nn as nn
from config import get_config, get_weights_file_path
from train import get_model, get_ds, run_validation

In [2]:
# The Transformer Module class is used to align the weights with the model
class TransformerModule(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, encoder_input, decoder_input, encoder_mask, decoder_mask):
        encoder_output = self.model.encode(encoder_input, encoder_mask)
        decoder_output = self.model.decode(
            encoder_output, encoder_mask, decoder_input, decoder_mask
        )
        proj_output = self.model.project(decoder_output)
        return proj_output

In [3]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = get_config()
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
og_model = get_model(
    config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()
)
lightning_mod = TransformerModule(og_model)

# Load the pretrained weights
model_filename = get_weights_file_path(config, f"final")
state = torch.load(model_filename)
lightning_mod.load_state_dict(state)

Using device: cuda
Max length of source sentence: 309
Max length of target sentence: 274


  init.xavier_uniform(parameter)


<All keys matched successfully>

In [4]:
run_validation(
    lightning_mod,
    val_dataloader,
    tokenizer_src,
    tokenizer_tgt,
    config["seq_len"],
    device,
    None,
    num_examples=5,
)

--------------------------------------------------------------------------------
SOURCE: What aim, what purpose, what ambition in life have you now?"
TARGET: Insomma, qual'è lo scopo della vostra vita?
PREDICTED: , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,
--------------------------------------------------------------------------------
SOURCE: 