In [2]:
from tqdm import tqdm
from pathlib import Path
import os  # Add this import

import torch

from torch.utils.data import DataLoader

from transformer import Transformer, FrenchEnglishDataset
from tokenizer.bpe_tokenizer import BPETokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cpu"

tokenizer = BPETokenizer()
vocab_size = len(tokenizer.vocab)

model = Transformer(vocab_size=vocab_size,
                    num_layers=2,
                    num_heads=2,
                    ffn_hidden_dim=64,
                    embedding_dim=64,
                    qk_length=64,
                    value_length=64,
                    max_length=5000,
                    dropout=0.1)
model.to(device)

# Load the checkpoint
checkpoint = torch.load('ckpts/model_final_small.pt', map_location=torch.device('cpu'))

# Load the state dictionary into your model
model.load_state_dict(checkpoint)


train_dataset = FrenchEnglishDataset(Path("en-fr-small.csv"), tokenizer=tokenizer, train=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=FrenchEnglishDataset.collate_fn)

batch = next(iter(train_loader))
src, tgt = batch

loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=0)

src = src.to(device)
tgt = tgt.to(device)

src = src.to(torch.int64)
tgt = tgt.to(torch.int64)

model.eval()
with torch.no_grad():
    outputs = model(src, tgt)
    outputs = outputs.view(-1, vocab_size)
    tgt = tgt.view(-1)
    loss = loss_fn(outputs, tgt)
    print(loss)

    # Decode the outputs into English translations
    # Assuming outputs are of shape (batch_size * sequence_length, vocab_size)
    # We need to reshape them back to (batch_size, sequence_length, vocab_size)
    batch_size = src.shape[0]
    sequence_length = tgt.shape[0] // batch_size
    outputs = outputs.view(batch_size, sequence_length, vocab_size)

    # Greedy decoding
    decoded_outputs = []
    for output in outputs:
        # Get the argmax of each token in the sequence
        token_ids = torch.argmax(output, dim=-1)
        
        # Decode the sequence
        decoded_sequence = tokenizer.decode(token_ids)
        
        decoded_outputs.append(decoded_sequence)

    # Print the decoded translations
    for i, translation in enumerate(decoded_outputs):
        print(f"Translation {i + 1}: {translation}")

  checkpoint = torch.load('ckpts/model_final_small.pt', map_location=torch.device('cpu'))
6it [00:00, 43539.49it/s]
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
  inputs = [torch.tensor(b[0]) for b in batch]
  targets = [torch.tensor(b[1]) for b in batch]


tensor(5.0797)
Translation 1:  the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
Translation 2:  the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
Translation 3:  the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
Translation 4:  the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
Translation 5:  the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
Translation 6:  the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
