In [13]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

import models
import prepare

In [14]:

src_vocab_file = 'dat/vocab-eng/basic.model'
tgt_vocab_file = 'dat/vocab-kor/basic.model'

src_vocab = prepare.load_tokenizer(src_vocab_file)
tgt_vocab = prepare.load_tokenizer(tgt_vocab_file)

data_file = 'dat/raw.txt'
dataset = prepare.BPEDataset(data_file, src_vocab, tgt_vocab, src_max_len=16, tgt_max_len=16)

# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:

model = models.Transformer(
    num_encoder_layers=3, 
    num_decoder_layers=3,
    d_model=32, 
    num_heads=16, 
    dff=2048,
    input_vocab_size=len(src_vocab),  # Use source vocab size
    target_vocab_size=len(tgt_vocab),  # Use target vocab size
    pe_input=1000, 
    pe_target=1000,
    dropout_rate=0.1
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0005)
loss_fn = nn.CrossEntropyLoss()


In [16]:

num_iterations = 1000
print_every = 100
validate_every = 10

best_val_loss = float('inf')
for iteration in range(num_iterations):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        src = batch['source'].to(device)
        tgt = batch['target'].to(device)
        
        tgt_input = tgt[:, :].to(device)
        targets = tgt[:, :].contiguous().view(-1).to(device)
        
        # Create masks
        enc_padding_mask = models.create_padding_mask(src, src_vocab.special_tokens['<pad>']).to(device)
        look_ahead_mask = models.create_look_ahead_mask(tgt_input.size(1)).to(device)
        dec_padding_mask = models.create_padding_mask(tgt, tgt_vocab.special_tokens['<pad>']).to(device)  # Often, dec_padding_mask is the same as enc_padding_mask

        output = model(src, tgt_input, enc_padding_mask, look_ahead_mask, dec_padding_mask)
        output = output.view(-1, output.size(-1))
        
        loss = loss_fn(output, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if (batch_idx + 1) % print_every == 0:
            print(f'Iteration {iteration}, Batch {batch_idx + 1}, Loss: {total_loss / print_every:.4f}')
            total_loss = 0

    if (iteration + 1) % validate_every == 0:
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                src = batch['source'].to(device)
                tgt = batch['target'].to(device)
                
                tgt_input = tgt[:, :]
                targets = tgt[:, :].contiguous().view(-1)
                
                output = model(src, tgt_input, None, look_ahead_mask, None)
                output = output.view(-1, output.size(-1))
                loss = loss_fn(output, targets)
                
                total_val_loss += loss.item()
        
        avg_val_loss = total_val_loss / len(val_loader)
        print(f'Validation Loss after {iteration + 1} iterations: {avg_val_loss:.4f}')

        # Save model checkpoint if it has the best validation loss so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            checkpoint_path = os.path.join('out', f'model_checkpoint_{iteration + 1}.pt')
            torch.save({
                'iteration': iteration + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_val_loss,
            }, checkpoint_path)
            print(f'Model checkpoint saved to {checkpoint_path}')

Iteration 0, Batch 100, Loss: 4.0663
Iteration 1, Batch 100, Loss: 2.0619
Iteration 2, Batch 100, Loss: 1.0907
Iteration 3, Batch 100, Loss: 0.6096
Iteration 4, Batch 100, Loss: 0.3615
Iteration 5, Batch 100, Loss: 0.2291
Iteration 6, Batch 100, Loss: 0.1572
Iteration 7, Batch 100, Loss: 0.1136
Iteration 8, Batch 100, Loss: 0.0889
Iteration 9, Batch 100, Loss: 0.0688
Validation Loss after 10 iterations: 0.0317
Model checkpoint saved to out/model_checkpoint_10.pt
Iteration 10, Batch 100, Loss: 0.0573
Iteration 11, Batch 100, Loss: 0.0465
Iteration 12, Batch 100, Loss: 0.0394
Iteration 13, Batch 100, Loss: 0.0348
Iteration 14, Batch 100, Loss: 0.0303
Iteration 15, Batch 100, Loss: 0.0269
Iteration 16, Batch 100, Loss: 0.0232
Iteration 17, Batch 100, Loss: 0.0216
Iteration 18, Batch 100, Loss: 0.0191
Iteration 19, Batch 100, Loss: 0.0172
Validation Loss after 20 iterations: 0.0076
Model checkpoint saved to out/model_checkpoint_20.pt
Iteration 20, Batch 100, Loss: 0.0164
Iteration 21, Batc