In [1]:
import os
import torch

In [2]:
!rm -rf data
!mkdir data
!cp -r /kaggle/input/create-phomt-700k-pairs/data/* ./data/

In [3]:
!ls -R data

data:
raw_data.src  raw_data.trg  sp	src  trg

data/sp:
src_sp.model  src_sp.vocab  trg_sp.model  trg_sp.vocab

data/src:
train.txt  valid.txt

data/trg:
train.txt  valid.txt


In [4]:
%%writefile constants.py
import torch

# Path or parameters for data
DATA_DIR = 'data'
SP_DIR = f'{DATA_DIR}/sp'
SRC_DIR = 'src'
TRG_DIR = 'trg'
SRC_RAW_DATA_NAME = 'raw_data.src'
TRG_RAW_DATA_NAME = 'raw_data.trg'
TRAIN_NAME = 'train.txt'
VALID_NAME = 'valid.txt'
TEST_NAME = 'test.txt'

# Parameters for sentencepiece tokenizer
pad_id = 0
sos_id = 1
eos_id = 2
unk_id = 3
src_model_prefix = 'src_sp'
trg_model_prefix = 'trg_sp'
sp_vocab_size = 16000
character_coverage = 1.0
model_type = 'unigram'

# Parameters for Transformer & training
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
learning_rate = 1e-4
batch_size = 64
seq_len = 128
num_heads = 8
num_layers = 6
d_model = 512
d_ff = 2048
d_k = d_model // num_heads
drop_out_rate = 0.1
num_epochs = 5
beam_size = 5
ckpt_dir = 'saved_model'

# Others
attention_type = 'luong' # 'bahdanau' or 'luong' or 'scaled_dot_product'
start_epoch = 1

Writing constants.py


In [5]:
%%writefile custom_data.py
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from constants import SP_DIR, DATA_DIR, SRC_DIR, TRG_DIR, src_model_prefix, trg_model_prefix, batch_size, seq_len, pad_id, sos_id, eos_id

import torch
import sentencepiece as spm
import numpy as np

src_sp = spm.SentencePieceProcessor()
trg_sp = spm.SentencePieceProcessor()
src_sp.Load(f"{SP_DIR}/{src_model_prefix}.model")
trg_sp.Load(f"{SP_DIR}/{trg_model_prefix}.model")


def get_data_loader(file_name):
    print(f"Getting source/target {file_name}...")
    with open(f"{DATA_DIR}/{SRC_DIR}/{file_name}", 'r') as f:
        src_text_list = f.readlines()

    with open(f"{DATA_DIR}/{TRG_DIR}/{file_name}", 'r') as f:
        trg_text_list = f.readlines()

    print("Tokenizing & Padding src data...")
    src_list = process_src(src_text_list) # (sample_num, L)
    print(f"The shape of src data: {np.shape(src_list)}")

    print("Tokenizing & Padding trg data...")
    input_trg_list, output_trg_list = process_trg(trg_text_list) # (sample_num, L)
    print(f"The shape of input trg data: {np.shape(input_trg_list)}")
    print(f"The shape of output trg data: {np.shape(output_trg_list)}")

    dataset = CustomDataset(src_list, input_trg_list, output_trg_list)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader


def pad_or_truncate(tokenized_text):
    if len(tokenized_text) < seq_len:
        left = seq_len - len(tokenized_text)
        padding = [pad_id] * left
        tokenized_text += padding
    else:
        tokenized_text = tokenized_text[:seq_len]

    return tokenized_text


def process_src(text_list):
    tokenized_list = []
    for text in tqdm(text_list):
        tokenized = src_sp.EncodeAsIds(text.strip())
        tokenized_list.append(pad_or_truncate(tokenized + [eos_id]))

    return tokenized_list

def process_trg(text_list):
    input_tokenized_list = []
    output_tokenized_list = []
    for text in tqdm(text_list):
        tokenized = trg_sp.EncodeAsIds(text.strip())
        trg_input = [sos_id] + tokenized
        trg_output = tokenized + [eos_id]
        input_tokenized_list.append(pad_or_truncate(trg_input))
        output_tokenized_list.append(pad_or_truncate(trg_output))

    return input_tokenized_list, output_tokenized_list


class CustomDataset(Dataset):
    def __init__(self, src_list, input_trg_list, output_trg_list):
        super().__init__()
        self.src_data = torch.LongTensor(src_list)
        self.input_trg_data = torch.LongTensor(input_trg_list)
        self.output_trg_data = torch.LongTensor(output_trg_list)

        assert np.shape(src_list) == np.shape(input_trg_list), "The shape of src_list and input_trg_list are different."
        assert np.shape(input_trg_list) == np.shape(output_trg_list), "The shape of input_trg_list and output_trg_list are different."

    def make_mask(self):
        e_mask = (self.src_data != pad_id).unsqueeze(1) # (num_samples, 1, L)
        d_mask = (self.input_trg_data != pad_id).unsqueeze(1) # (num_samples, 1, L)

        nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool) # (1, L, L)
        nopeak_mask = torch.tril(nopeak_mask) # (1, L, L) to triangular shape
        d_mask = d_mask & nopeak_mask # (num_samples, L, L) padding false

        return e_mask, d_mask

    def __getitem__(self, idx):
        return self.src_data[idx], self.input_trg_data[idx], self.output_trg_data[idx]

    def __len__(self):
        return np.shape(self.src_data)[0]

Writing custom_data.py


In [6]:
%%writefile data_structure.py
import heapq


class BeamNode():
    def __init__(self, cur_idx, prob, decoded):
        self.cur_idx = cur_idx
        self.prob = prob
        self.decoded = decoded
        self.is_finished = False
        
    def __gt__(self, other):
        return self.prob > other.prob
    
    def __ge__(self, other):
        return self.prob >= other.prob
    
    def __lt__(self, other):
        return self.prob < other.prob
    
    def __le__(self, other):
        return self.prob <= other.prob
    
    def __eq__(self, other):
        return self.prob == other.prob
    
    def __ne__(self, other):
        return self.prob != other.prob
    
    def print_spec(self):
        print(f"ID: {self} || cur_idx: {self.cur_idx} || prob: {self.prob} || decoded: {self.decoded}")
    

class PriorityQueue():
    def __init__(self):
        self.queue = []
        
    def put(self, obj):
        heapq.heappush(self.queue, (obj.prob, obj))
        
    def get(self):
        return heapq.heappop(self.queue)[1]
    
    def qsize(self):
        return len(self.queue)
    
    def print_scores(self):
        scores = [t[0] for t in self.queue]
        print(scores)
        
    def print_objs(self):
        objs = [t[1] for t in self.queue]
        print(objs)

Writing data_structure.py


In [7]:
%%writefile layers.py
from torch import nn
from constants import d_model, drop_out_rate, num_heads, d_k, d_ff, seq_len, device, attention_type

import torch
import math


class EncoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm_1 = LayerNormalization()
        if attention_type == 'bahdanau':
            self.multihead_attention = BahdanauMultiheadAttention()
        elif attention_type == 'scaled_dot_product':
            self.multihead_attention = MultiheadAttention()
        elif attention_type == 'luong':
            self.multihead_attention = LuongMultiheadAttention()
        self.drop_out_1 = nn.Dropout(drop_out_rate)

        self.layer_norm_2 = LayerNormalization()
        self.feed_forward = FeedFowardLayer()
        self.drop_out_2 = nn.Dropout(drop_out_rate)

    def forward(self, x, e_mask):
        x_1 = self.layer_norm_1(x) # (B, L, d_model)
        x = x + self.drop_out_1(
            self.multihead_attention(x_1, x_1, x_1, mask=e_mask)
        ) # (B, L, d_model)
        x_2 = self.layer_norm_2(x) # (B, L, d_model)
        x = x + self.drop_out_2(self.feed_forward(x_2)) # (B, L, d_model)

        return x # (B, L, d_model)


class DecoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm_1 = LayerNormalization()
        self.masked_multihead_attention = None
        if attention_type == 'bahdanau':
            self.masked_multihead_attention = BahdanauMultiheadAttention()
        elif attention_type == 'scaled_dot_product':
            self.masked_multihead_attention = MultiheadAttention()
        elif attention_type == 'luong':
            self.masked_multihead_attention = LuongMultiheadAttention()
        self.drop_out_1 = nn.Dropout(drop_out_rate)

        self.layer_norm_2 = LayerNormalization()
        self.multihead_attention = None
        if attention_type == 'bahdanau':
            self.multihead_attention = BahdanauMultiheadAttention()
        elif attention_type == 'scaled_dot_product':
            self.multihead_attention = MultiheadAttention()
        elif attention_type == 'luong':
            self.multihead_attention = LuongMultiheadAttention()
        self.drop_out_2 = nn.Dropout(drop_out_rate)

        self.layer_norm_3 = LayerNormalization()
        self.feed_forward = FeedFowardLayer()
        self.drop_out_3 = nn.Dropout(drop_out_rate)

    def forward(self, x, e_output, e_mask,  d_mask):
        x_1 = self.layer_norm_1(x) # (B, L, d_model)
        x = x + self.drop_out_1(
            self.masked_multihead_attention(x_1, x_1, x_1, mask=d_mask)
        ) # (B, L, d_model)
        x_2 = self.layer_norm_2(x) # (B, L, d_model)
        x = x + self.drop_out_2(
            self.multihead_attention(x_2, e_output, e_output, mask=e_mask)
        ) # (B, L, d_model)
        x_3 = self.layer_norm_3(x) # (B, L, d_model)
        x = x + self.drop_out_3(self.feed_forward(x_3)) # (B, L, d_model)

        return x # (B, L, d_model)


class MultiheadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.inf = 1e9

        # W^Q, W^K, W^V in the paper
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(drop_out_rate)
        self.attn_softmax = nn.Softmax(dim=-1)

        # Final output linear transformation
        self.w_0 = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        input_shape = q.shape

        # Linear calculation +  split into num_heads
        q = self.w_q(q).view(input_shape[0], -1, num_heads, d_k) # (B, L, num_heads, d_k)
        k = self.w_k(k).view(input_shape[0], -1, num_heads, d_k) # (B, L, num_heads, d_k)
        v = self.w_v(v).view(input_shape[0], -1, num_heads, d_k) # (B, L, num_heads, d_k)

        # For convenience, convert all tensors in size (B, num_heads, L, d_k)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Conduct self-attention
        attn_values = self.self_attention(q, k, v, mask=mask) # (B, num_heads, L, d_k)
        concat_output = attn_values.transpose(1, 2)\
            .contiguous().view(input_shape[0], -1, d_model) # (B, L, d_model)

        return self.w_0(concat_output)

    def self_attention(self, q, k, v, mask=None):
        # Calculate attention scores with scaled dot-product attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) # (B, num_heads, L, L)
        attn_scores = attn_scores / math.sqrt(d_k)

        # If there is a mask, make masked spots -INF
        if mask is not None:
            mask = mask.unsqueeze(1) # (B, 1, L) => (B, 1, 1, L) or (B, L, L) => (B, 1, L, L)
            attn_scores = attn_scores.masked_fill_(mask == 0, -1 * self.inf)

        # Softmax and multiplying K to calculate attention value
        attn_distribs = self.attn_softmax(attn_scores)

        attn_distribs = self.dropout(attn_distribs)
        attn_values = torch.matmul(attn_distribs, v) # (B, num_heads, L, d_k)

        return attn_values


class FeedFowardLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff, bias=True)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(d_ff, d_model, bias=True)
        self.dropout = nn.Dropout(drop_out_rate)

    def forward(self, x):
        x = self.relu(self.linear_1(x)) # (B, L, d_ff)
        x = self.dropout(x)
        x = self.linear_2(x) # (B, L, d_model)

        return x


class LayerNormalization(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.layer = nn.LayerNorm([d_model], elementwise_affine=True, eps=self.eps)

    def forward(self, x):
        x = self.layer(x)

        return x


class PositionalEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Make initial positional encoding matrix with 0
        pe_matrix= torch.zeros(seq_len, d_model) # (L, d_model)

        # Calculating position encoding values
        for pos in range(seq_len):
            for i in range(d_model):
                if i % 2 == 0:
                    pe_matrix[pos, i] = math.sin(pos / (10000 ** (2 * i / d_model)))
                elif i % 2 == 1:
                    pe_matrix[pos, i] = math.cos(pos / (10000 ** (2 * i / d_model)))

        pe_matrix = pe_matrix.unsqueeze(0) # (1, L, d_model)
        # self.positional_encoding = pe_matrix.to(device=device).requires_grad_(False)
        self.register_buffer('positional_encoding', pe_matrix)

    def forward(self, x):
        x = x * math.sqrt(d_model) # (B, L, d_model)
        x = x + self.positional_encoding # (B, L, d_model)

        return x

class BahdanauMultiheadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.inf = 1e9
        
        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Bahdanau specific: Vector v (learned parameter)
        # Shape (1, num_heads, 1, 1, d_k) for broadcasting
        self.v = nn.Parameter(torch.rand(1, num_heads, 1, 1, d_k))
        nn.init.xavier_uniform_(self.v)

        self.dropout = nn.Dropout(drop_out_rate)
        self.attn_softmax = nn.Softmax(dim=-1)

        # Final output linear transformation
        self.w_0 = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        input_shape = q.shape 

        # Linear & Split heads
        q = self.w_q(q).view(input_shape[0], -1, num_heads, d_k) # (B, L, H, d_k)
        k = self.w_k(k).view(input_shape[0], -1, num_heads, d_k) 
        v = self.w_v(v).view(input_shape[0], -1, num_heads, d_k) 

        # Transpose to (B, H, L, d_k)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # --- Bahdanau Score Calculation ---
        # q: (B, H, L_q, 1, d_k)
        # k: (B, H, 1, L_k, d_k)
        # Broadcast sum -> (B, H, L_q, L_k, d_k)
        energy = torch.tanh(q.unsqueeze(3) + k.unsqueeze(2)) 
        
        # Multiply by v and sum last dim -> (B, H, L_q, L_k)
        attn_scores = torch.sum(self.v * energy, dim=-1) 
        # ----------------------------------

        if mask is not None:
            mask = mask.unsqueeze(1) 
            attn_scores = attn_scores.masked_fill(mask == 0, -1 * self.inf)

        attn_distribs = self.attn_softmax(attn_scores)
        attn_distribs = self.dropout(attn_distribs)
        
        attn_values = torch.matmul(attn_distribs, v) 
        
        concat_output = attn_values.transpose(1, 2)\
            .contiguous().view(input_shape[0], -1, d_model)

        return self.w_0(concat_output)

class LuongMultiheadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.inf = 1e9

        # Standard projections
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # --- LUONG 'GENERAL' SPECIFIC ---
        # Matrix Wa in formula: score(h_t, h_s) = h_t^T * Wa * h_s
        # Chúng ta cần một ma trận Wa cho mỗi Head riêng biệt.
        # Shape: (num_heads, d_k, d_k)
        self.w_a = nn.Parameter(torch.rand(num_heads, d_k, d_k))
        nn.init.xavier_uniform_(self.w_a)
        # --------------------------------

        self.dropout = nn.Dropout(drop_out_rate)
        self.attn_softmax = nn.Softmax(dim=-1)

        self.w_0 = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        input_shape = q.shape

        # 1. Linear & Split heads
        q = self.w_q(q).view(input_shape[0], -1, num_heads, d_k) # (B, L_q, H, d_k)
        k = self.w_k(k).view(input_shape[0], -1, num_heads, d_k) # (B, L_k, H, d_k)
        v = self.w_v(v).view(input_shape[0], -1, num_heads, d_k) 

        # Transpose => (B, H, L, d_k)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # 2. Luong 'General' Score Calculation: Q * Wa * K^T
        # Bước A: Tính Q_weighted = Q * Wa
        # Q: (B, H, L_q, d_k)
        # Wa: (H, d_k, d_k)
        # Chúng ta dùng einsum để nhân ma trận Wa riêng cho từng head
        # 'bhld' (q), 'hde' (wa) -> 'bhle' (q_weighted)
        q_weighted = torch.einsum('bhld,hde->bhle', q, self.w_a)

        # Bước B: Nhân với K^T
        attn_scores = torch.matmul(q_weighted, k.transpose(-2, -1)) # (B, H, L_q, L_k)

        # Lưu ý: Luong Attention gốc thường không chia cho sqrt(d_k), 
        # nhưng ta có thể giữ hoặc bỏ tùy ý. Ở đây tôi bỏ scaling để đúng chất Luong General.
        
        # 3. Masking & Softmax (Standard)
        if mask is not None:
            mask = mask.unsqueeze(1)
            attn_scores = attn_scores.masked_fill(mask == 0, -1 * self.inf)

        attn_distribs = self.attn_softmax(attn_scores)
        attn_distribs = self.dropout(attn_distribs)
        
        attn_values = torch.matmul(attn_distribs, v)

        concat_output = attn_values.transpose(1, 2)\
            .contiguous().view(input_shape[0], -1, d_model)

        return self.w_0(concat_output)

Writing layers.py


In [8]:
%%writefile main.py
from tqdm import tqdm
from constants import (
    SP_DIR,
    src_model_prefix,
    trg_model_prefix,
    seq_len,
    pad_id,
    sos_id,
    eos_id,
    learning_rate,
    device,
    num_epochs,
    ckpt_dir,
    beam_size,
    start_epoch,
)
from constants import TRAIN_NAME, VALID_NAME
from custom_data import get_data_loader, pad_or_truncate
from transformer import Transformer
from torch import nn
import torch.nn.functional as F

import torch
import sys
import os
import numpy as np
import argparse
import datetime
import sentencepiece as spm


class Manager:
    def __init__(self, is_train=True, ckpt_name=None):
        # Load vocabs
        print("Loading vocabs...")
        self.src_i2w = {}
        self.trg_i2w = {}

        with open(f"{SP_DIR}/{src_model_prefix}.vocab") as f:
            lines = f.readlines()
        for i, line in enumerate(lines):
            word = line.strip().split("\t")[0]
            self.src_i2w[i] = word

        with open(f"{SP_DIR}/{trg_model_prefix}.vocab") as f:
            lines = f.readlines()
        for i, line in enumerate(lines):
            word = line.strip().split("\t")[0]
            self.trg_i2w[i] = word

        print(
            f"The size of src vocab is {len(self.src_i2w)} and that of trg vocab is {len(self.trg_i2w)}."
        )

        # Load Transformer model & Adam optimizer
        print("Loading Transformer model & Adam optimizer...")
        self.model = Transformer(
            src_vocab_size=len(self.src_i2w), trg_vocab_size=len(self.trg_i2w)
        )
        if torch.cuda.device_count() > 1:
            print(f"Detecting {torch.cuda.device_count()} GPUs. Using DataParallel!")
            self.model = nn.DataParallel(self.model)
            self.model_core = self.model.module
        else:
            self.model_core = self.model
        self.model = self.model.to(device)
        self.optim = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        self.best_loss = sys.float_info.max
        self.start_epoch = start_epoch
        print(f"Starting from epoch {self.start_epoch}")

        if ckpt_name is not None:
            assert os.path.exists(f"{ckpt_dir}/{ckpt_name}"), (
                f"There is no checkpoint named {ckpt_name}."
            )

            print("Loading checkpoint...")
            checkpoint = torch.load(
                f"{ckpt_dir}/{ckpt_name}", map_location=device, weights_only=False
            )
            self.model_core.load_state_dict(checkpoint["model_state_dict"])
            self.optim.load_state_dict(checkpoint["optim_state_dict"])
            self.best_loss = checkpoint["loss"]

            if "epoch" in checkpoint:
                self.start_epoch = checkpoint["epoch"]
                print(f"Resuming training from epoch {self.start_epoch}")
            else:
                print(
                    f"No epoch info in checkpoint, starting from epoch {self.start_epoch} (but with loaded weights)."
                )

        else:
            print("Initializing the model...")
            for p in self.model.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

        if is_train:
            # Load loss function
            print("Loading loss function...")
            self.criterion = nn.NLLLoss(ignore_index=pad_id)

            # Load dataloaders
            print("Loading dataloaders...")
            self.train_loader = get_data_loader(TRAIN_NAME)
            self.valid_loader = get_data_loader(VALID_NAME)

        print("Setting finished.")

    def train(self):
        print("Training starts.")

        start_range = self.start_epoch
        end_range = self.start_epoch + num_epochs

        for epoch in range(start_range, end_range):
            self.model.train()

            train_losses = []
            start_time = datetime.datetime.now()

            for i, batch in tqdm(enumerate(self.train_loader)):
                src_input, trg_input, trg_output = batch
                src_input, trg_input, trg_output = (
                    src_input.to(device),
                    trg_input.to(device),
                    trg_output.to(device),
                )

                e_mask, d_mask = self.make_mask(src_input, trg_input)

                output = self.model(
                    src_input, trg_input, e_mask, d_mask
                )  # (B, L, vocab_size)

                trg_output_shape = trg_output.shape
                self.optim.zero_grad()
                loss = self.criterion(
                    output.view(-1, self.model_core.trg_vocab_size),
                    trg_output.view(trg_output_shape[0] * trg_output_shape[1]),
                )

                loss.backward()
                self.optim.step()

                train_losses.append(loss.item())

                del src_input, trg_input, trg_output, e_mask, d_mask, output

            end_time = datetime.datetime.now()
            training_time = end_time - start_time
            seconds = training_time.seconds
            hours = seconds // 3600
            minutes = (seconds % 3600) // 60
            seconds = seconds % 60

            mean_train_loss = np.mean(train_losses)
            print(f"#################### Epoch: {epoch} ####################")
            print(
                f"Train loss: {mean_train_loss} || One epoch training time: {hours}hrs {minutes}mins {seconds}secs"
            )

            valid_loss, valid_time = self.validation()

            if not os.path.exists(ckpt_dir):
                os.mkdir(ckpt_dir)

            is_best = False
            if valid_loss < self.best_loss:
                self.best_loss = valid_loss
                is_best = True
                print(
                    f"***** Epoch {epoch} has best valid loss: {self.best_loss} *****"
                )
            state_dict = {
                "model_state_dict": self.model_core.state_dict(),
                "optim_state_dict": self.optim.state_dict(),
                "loss": valid_loss,
                "best_loss": self.best_loss,
                "epoch": epoch,
            }
            torch.save(state_dict, f"{ckpt_dir}/ckpt_epoch{epoch}.tar")
            print(f"Saved checkpoint: ckpt_epoch{epoch}.tar")

            if is_best:
                torch.save(state_dict, f"{ckpt_dir}/best_ckpt.tar")
                print("***** Updated best_ckpt.tar *****")

            print(f"Best valid loss: {self.best_loss}")
            print(f"Valid loss: {valid_loss} || One epoch training time: {valid_time}")

        print("Training finished!")

    def validation(self):
        print("Validation processing...")
        self.model.eval()

        valid_losses = []
        start_time = datetime.datetime.now()

        with torch.no_grad():
            for i, batch in tqdm(enumerate(self.valid_loader)):
                src_input, trg_input, trg_output = batch
                src_input, trg_input, trg_output = (
                    src_input.to(device),
                    trg_input.to(device),
                    trg_output.to(device),
                )

                e_mask, d_mask = self.make_mask(src_input, trg_input)

                output = self.model(
                    src_input, trg_input, e_mask, d_mask
                )  # (B, L, vocab_size)

                trg_output_shape = trg_output.shape
                loss = self.criterion(
                    output.view(-1, self.model_core.trg_vocab_size),
                    trg_output.view(trg_output_shape[0] * trg_output_shape[1]),
                )

                valid_losses.append(loss.item())

                del src_input, trg_input, trg_output, e_mask, d_mask, output

        end_time = datetime.datetime.now()
        validation_time = end_time - start_time
        seconds = validation_time.seconds
        hours = seconds // 3600
        minutes = (seconds % 3600) // 60
        seconds = seconds % 60

        mean_valid_loss = np.mean(valid_losses)

        return mean_valid_loss, f"{hours}hrs {minutes}mins {seconds}secs"

    def inference(self, input_sentence, method):
        print("Inference starts.")
        self.model.eval()

        print("Loading sentencepiece tokenizer...")
        src_sp = spm.SentencePieceProcessor()
        trg_sp = spm.SentencePieceProcessor()
        src_sp.Load(f"{SP_DIR}/{src_model_prefix}.model")
        trg_sp.Load(f"{SP_DIR}/{trg_model_prefix}.model")

        print("Preprocessing input sentence...")
        tokenized = src_sp.EncodeAsIds(input_sentence)
        src_data = (
            torch.LongTensor(pad_or_truncate(tokenized)).unsqueeze(0).to(device)
        )  # (1, L)
        e_mask = (src_data != pad_id).unsqueeze(1).to(device)  # (1, 1, L)

        start_time = datetime.datetime.now()

        print("Encoding input sentence...")

        with torch.no_grad():
            src_data = self.model_core.src_embedding(src_data)
            src_data = self.model_core.positional_encoder(src_data)
            e_output = self.model_core.encoder(src_data, e_mask)  # (1, L, d_model)

            if method == "greedy":
                print("Greedy decoding selected.")
                result = self.greedy_search(e_output, e_mask, trg_sp)
            elif method == "beam":
                print("Beam search selected.")
                result = self.beam_search(e_output, e_mask, trg_sp)

        end_time = datetime.datetime.now()

        total_inference_time = end_time - start_time
        seconds = total_inference_time.seconds
        minutes = seconds // 60
        seconds = seconds % 60

        print(f"Input: {input_sentence}")
        print(f"Result: {result}")
        print(
            f"Inference finished! || Total inference time: {minutes}mins {seconds}secs"
        )

        return result

    def greedy_search(self, e_output, e_mask, trg_sp):
        last_words = torch.LongTensor([pad_id] * seq_len).to(device)  # (L)
        last_words[0] = sos_id  # (L)
        cur_len = 1

        for i in range(seq_len):
            d_mask = (
                (last_words.unsqueeze(0) != pad_id).unsqueeze(1).to(device)
            )  # (1, 1, L)
            nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool).to(
                device
            )  # (1, L, L)
            nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) to triangular shape
            d_mask = d_mask & nopeak_mask  # (1, L, L) padding false

            trg_embedded = self.model_core.trg_embedding(last_words.unsqueeze(0))
            trg_positional_encoded = self.model_core.positional_encoder(trg_embedded)
            decoder_output = self.model_core.decoder(
                trg_positional_encoded, e_output, e_mask, d_mask
            )  # (1, L, d_model)

            output = self.model_core.softmax(
                self.model_core.output_linear(decoder_output)
            )  # (1, L, trg_vocab_size)

            output = torch.argmax(output, dim=-1)  # (1, L)
            last_word_id = output[0][i].item()

            if i < seq_len - 1:
                last_words[i + 1] = last_word_id
                cur_len += 1

            if last_word_id == eos_id:
                break

        if last_words[-1].item() == pad_id:
            decoded_output = last_words[1:cur_len].tolist()
        else:
            decoded_output = last_words[1:].tolist()
        decoded_output = trg_sp.decode_ids(decoded_output)

        return decoded_output

    def beam_search(self, e_output, e_mask, trg_sp, beam_size=beam_size, alpha=0.7):
        """
        Beam search implemented correctly.
        - e_output: encoder outputs (1, L_enc, d_model)
        - e_mask: encoder mask (1, 1, L_enc)
        - trg_sp: SentencePiece processor for decoding ids->text
        - beam_size: beam width
        - alpha: length normalization hyperparameter (common default ~0.7)
        """
        self.model.eval()

        # Each hypothesis: (tokens_list, cumulative_logprob, is_finished)
        # Initialize with single hypothesis [SOS]
        hypotheses = [([sos_id], 0.0, False)]

        for t in range(seq_len):
            all_candidates = []

            # If all hypotheses are finished, we can stop early
            if all(h[2] for h in hypotheses):
                break

            # Build batch of decoder inputs: for every hypothesis that is not finished,
            # we'll expand with top-k next token candidates. For finished hypos, keep them as-is.
            # We will run decoder on the batch of candidate sequences to get log-probs.
            for h_idx, (tokens, logp, finished) in enumerate(hypotheses):
                if finished:
                    # keep finished hypothesis as a candidate (carry over)
                    all_candidates.append((tokens, logp, True))
                else:
                    # We will expand this hypothesis; but first create its current input (padded)
                    # We'll ask the model for top-k next tokens; to do that efficiently we will
                    # create candidate inputs later.
                    # For now just note we will expand this hypothesis.
                    # We'll create the actual candidate inputs after we determine top-k per hypo.
                    pass

            # To get top-k for each hypothesis we need model output at position t given its tokens.
            # We'll create a batch of current hypotheses (one per non-finished hypo), run decoder,
            # and extract log-probs at time step t, then pick top-k per hypothesis.
            alive_hypos = [(idx, h) for idx, h in enumerate(hypotheses) if not h[2]]
            if len(alive_hypos) == 0:
                break

            # Prepare batch input: for each alive hypo, create padded tensor (seq_len) with its tokens
            batch_inputs = []
            hypo_map = []  # map from batch row -> hypothesis index
            for h_idx, (tokens, logp, finished) in alive_hypos:
                seq = tokens + [pad_id] * (seq_len - len(tokens))
                batch_inputs.append(seq)
                hypo_map.append(h_idx)

            batch_inputs = torch.LongTensor(batch_inputs).to(
                device
            )  # (B_alive, seq_len)

            # Create decoder mask for this batch
            d_mask = (
                (batch_inputs != pad_id).unsqueeze(1).to(device)
            )  # (B_alive, 1, seq_len)
            nopeak = torch.tril(
                torch.ones((1, seq_len, seq_len), dtype=torch.bool, device=device)
            )
            d_mask = d_mask & nopeak  # (B_alive, seq_len, seq_len) broadcasted

            # Run decoder for this batch (one forward)
            with torch.no_grad():
                trg_emb = self.model_core.trg_embedding(
                    batch_inputs
                )  # (B_alive, L, d_model)
                trg_emb = self.model_core.positional_encoder(
                    trg_emb
                )  # (B_alive, L, d_model)
                dec_out = self.model_core.decoder(
                    trg_emb,
                    e_output.repeat(len(batch_inputs), 1, 1)
                    if e_output.size(0) == 1
                    else e_output,
                    e_mask.repeat(len(batch_inputs), 1, 1)
                    if e_mask.size(0) == 1
                    else e_mask,
                    d_mask,
                )  # (B_alive, L, d_model)
                # Get log-probs (use model's output_linear then log_softmax to be safe)
                logits = self.model_core.output_linear(dec_out)  # (B_alive, L, V)
                log_probs = F.log_softmax(logits, dim=-1)  # (B_alive, L, V)

            # For each alive hypothesis, get top-k tokens at time step t
            B_alive = log_probs.size(0)
            V = log_probs.size(-1)
            topk = min(beam_size, V)

            for i in range(B_alive):
                hypo_idx = hypo_map[i]
                tokens, curr_logp, _ = hypotheses[hypo_idx]
                # get log-prob vector at time t
                logp_t = log_probs[i, t]  # (V,)
                top_vals, top_idx = torch.topk(logp_t, k=topk)  # both tensors
                top_vals = top_vals.cpu().tolist()
                top_idx = top_idx.cpu().tolist()
                for k_idx, token_id in enumerate(top_idx):
                    new_tokens = tokens + [token_id]
                    new_logp = curr_logp + top_vals[k_idx]  # cumulative log-prob
                    finished = token_id == eos_id
                    all_candidates.append((new_tokens, new_logp, finished))

            # Also include previous finished hypotheses (they were added earlier)

            # Now select top `beam_size` candidates among all_candidates by cumulative log-prob
            # Note: do NOT normalize length here (we keep cumulative log-prob for beam propagation).
            all_candidates = sorted(all_candidates, key=lambda x: x[1], reverse=True)
            hypotheses = all_candidates[:beam_size]

            # If number of hypotheses < beam_size (possible if many finished), pad by carrying best finished
            # (not strictly necessary)

        # After finishing (either reached seq_len or all finished), choose best hypothesis with length normalization
        # Apply length normalization score = logp / (len(tokens) ** alpha)
        final_scores = []
        for tokens, logp, finished in hypotheses:
            length = len(tokens) - 1  # exclude SOS for length
            if length <= 0:
                length = 1.0
            score = logp / (length**alpha)
            final_scores.append((score, tokens, finished))

        final_scores = sorted(final_scores, key=lambda x: x[0], reverse=True)
        best_tokens = final_scores[0][1]

        # Remove leading SOS and trailing EOS if present
        if best_tokens and best_tokens[0] == sos_id:
            best_tokens = best_tokens[1:]
        if best_tokens and best_tokens[-1] == eos_id:
            best_tokens = best_tokens[:-1]

        return trg_sp.decode_ids(best_tokens)

    def make_mask(self, src_input, trg_input):
        e_mask = (src_input != pad_id).unsqueeze(1).to(device)
        d_mask = (trg_input != pad_id).unsqueeze(1).to(device)

        nopeak_mask = torch.tril(
            torch.ones((1, seq_len, seq_len), dtype=torch.bool, device=device)
        )  # (1, L, L) to triangular shape
        d_mask = d_mask & nopeak_mask  # (B, L, L) padding false

        return e_mask, d_mask


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True, help="train or inference?")
    parser.add_argument("--ckpt_name", required=False, help="best checkpoint file")
    parser.add_argument(
        "--input", type=str, required=False, help="input sentence when inferencing"
    )
    parser.add_argument(
        "--decode", type=str, required=False, default="greedy", help="greedy or beam?"
    )

    args = parser.parse_args()

    if args.mode == "train":
        if args.ckpt_name is not None:
            manager = Manager(is_train=True, ckpt_name=args.ckpt_name)
        else:
            manager = Manager(is_train=True)

        manager.train()
    elif args.mode == "inference":
        assert args.ckpt_name is not None, (
            "Please specify the model file name you want to use."
        )
        assert args.input is not None, "Please specify the input sentence to translate."
        assert args.decode == "greedy" or args.decode == "beam", (
            "Please specify correct decoding method, either 'greedy' or 'beam'."
        )

        manager = Manager(is_train=False, ckpt_name=args.ckpt_name)
        manager.inference(args.input, args.decode)

    else:
        print("Please specify mode argument either with 'train' or 'inference'.")

Writing main.py


In [9]:
%%writefile sentencepiece_train.py
from constants import DATA_DIR, SP_DIR, SRC_RAW_DATA_NAME, TRG_RAW_DATA_NAME
from constants import src_model_prefix, trg_model_prefix, pad_id, sos_id, eos_id, unk_id, sp_vocab_size, character_coverage, model_type
from constants import SRC_DIR, TRG_DIR, TRAIN_NAME, VALID_NAME
from tqdm import tqdm

import os
import sentencepiece as spm

train_frac = 0.8

def train_sp(is_src=True):
    template = "--input={} \
                --pad_id={} \
                --bos_id={} \
                --eos_id={} \
                --unk_id={} \
                --model_prefix={} \
                --vocab_size={} \
                --character_coverage={} \
                --model_type={}"

    if is_src:
        this_input_file = f"{DATA_DIR}/{SRC_RAW_DATA_NAME}"
        this_model_prefix = f"{SP_DIR}/{src_model_prefix}"
    else:
        this_input_file = f"{DATA_DIR}/{TRG_RAW_DATA_NAME}"
        this_model_prefix = f"{SP_DIR}/{trg_model_prefix}"

    config = template.format(this_input_file,
                            pad_id,
                            sos_id,
                            eos_id,
                            unk_id,
                            this_model_prefix,
                            sp_vocab_size,
                            character_coverage,
                            model_type)

    print(config)

    if not os.path.isdir(SP_DIR):
        os.mkdir(SP_DIR)

    print(spm)
    spm.SentencePieceTrainer.Train(config)
    
    
def split_data(raw_data_name, data_dir):
    with open(f"{DATA_DIR}/{raw_data_name}") as f:
        lines = f.readlines()    
    
    print("Splitting data...")
    
    train_lines = lines[:int(train_frac * len(lines))]
    valid_lines = lines[int(train_frac * len(lines)):]
    
    if not os.path.isdir(f"{DATA_DIR}/{data_dir}"):
        os.mkdir(f"{DATA_DIR}/{data_dir}")
    
    with open(f"{DATA_DIR}/{data_dir}/{TRAIN_NAME}", 'w') as f:
        for line in tqdm(train_lines):
            f.write(line.strip() + '\n')
            
    with open(f"{DATA_DIR}/{data_dir}/{VALID_NAME}", 'w') as f:
        for line in tqdm(valid_lines):
            f.write(line.strip() + '\n')
            
    print(f"Train/Validation data saved in {DATA_DIR}/{data_dir}.")


if __name__=='__main__':
    train_sp(is_src=True)
    train_sp(is_src=False)
    split_data(SRC_RAW_DATA_NAME, SRC_DIR)
    split_data(TRG_RAW_DATA_NAME, TRG_DIR)

Writing sentencepiece_train.py


In [10]:
%%writefile transformer.py
from torch import nn
from constants import d_model, num_layers
from layers import EncoderLayer, DecoderLayer, PositionalEncoder, LayerNormalization


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size):
        super().__init__()
        self.src_vocab_size = src_vocab_size
        self.trg_vocab_size = trg_vocab_size

        self.src_embedding = nn.Embedding(self.src_vocab_size, d_model)
        self.trg_embedding = nn.Embedding(self.trg_vocab_size, d_model)
        self.positional_encoder = PositionalEncoder()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.output_linear = nn.Linear(d_model, self.trg_vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, src_input, trg_input, e_mask=None, d_mask=None):
        src_input = self.src_embedding(src_input) # (B, L) => (B, L, d_model)
        trg_input = self.trg_embedding(trg_input) # (B, L) => (B, L, d_model)
        src_input = self.positional_encoder(src_input) # (B, L, d_model) => (B, L, d_model)
        trg_input = self.positional_encoder(trg_input) # (B, L, d_model) => (B, L, d_model)

        e_output = self.encoder(src_input, e_mask) # (B, L, d_model)
        d_output = self.decoder(trg_input, e_output, e_mask, d_mask) # (B, L, d_model)

        output = self.softmax(self.output_linear(d_output)) # (B, L, d_model) => # (B, L, trg_vocab_size)

        return output


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer() for i in range(num_layers)])
        self.layer_norm = LayerNormalization()

    def forward(self, x, e_mask):
        for i in range(num_layers):
            x = self.layers[i](x, e_mask)

        return self.layer_norm(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer() for i in range(num_layers)])
        self.layer_norm = LayerNormalization()

    def forward(self, x, e_output, e_mask, d_mask):
        for i in range(num_layers):
            x = self.layers[i](x, e_output, e_mask, d_mask)

        return self.layer_norm(x)

Writing transformer.py


In [11]:
!python main.py --mode train

Loading vocabs...
The size of src vocab is 16000 and that of trg vocab is 16000.
Loading Transformer model & Adam optimizer...
Detecting 2 GPUs. Using DataParallel!
Starting from epoch 1
Initializing the model...
Loading loss function...
Loading dataloaders...
Getting source/target train.txt...
Tokenizing & Padding src data...
100%|████████████████████████████████| 560000/560000 [00:13<00:00, 41150.81it/s]
The shape of src data: (560000, 128)
Tokenizing & Padding trg data...
100%|████████████████████████████████| 560000/560000 [00:18<00:00, 30012.17it/s]
The shape of input trg data: (560000, 128)
The shape of output trg data: (560000, 128)
Getting source/target valid.txt...
Tokenizing & Padding src data...
100%|████████████████████████████████| 140000/140000 [00:03<00:00, 45802.70it/s]
The shape of src data: (140000, 128)
Tokenizing & Padding trg data...
100%|████████████████████████████████| 140000/140000 [00:04<00:00, 33029.04it/s]
The shape of input trg data: (