In [4]:
from src.transformer_model import Transformer

import torch
import torch.nn as nn
import torch.optim as optim

In [9]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transformer = Transformer(
    src_vocab_size,
    tgt_vocab_size,
    d_model,
    num_heads,
    d_ff,
    num_layers,
    dropout,
    max_seq_length
).to(device)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length), device='cuda')  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length), device='cuda')  # (batch_size, seq_length)

In [10]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 8.668381690979004
Epoch: 2, Loss: 8.542909622192383
Epoch: 3, Loss: 8.473676681518555
Epoch: 4, Loss: 8.42013931274414
Epoch: 5, Loss: 8.363524436950684
Epoch: 6, Loss: 8.294819831848145
Epoch: 7, Loss: 8.21558952331543
Epoch: 8, Loss: 8.130414009094238
Epoch: 9, Loss: 8.04964828491211
Epoch: 10, Loss: 7.965593338012695
Epoch: 11, Loss: 7.89173698425293
Epoch: 12, Loss: 7.807148456573486
Epoch: 13, Loss: 7.727425575256348
Epoch: 14, Loss: 7.638842582702637
Epoch: 15, Loss: 7.559998035430908
Epoch: 16, Loss: 7.481446743011475
Epoch: 17, Loss: 7.393308639526367
Epoch: 18, Loss: 7.316126823425293
Epoch: 19, Loss: 7.226165294647217
Epoch: 20, Loss: 7.152009010314941
Epoch: 21, Loss: 7.070465564727783
Epoch: 22, Loss: 6.985903263092041
Epoch: 23, Loss: 6.916940212249756
Epoch: 24, Loss: 6.846054553985596
Epoch: 25, Loss: 6.759749412536621
Epoch: 26, Loss: 6.692696571350098
Epoch: 27, Loss: 6.612946033477783
Epoch: 28, Loss: 6.546875476837158
Epoch: 29, Loss: 6.46886205673217

In [12]:
transformer.eval()

# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length), device=device)  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length), device=device)  # (batch_size, seq_length)

with torch.no_grad():

    val_output = transformer(val_src_data, val_tgt_data[:, :-1])
    val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")

Validation Loss: 8.8159761428833
