In [1]:
import torch
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.nn import Transformer
from datasets import load_dataset
import itertools


In [2]:
# Load the dataset
iwslt_dataset = load_dataset('iwslt2017', 'iwslt2017-en-de')

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
    device = 'mps'  # For Apple Silicon GPUs
    torch.mps.set_per_process_memory_fraction(0.0) # Only use if RAM >= 32GB 

BATCH_SIZE = 4


In [3]:

# Initialize tokenizers
tokenizer_en = get_tokenizer('spacy', language='en_core_web_sm')
tokenizer_de = get_tokenizer('spacy', language='de_core_news_sm')

def tokenize(batch):
    en_texts = [item['en'] for item in batch['translation']]
    de_texts = [item['de'] for item in batch['translation']]
    batch['tokenized_en'] = [list(map(str, tokenizer_en(text))) for text in en_texts]
    batch['tokenized_de'] = [list(map(str, tokenizer_de(text))) for text in de_texts]
    return batch

# Tokenize the data
iwslt_dataset = iwslt_dataset.map(tokenize, batched=True, batch_size=1000, num_proc=4)

# Function to load the checkpoint
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath, map_location=device)
    d_model=checkpoint['settings']['d_model']
    model = Transformer(
        d_model=d_model,
        nhead=checkpoint['settings']['nhead'],
        num_encoder_layers=checkpoint['settings']['num_encoder_layers'],
        num_decoder_layers=checkpoint['settings']['num_decoder_layers'],
        dim_feedforward=checkpoint['settings']['dim_feedforward']
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    output_projection = torch.nn.Linear(
        checkpoint['settings']['d_model'], checkpoint['settings']['output_vocab_size']
    )
    # Ensure to load the state dict for the linear module correctly
    if 'output_projection_state_dict' in checkpoint:
        output_projection.load_state_dict(checkpoint['output_projection_state_dict'])
    else:
        print("No saved state_dict for output_projection found in checkpoint.")

    vocab_en = checkpoint['vocab_en']
    vocab_de = checkpoint['vocab_de']
    
    return model, output_projection, vocab_en, vocab_de, d_model


In [4]:
model_path = 'model_ckpt/transformer_checkpoint_epoch0.pth'


model, output_projection, EN_VOCAB, DE_VOCAB, d_model = load_checkpoint(model_path)
output_projection = output_projection.to(device)

en_embedding = torch.nn.Embedding(len(EN_VOCAB), d_model).to(device)
de_embedding = torch.nn.Embedding(len(DE_VOCAB), d_model).to(device)

# Function for collating batches
def collate_fn(batch):
    en_batch = [item['tokenized_en'] for item in batch]
    de_batch = [['<sos>'] + item['tokenized_de'] + ['<eos>'] for item in batch]
    en_indices = [[EN_VOCAB.get(token, EN_VOCAB['<unk>']) for token in sentence] for sentence in en_batch]
    de_indices = [[DE_VOCAB.get(token, DE_VOCAB['<unk>']) for token in sentence] for sentence in de_batch]
    en_tensor = pad_sequence([torch.tensor(seq) for seq in en_indices], padding_value=EN_VOCAB['<pad>'], batch_first=True)
    de_tensor = pad_sequence([torch.tensor(seq) for seq in de_indices], padding_value=DE_VOCAB['<pad>'], batch_first=True)
    return {'en': en_tensor.to(device), 'de': de_tensor.to(device)}






In [5]:

# Validation DataLoader
test_loader = DataLoader(iwslt_dataset['test'], batch_size=BATCH_SIZE, collate_fn=collate_fn)

# Evaluate the model
model.eval()
total_loss = 0
with torch.no_grad():
    for batch in test_loader:
        src_tensor = batch['en'] #.to(device)
        tgt_tensor = batch['de'] #.to(device)
        
        src = en_embedding(src_tensor)
        tgt = de_embedding(tgt_tensor)

        # Adjust the sequence length compatibility as before
        if src.shape[1] > tgt.shape[1]:
            src = src[:, :tgt.shape[1], :]
        elif src.shape[1] < tgt.shape[1]:
            pad_size = tgt.shape[1] - src.shape[1]
            src = torch.nn.functional.pad(src, (0, 0, 0, pad_size), value=EN_VOCAB['<pad>'])

        out = model(src, tgt)
        out = output_projection(out) 
        
        target_mask = (tgt_tensor != DE_VOCAB['<pad>']).view(-1)
        loss = torch.nn.functional.cross_entropy(out.view(-1, len(DE_VOCAB)), tgt_tensor.view(-1), reduction='none')
        loss = (loss * target_mask).sum() / target_mask.sum()
        print('Loss:', loss.item())
        total_loss += loss.item()

    print(f'Average Validation Loss: {total_loss / len(test_loader)}')


Loss: 10.543373107910156
Loss: 10.775322914123535
Loss: 10.665635108947754
Loss: 10.39967155456543
Loss: 11.030360221862793
Loss: 10.139505386352539
Loss: 10.74277114868164
Loss: 10.546963691711426
Loss: 10.690202713012695
Loss: 10.856400489807129
Loss: 10.612676620483398
Loss: 10.891348838806152
Loss: 10.694480895996094
Loss: 11.119407653808594
Loss: 10.707167625427246
Loss: 10.802388191223145
Loss: 10.844335556030273
Loss: 10.504132270812988
Loss: 10.431581497192383
Loss: 10.647234916687012
Loss: 10.350074768066406
Loss: 10.486842155456543
Loss: 11.102161407470703
Loss: 10.970832824707031
Loss: 10.745494842529297
Loss: 10.637165069580078
Loss: 10.67344856262207
Loss: 10.720795631408691
Loss: 10.39254379272461
Loss: 10.525296211242676
Loss: 10.61292839050293
Loss: 10.650269508361816
Loss: 11.22901439666748
Loss: 10.362801551818848
Loss: 10.853147506713867
Loss: 10.357525825500488
Loss: 10.602944374084473
Loss: 10.464288711547852
Loss: 10.849983215332031
Loss: 11.027608871459961
Loss: 