In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
from torchtext.legacy.data import Field, BucketIterator

import os

from transformer import Transformer
from utils import create_look_ahead_mask, create_padding_mask

In [None]:
# Define data fields
SRC = Field(tokenize='spacy', tokenizer_language='en', init_token='<sos>', eos_token='<eos>', lower=True)
TRG = Field(tokenize='spacy', tokenizer_language='de', init_token='<sos>', eos_token='<eos>', lower=True)

In [None]:
# Load benchmarking dataset
train_data, valid_data, test_data = torchtext.datasets.Multi30k.splits(exts=('.de', '.en'), fields=(TRG, SRC))

In [None]:
# Build vocabularies
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [None]:
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define model parameters
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
D_MODEL = 512
N_HEADS = 8
N_LAYERS = 6
D_FF = 2048
DROPOUT = 0.1

In [None]:
# Define model
transformer = Transformer(
    input_vocab_size=INPUT_DIM,
    target_vocab_size=OUTPUT_DIM,
    d_model=D_MODEL,
    num_heads=N_HEADS,
    num_encoder_layers=N_LAYERS,
    num_decoder_layers=N_LAYERS,
    dff=D_FF,
    dropout_rate=DROPOUT
).to(device)

In [None]:
# Define loss function
loss_fn = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi['<pad>'])

# Define optimizer
optimizer = optim.Adam(transformer.parameters(), lr=0.0001)

# Define batch size and max sequence length
BATCH_SIZE = 128
MAX_SEQ_LEN = 50

In [None]:
# Define data iterators
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    datasets=(train_data, valid_data, test_data),
    batch_size=BATCH_SIZE,
    device=device,
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
    repeat=False,
    shuffle=True
)

In [None]:
# Define number of training epochs
N_EPOCHS = 10

if not os.path.exists("checkpoints"):
    os.makedirs("checkpoints")
# Train model
for epoch in range(N_EPOCHS):
    transformer.train()
    train_loss = 0
    for i, batch in enumerate(train_iterator):
        src = batch.src
        trg = batch.trg
        
        src_mask = create_padding_mask(src)
        trg_pad_mask = create_padding_mask(trg)
        look_ahead_mask = create_look_ahead_mask(trg.shape[1])
        trg_mask = torch.max(trg_pad_mask, look_ahead_mask.to(device))
        
        optimizer.zero_grad()
        
        output, _ = transformer(src, trg[:,:-1], src_mask, trg_mask, trg_pad_mask)
        
        # Reshape output and target to match loss function requirements
        output = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:,1:].contiguous().view(-1)
        
        loss = loss_fn(output, trg)
        train_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
        # Print progress every 100 batches
        if i % 100 == 0:
            print(f'Epoch: {epoch+1}, Batch: {i+1}/{len(train_iterator)}, Train Loss: {train_loss/(i+1):.4f}')
    
    # Evaluate on validation set after each epoch
    transformer.eval()
    with torch.no_grad():
        valid_loss = 0
        for batch in valid_iterator:
            src = batch.src
            trg = batch.trg
            
            src_mask = create_padding_mask(src)
            trg_pad_mask = create_padding_mask(trg)
            look_ahead_mask = create_look_ahead_mask(trg.shape[1])
            trg_mask = torch.max(trg_pad_mask, look_ahead_mask.to(device))
            
            output, _ = transformer(src, trg[:,:-1], src_mask, trg_mask, trg_pad_mask)
            
            # Reshape output and target to match loss function requirements
            output = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:,1:].contiguous().view(-1)
            
            loss = loss_fn(output, trg)
            valid_loss += loss.item()
        
        print(f'Epoch: {epoch+1}, Validation Loss: {valid_loss/len(valid_iterator):.4f}')
    
    # Define checkpoint path
    checkpoint_path = f"checkpoint_epoch_{epoch}.pt"

    # Save model state and optimizer state
    torch.save({
        'epoch': epoch,
        'model_state_dict': transformer.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss / (i+1),
        'valid_loss': valid_loss / len(valid_iterator),
    }, checkpoint_path)
 

In [None]:
# Save model and optimizer states
torch.save({
    'model_state_dict': transformer.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'transformer_model.pt')

In [None]:
# Load saved model and optimizer states
checkpoint = torch.load('transformer_model.pt')
transformer.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# # Define checkpoint path
# checkpoint_path = "checkpoint_epoch_5.pt"

# # Load checkpoint
# checkpoint = torch.load(checkpoint_path)
# transformer.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# train_loss = checkpoint['train_loss']
# valid_loss = checkpoint['valid_loss']

In [None]:
# Testing the model
# Define the input sequence
input_sequence = 'This is a test.'

# Preprocess the input sequence
tokenized_sequence = SRC.tokenize(input_sequence)
numericalized_sequence = [SRC.vocab.stoi[token] for token in tokenized_sequence]
tensor_sequence = torch.LongTensor(numericalized_sequence).unsqueeze(1)

# Define the maximum sequence length
max_seq_len = 50

# Pad the input sequence if necessary
if tensor_sequence.shape[0] < max_seq_len:
    padding = torch.LongTensor([[SRC.vocab.stoi['<pad>']] * (max_seq_len - tensor_sequence.shape[0])])
    tensor_sequence = torch.cat([tensor_sequence, padding], dim=0)

# Run the model on the input sequence
output = transformer(tensor_sequence.to(device), None, None, None, None)[0]

# Convert the output to tokens
output_tokens = [TRG.vocab.itos[token_idx] for token_idx in output.argmax(dim=-1)]

# Print the output
print(' '.join(output_tokens))