In [32]:
# imports
!pip install datasets tokenizers
import os
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace



In [33]:
dirs = ["./malaygpt", "./tokenizer_en", "./tokenizer_my"]
for dir in dirs:
  if os.path.exists(dir):
    continue
  os.mkdir(dir)

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Download Dataset

In [35]:
# English/Malay pairs from HuggingFace
train_dataset = load_dataset("Helsinki-NLP/opus-100", "en-ms", split='train')
validation_dataset = load_dataset("Helsinki-NLP/opus-100", "en-ms", split='validation')

# Limit the amount of data for training purposes
raw_train_dataset, rt_to_skip = random_split(train_dataset, [1500, len(train_dataset) - 1500])
raw_validation_dataset, vt_to_skip = random_split(validation_dataset, [50, len(validation_dataset) - 50])

# Tokenizer

In [36]:
# Returns a generator list from a dataset of the given language
def get_ds_iterator(raw_train_dataset, lang):
  for data in raw_train_dataset:
    yield data["translation"][lang]

# Create English source tokenizer
tokenizer_en = Tokenizer(BPE(unk_token="[UNK]"))
trainer_en = BpeTrainer(min_frequency=2, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"])
# Pre-tokenizer to split input into words
tokenizer_en.pre_tokenizer = Whitespace()
tokenizer_en.train_from_iterator(get_ds_iterator(raw_train_dataset, "en"), trainer=trainer_en)
tokenizer_en.save("./tokenizer_en/tokenizer_en.json")

# Create Malay source tokenizer
tokenizer_my = Tokenizer(BPE(unk_token="[UNK]"))
trainer_my = BpeTrainer(min_frequency=2, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"])
# Pre-tokenizer to split input into words
tokenizer_my.pre_tokenizer = Whitespace()
tokenizer_my.train_from_iterator(get_ds_iterator(raw_train_dataset, "ms"), trainer=trainer_my)
tokenizer_my.save("./tokenizer_my/tokenizer_my.json")

In [37]:
# Retrieve tokenizers we made
tokenizer_en = Tokenizer.from_file("./tokenizer_en/tokenizer_en.json")
tokenizer_my = Tokenizer.from_file("./tokenizer_my/tokenizer_my.json")

# Get the vocab sizes
source_vocab_size = tokenizer_en.get_vocab_size()
target_vocab_size = tokenizer_my.get_vocab_size()

In [38]:
max_seq_len_source = 0
max_seq_len_target = 0

# Calculate the max sequence length in the training dataset for source/target
for data in raw_train_dataset:
  enc_ids = tokenizer_en.encode(data["translation"]["en"]).ids
  dec_ids = tokenizer_my.encode(data["translation"]["ms"]).ids
  max_seq_len_source = max(max_seq_len_source, len(enc_ids))
  max_seq_len_target = max(max_seq_len_target, len(dec_ids))

print("Source vocab max sequence length:", max_seq_len_source)
print("Target vocab max sequence length:", max_seq_len_target)

Source vocab max sequence length: 87
Target vocab max sequence length: 105


In [39]:
# Standard max sequence length for training, with buffer for padding, the classification token, unknown tokens, separator tokens, etc.
max_seq_len = 155

# Dataset and Dataloader

In [40]:
# Causal mask to hide future tokens
def causal_mask(size):
  # Square matrix with ones in the lower triangle: size x size
  mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
  return mask == 0

In [41]:
# Encode raw dataset to be processed by the model
class EncodeDataset(Dataset):
  def __init__(self, raw_dataset, max_seq_len):
    super().__init__()
    self.raw_dataset = raw_dataset
    self.max_seq_len = max_seq_len

  def __len__(self):
    return len(self.raw_dataset)

  def __getitem__(self, index):
    # Fetch data (in both English and Malay) for the given index
    raw_text = self.raw_dataset[index]

    # Separate text into source and target
    source_text = raw_text["translation"]["en"]
    target_text = raw_text["translation"]["ms"]

    # Encode text
    source_text_encoded = tokenizer_en.encode(source_text).ids
    target_text_encoded = tokenizer_my.encode(target_text).ids

    # Convert CLS, SEP, and PAD to their vocab index id using the tokenizer
    # Start of sentence token
    CLS_ID = torch.tensor([tokenizer_my.token_to_id("[CLS]")], dtype=torch.int64)
    # End of sentence token
    SEP_ID = torch.tensor([tokenizer_my.token_to_id("[SEP]")], dtype=torch.int64)
    # Padding token
    PAD_ID = torch.tensor([tokenizer_my.token_to_id("[PAD]")], dtype=torch.int64)

    # Amount to pad the encoded text
    num_source_padding = self.max_seq_len - len(source_text_encoded) - 2
    num_target_padding = self.max_seq_len - len(target_text_encoded) - 1
    encoder_padding = torch.tensor([PAD_ID] * num_source_padding, dtype=torch.int64)
    decoder_padding = torch.tensor([PAD_ID] * num_target_padding, dtype=torch.int64)

    # CLS + source encoding + SEP + padding
    encoder_input = torch.cat([CLS_ID, torch.tensor(source_text_encoded, dtype=torch.int64), SEP_ID, encoder_padding], dim=0)
    # CLS + target encoding + padding
    decoder_input = torch.cat([CLS_ID, torch.tensor(target_text_encoded, dtype=torch.int64), decoder_padding], dim=0)

    # target encoding + SEP + padding
    target_label = torch.cat([torch.tensor(target_text_encoded, dtype=torch.int64), SEP_ID, decoder_padding], dim=0)

    # Masks to ignore padding
    encoder_mask = (encoder_input != PAD_ID).unsqueeze(0).unsqueeze(0).int()
    # Apply causal mask to decoder mask, so that the decoder can't see future tokens when predicting the next token
    decoder_mask = (decoder_input != PAD_ID).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0))

    return {
        "encoder_input": encoder_input,
        "decoder_input": decoder_input,
        "target_label": target_label,
        "encoder_mask": encoder_mask,
        "decoder_mask": decoder_mask,
        "source_text": source_text,
        "target_text": target_text
    }

In [42]:
# Create encoded datasets
train_ds = EncodeDataset(raw_train_dataset, max_seq_len)
val_ds = EncodeDataset(raw_validation_dataset, max_seq_len)

# Create dataloaders to use in the model
train_dataloader = DataLoader(train_ds, batch_size=5, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

In [43]:
'''
encoder_input: Encoded source text with start and end of sentence tokens and padding
decoder_input: Encoded target text with start of sentence token and padding
target_label: Encoded target text with padding
encoder_mask: Mask to ignore padding in the encoder input
decoder_mask: (Causal) mask to ignore padding in the decoder input
source_text: Original source text
target_text: Original target text
'''
train_ds.__getitem__(0)

{'encoder_input': tensor([   2,   54,  220,  114, 1306,  109,    3,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           

# Input Embedding and Positional Encoding

In [44]:
# Embedding layer with normalized embeddings
class EmbeddingLayer(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    self.d_model = d_model
    # Embedding layer to map token ids to embeddings (vocab_size x d_model)
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, input):
    # Multiply embedding by the sqrt(d_model) to normalize the output
    embedding_output = self.embedding(input) * math.sqrt(self.d_model)
    return embedding_output

In [45]:
# Positional encoding layer
class PositionalEncoding(nn.Module):
  def __init__(self, d_model: int, max_seq_len: int, dropout_rate: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout_rate)

    # Init positional encodings, positions
    pe = torch.zeros(max_seq_len, d_model)
    pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    # 1 / (10000 ** (2 * i / d_model))
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

    # Apply div term to positional encodings, with sin/cos depending on even/odd dimensions
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)

    # Add batch dimension
    # pe: 1 x seq_len x d_model
    pe = pe.unsqueeze(0)
    # Ensure that the positional encodings are a part of the model, but not trainable
    self.register_buffer("pe", pe)

  def forward(self, input_embedding):
    # input_embedding: batch_size x seq_len x d_model
    input_embedding = input_embedding + (self.pe[:, :input_embedding.shape[1], :]).requires_grad_(False)
    return self.dropout(input_embedding)

# Multi-Head Attention

In [46]:
# Multihead attention block to get context
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model: int, num_heads: int, dropout_rate: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout_rate)
    self.num_heads = num_heads

    # d_model must be divisible by the number of heads
    assert d_model % num_heads == 0

    # Dimension of each self attention head
    self.d_k = d_model // num_heads

    # Init weight matrices
    self.W_q = nn.Linear(d_model, d_model, bias=False)
    self.W_k = nn.Linear(d_model, d_model, bias=False)
    self.W_v = nn.Linear(d_model, d_model, bias=False)
    self.W_o = nn.Linear(d_model, d_model, bias=False)

  def forward(self, q, k, v, encoder_mask):
    # q, k, v: batch_size x seq_len x d_model

    # Multiply input embeddings by weights
    query = self.W_q(q)
    key = self.W_k(k)
    value = self.W_v(v)

    # Divide query, key, and value into the number of heads
    # query, key, value: batch_size x num_heads x seq_len x d_k
    query = query.view(query.shape[0], query.shape[1], self.num_heads, self.d_k).transpose(1, 2)
    key = key.view(key.shape[0], key.shape[1], self.num_heads, self.d_k).transpose(1, 2)
    value = value.view(value.shape[0], value.shape[1], self.num_heads, self.d_k).transpose(1, 2)

    # SELF ATTENTION BLOCK
    # -------------------------

    # Attention score based on the similarity between the query and key
    # attention_score: batch_size x num_heads x seq_len x seq_len
    attention_score = (query @ key.transpose(-2, -1)) / math.sqrt(self.d_k)

    # Apply encoder/causal mask
    if encoder_mask is not None:
      attention_score.masked_fill_(encoder_mask == 0, -1e9)

    # Apply softmax
    attention_score = attention_score.softmax(dim=-1)

    # Apply dropout
    if self.dropout is not None:
      attention_score = self.dropout(attention_score)

    # Multiply attention score with the value
    # attention_output: batch_size x num_heads x seq_len x d_k
    attention_output = attention_score @ value

    # -------------------------

    # Concatenate all the output heads
    # attention_output: batch_size x seq_len x d_model
    attention_output = attention_output.transpose(1, 2).contiguous().view(attention_output.shape[0], -1, self.num_heads * self.d_k)

    # Multiply attention output by output weights
    multihead_output = self.W_o(attention_output)

    return multihead_output

# Feedforward, Layer Normalization, and AddAndNorm

In [47]:
# Two linear layers, with dropout and ReLU activation
class FeedForward(nn.Module):
  def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout_rate)
    self.layer_1 = nn.Linear(d_model, d_ff)
    self.layer_2 = nn.Linear(d_ff, d_model)

  def forward(self, input):
    return self.layer_2(self.dropout(torch.relu(self.layer_1(input))))

In [48]:
# Layer normalization with scaling (gamma) and shifting (beta)
class LayerNorm(nn.Module):
  def __init__(self, eps: float = 1e-5):
    super().__init__()
    # Epsilon is for divide-by-zero errors
    self.eps = eps
    # Extra learning params to scale and shift embedding values; same number of weights as d_model
    self.gamma = nn.Parameter(torch.ones(512))
    self.beta = nn.Parameter(torch.zeros(512))

  def forward(self, input):
    mean = input.mean(dim=-1, keepdim=True)
    std = input.std(dim=-1, keepdim=True)
    return self.gamma * (input - mean) / (std + self.eps) + self.beta

In [49]:
# Layer normalization and skip connection
class AddAndNorm(nn.Module):
  def __init__(self, dropout_rate: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout_rate)
    self.layer_norm = LayerNorm()

  def forward(self, input, sub_layer):
    return input + self.dropout(sub_layer(self.layer_norm(input)))

# Encoder Block and Encoder

In [50]:
# Multihead attention and feed forward blocks, with add-and-norm
class EncoderBlock(nn.Module):
  def __init__(self, multihead_attention: MultiHeadAttention, feed_forward: FeedForward, dropout_rate: float) -> None:
    super().__init__()
    self.multihead_attention = multihead_attention
    self.feed_forward = feed_forward
    self.addnorm_1 = AddAndNorm(dropout_rate)
    self.addnorm_2 = AddAndNorm(dropout_rate)

  def forward(self, encoder_input, encoder_mask):
    # Encoder input from skip connection and Multihead Attention block
    encoder_input = self.addnorm_1(encoder_input, lambda encoder_input: self.multihead_attention(encoder_input, encoder_input, encoder_input, encoder_mask))
    # Multihead Attention output from skip connection and Feed Forward block
    encoder_input = self.addnorm_2(encoder_input, self.feed_forward)

    return encoder_input

In [51]:
# Multiple encoder blocks and layer normalization
class Encoder(nn.Module):
  def __init__(self, encoderblocklist: nn.ModuleList) -> None:
    super().__init__()
    self.encoderblocklist = encoderblocklist
    self.layer_norm = LayerNorm()

  def forward(self, encoder_input, encoder_mask):
    # Loop input through all encoder blocks
    for encoderblock in self.encoderblocklist:
      encoder_input = encoderblock(encoder_input, encoder_mask)
    # Normalize the final encoder block output
    encoder_output = self.layer_norm(encoder_input)
    return encoder_output

# Decoder Block, Decoder, and Projection Layer

In [52]:
# Masked multihead attention, cross multihead attention from encoder output, and feed forward blocks, with add-and-norm
class DecoderBlock(nn.Module):
  def __init__(self, masked_multihead_attention: MultiHeadAttention, cross_multihead_attention: MultiHeadAttention, feed_forward: FeedForward, dropout_rate: float) -> None:
    super().__init__()
    # Uses a causal mask
    self.masked_multihead_attention = masked_multihead_attention
    # Uses multihead attention from the output of the encoder
    self.cross_multihead_attention = cross_multihead_attention
    self.feed_forward = feed_forward
    self.addnorm_1 = AddAndNorm(dropout_rate)
    self.addnorm_2 = AddAndNorm(dropout_rate)
    self.addnorm_3 = AddAndNorm(dropout_rate)

  def forward(self, decoder_input, encoder_output, encoder_mask, decoder_mask):
    # Decoder input from skip connection and Masked Multihead Attention block
    decoder_input = self.addnorm_1(decoder_input, lambda decoder_input: self.masked_multihead_attention(decoder_input, decoder_input, decoder_input, decoder_mask))
    # Masked Multihead Attention output from skip connection and Cross Multihead Attention block
    decoder_input = self.addnorm_2(decoder_input, lambda decoder_input: self.cross_multihead_attention(decoder_input, encoder_output, encoder_output, encoder_mask))
    # Cross Multihead Attention output from skip connection and Feed Forward block
    decoder_input = self.addnorm_3(decoder_input, self.feed_forward)
    return decoder_input

In [53]:
# Multiple decoder blocks and layer normalization
class Decoder(nn.Module):
  def __init__(self, decoderblocklist: nn.ModuleList) -> None:
    super().__init__()
    self.decoderblocklist = decoderblocklist
    self.layer_norm = LayerNorm()

  def forward(self, decoder_input, encoder_output, encoder_mask, decoder_mask):
    # Loop input through all decoder blocks
    for decoderblock in self.decoderblocklist:
      decoder_input = decoderblock(decoder_input, encoder_output, encoder_mask, decoder_mask)
    # Normalize the final decoder block output
    decoder_output = self.layer_norm(decoder_input)
    return decoder_output

In [54]:
# Linear layer and softmax activation
class ProjectionLayer(nn.Module):
  def __init__(self, d_model: int, vocab_size: int) -> None:
    super().__init__()
    self.projection_layer = nn.Linear(d_model, vocab_size)

  def forward(self, decoder_output):
    # output: batch_size x seq_len x vocab_size
    output = self.projection_layer(decoder_output)
    return output

# Transformer

In [55]:
# Full transformer model; encodes embeddings, decodes outputs, and projects predictions
class Transformer(nn.Module):
  def __init__(self, encoder: Encoder, decoder: Decoder, source_embed: EmbeddingLayer, target_embed: EmbeddingLayer, source_pos: PositionalEncoding, target_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
    super().__init__()
    # Encode
    self.source_embed = source_embed
    self.source_pos = source_pos
    self.encoder = encoder

    # Decode
    self.target_embed = target_embed
    self.target_pos = target_pos
    self.decoder = decoder

    # Maps decoder output to vocabulary
    self.projection_layer = projection_layer

  def encode(self, encoder_input, encoder_mask):
    encoder_input = self.source_embed(encoder_input)
    encoder_input = self.source_pos(encoder_input)
    encoder_output = self.encoder(encoder_input, encoder_mask)
    return encoder_output

  def decode(self, encoder_output, encoder_mask, decoder_input, decoder_mask):
    decoder_input = self.target_embed(decoder_input)
    decoder_input = self.target_pos(decoder_input)
    decoder_output = self.decoder(decoder_input, encoder_output, encoder_mask, decoder_mask)
    return decoder_output

  def project(self, decoder_output):
    return self.projection_layer(decoder_output)

In [56]:
def build_model(source_vocab_size, target_vocab_size, source_seq_len, target_seq_len, d_model=512, num_blocks=6, num_heads=8, dropout_rate=0.1, d_ff=2048):
  # Embedding layers
  source_embed = EmbeddingLayer(d_model, source_vocab_size)
  target_embed = EmbeddingLayer(d_model, target_vocab_size)

  # Positional encoding layers
  source_pos = PositionalEncoding(d_model, source_seq_len, dropout_rate)
  target_pos = PositionalEncoding(d_model, target_seq_len, dropout_rate)

  # Encoder block list
  encoderblocklist = []
  for _ in range(num_blocks):
    multihead_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
    feed_forward = FeedForward(d_model, d_ff, dropout_rate)
    encoder_block = EncoderBlock(multihead_attention, feed_forward, dropout_rate)
    encoderblocklist.append(encoder_block)
  # Encoder
  encoder = Encoder(nn.ModuleList(encoderblocklist))

  # Decoder block list
  decoderblocklist = []
  for _ in range(num_blocks):
    masked_multihead_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
    cross_multihead_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
    feed_forward = FeedForward(d_model, d_ff, dropout_rate)
    decoder_block = DecoderBlock(masked_multihead_attention, cross_multihead_attention, feed_forward, dropout_rate)
    decoderblocklist.append(decoder_block)
  # Decoder
  decoder = Decoder(nn.ModuleList(decoderblocklist))

  # Projection layer
  projection_layer = ProjectionLayer(d_model, target_vocab_size)

  # Transformer
  model = Transformer(encoder, decoder, source_embed, target_embed, source_pos, target_pos, projection_layer)

  # Init model params
  for p in model.parameters():
    if p.dim() > 1:
      nn.init.xavier_uniform_(p)

  return model

In [57]:
# Create model
model = build_model(tokenizer_en.get_vocab_size(), tokenizer_my.get_vocab_size(), max_seq_len, max_seq_len, d_model=512).to(device)

print(model)

Transformer(
  (source_embed): EmbeddingLayer(
    (embedding): Embedding(2028, 512)
  )
  (source_pos): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Encoder(
    (encoderblocklist): ModuleList(
      (0-5): 6 x EncoderBlock(
        (multihead_attention): MultiHeadAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (W_q): Linear(in_features=512, out_features=512, bias=False)
          (W_k): Linear(in_features=512, out_features=512, bias=False)
          (W_v): Linear(in_features=512, out_features=512, bias=False)
          (W_o): Linear(in_features=512, out_features=512, bias=False)
        )
        (feed_forward): FeedForward(
          (dropout): Dropout(p=0.1, inplace=False)
          (layer_1): Linear(in_features=512, out_features=2048, bias=True)
          (layer_2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (addnorm_1): AddAndNorm(
          (dropout): Dropout(p=0.1, inplace=False)
         

# Validation

In [58]:
def run_validation(model, validation_ds, tokenizer_en, tokenizer_my, max_seq_len, device, print_msg, global_step):
  # Change model to only evaluate
  model.eval()
  count = 0

  # Don"t calculate gradients during evaluation
  with torch.no_grad():
    for batch in validation_ds:
      count += 1

      # Get input and mask
      encoder_input = batch["encoder_input"].to(device)
      encoder_mask = batch["encoder_mask"].to(device)

      # Begin and end of sentence tokens
      cls_id = tokenizer_my.token_to_id("[CLS]")
      sep_id = tokenizer_my.token_to_id("[SEP]")

      # Calculate output of the encoder from the val sequence
      encoder_output = model.encode(encoder_input, encoder_mask)

      # Decoder input first token is the beginning of sentence token
      decoder_input = torch.empty(1, 1).fill_(cls_id).type_as(encoder_input).to(device)

      # Iteratively add tokens
      while True:
        # Decoder input is the max length
        if decoder_input.size(1) == max_seq_len:
          break

        # Recreate causal mask for token prediction with a new decoder input
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)

        # Get probabilities for the next token
        out = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
        prob = model.project(out[:, -1])

        # Greedily get the next token with the highest probability
        _, next_word = torch.max(prob, dim=1)

        # Add predicted token to the decoder input
        decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(encoder_input).fill_(next_word.item()).to(device)], dim=1)

        # Next token is the end of sentence token
        if next_word == sep_id:
          break

      model_out = decoder_input.squeeze(0)

      # Get source text, target text, and predicted text
      source_text = batch["source_text"][0]
      target_text = batch["target_text"][0]
      model_out_text = tokenizer_my.decode(model_out.detach().cpu().numpy())

      print_msg("-" * 55)
      print_msg(f"Source Text: {source_text}")
      print_msg(f"Target Text: {target_text}")
      print_msg(f"Predicted by MalayGPT: {model_out_text}")

      if count == 2:
        break

# Training

In [59]:
def train_model(preload_epoch=None):
  EPOCHS = 100
  initial_epoch = 0
  global_step = 0

  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-9)

  # Start at preloaded epoch, weights, and optimizer
  if preload_epoch is not None:
    # Load model
    model_filename = f"./malaygpt/model_{preload_epoch}.pth"
    state = torch.load(model_filename)
    model.load_state_dict(state["model_state_dict"])
    # Get initial epoch
    initial_epoch = state["epoch"] + 1
    # Get initial optimizer
    optimizer.load_state_dict(state["optimizer_state_dict"])
    global_step = state["global_step"]

  loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_en.token_to_id("[PAD]"), label_smoothing=0.1).to(device)

  for epoch in range(initial_epoch, EPOCHS):
    # Change model to train
    model.train()
    # Load dataset batches
    batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
    for batch in batch_iterator:
      # batch_size x seq_len
      encoder_input = batch["encoder_input"].to(device)
      # batch_size x seq_len
      decoder_input = batch["decoder_input"].to(device)
      # batch_size x 1 x 1 x seq_len
      encoder_mask = batch["encoder_mask"].to(device)
      # batch_size x 1 x seq_len x seq_len
      decoder_mask = batch["decoder_mask"].to(device)
      # batch_size x seq_len
      target_label = batch["target_label"].to(device)

      # batch_size x seq_len x d_model
      encoder_output = model.encode(encoder_input, encoder_mask)
      decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
      # batch_size x seq_len x vocab_size
      projection_output = model.project(decoder_output)

      # Calculate loss of the batch
      loss = loss_fn(projection_output.view(-1, tokenizer_my.get_vocab_size()), target_label.view(-1))
      batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
      loss.backward()

      optimizer.step()
      optimizer.zero_grad(set_to_none=True)

      global_step += 1

    # Run validation after every epoch
    run_validation(model, val_dataloader, tokenizer_en, tokenizer_my, max_seq_len, device, lambda msg: batch_iterator.write(msg), global_step)

    model_filename = f"./malaygpt/model_{epoch}.pt"
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "global_step": global_step
    }, model_filename)

In [None]:
train_model(preload_epoch=None)