In [None]:
import json, os, time
import torch, numpy, random
from tqdm import tqdm

from datasets import load_metric
! pip install sacrebleu

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [None]:
numpy.random.seed(3)
torch.cuda.manual_seed(3)
torch.manual_seed(3)
random.seed(3)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
try:
    import wandb
except:
    ! pip install wandb
import wandb
wandb.login(
    key = "280aa3837eb27ece3c32ed8e27e3e233d0afdc9c"
)
wandb.init("RNN Encoder - Decoder Model For Project Deep Learning")

# 0. Config

In [None]:
config = {
    ### Thông số bộ dataset
    "train_data_file": "/kaggle/input/phomt-dl-2023-1/PhoMT_json/tokenization/train/train.json",
    "dev_data_file": "/kaggle/input/phomt-dl-2023-1/PhoMT_json/tokenization/dev/dev.json",
    "test_data_file": "/kaggle/input/phomt-dl-2023-1/PhoMT_json/tokenization/test/test.json",
    "small_train_data": 30000, # nếu chỉ train bộ data nhỏ, phải set = số data sẽ sample ra, ví dụ 100 000.
    
    ### Thông số train
    "batch_size" : 16,
    "epoch": 10,
    
    ### Thông số model:
    "initial_model": None, # path của file model.pth để load lại model cho train tiếp hoặc infer
    "embedding_dim": 256,
    "hidden_size": 1024,
    
    ### Thông số opimizer và lr_scheduler:
    "step_size": 2,
    "gamma": 0.2,
    
    ### Output folder
    "model_save_path": "/kaggle/working/save_model.pth"
    
}

# Các hàm phụ

In [None]:
def load_model():
    statedict_2 = torch.load(config["initial_model"])
    encoder_2.load_state_dict(statedict_2['encoder'])
    decoder_2.load_state_dict(statedict_2['decoder'])

# **1. Phần dataset**

In [None]:
def preprocess(json_file, train = True):
    en_sentences = list()
    vi_sentences = list()
    
    with open(json_file, "r") as f:
        data = json.load(f)["data"]
        
    if train and config["small_train_data"] != 0:
        data = numpy.random.choice(a = data, 
                                   size = config["small_train_data"])
        
    for sample in data:
        en_sentences.append(sample["translation"]["en"].strip().lower())
        vi_sentences.append(sample["translation"]["vi"].strip().lower())
    return en_sentences, vi_sentences

In [None]:
class Lang:
    def __init__(self, sentence_list, train=True, word2id=None, id2word=None):
        self.word2id = word2id
        self.id2word = id2word
        self.train = train
        self.preprocess(sentence_list)
        self.get_vocab()
        self.get_word_vectors()
    
    def preprocess(self, sentence_list):
        """
        Preprocess các câu trong sentence_list:
        - thêm 2 token <START> và <END>
        - Padding các câu bằng các token <PAD>
        """
        
        ### Thêm 2 token <START> và <END>
        self.max_len = 0
        self.sentences = []
        for sen in sentence_list:
            sen = '<START> ' + sen + ' <END>'
            length = len(sen.split())
            self.sentences.append(sen)
            if self.max_len < length:
                self.max_len = length
        
        ### Padding
        for i, sen in enumerate(self.sentences):
            length = len(sen.split())
            diff = self.max_len - length
            paddings = [' <PAD>'] * diff
            self.sentences[i] = sen + ''.join(paddings)
    
    def get_vocab(self):
        """
        Tạo word2id, id2word, vocab size.
        """
        if self.train:
            self.word2id = {}
            self.id2word = []
            for s in self.sentences:
                for char in s.split():
                    if char not in self.word2id:
                        self.id2word.append(char)
                        self.word2id[char] = len(self.id2word) - 1
        self.vocab_size = len(self.id2word)
    
    def get_word_vectors(self):
        """
        Tạo word vectors.
        """
        self.wordvec = []
        for i, sen in enumerate(self.sentences):
            id_list = []
            for s in sen.split():
                if s in self.word2id:
                    id_list.append(self.word2id[s])
                else:
                    id_list.append(random.randint(0, self.vocab_size-1))
            self.wordvec.append(id_list)
        self.wordvec = numpy.array(self.wordvec)

In [None]:
class MTDataset(Dataset):
    def __init__(self, 
                 input_matrix,  # word vectors of input sentences
                 target_matrix  # word vectors of output sentences
                ):
        self.data = []
        for i in range(len(input_matrix)):
            self.data.append((input_matrix[i], target_matrix[i]))
            
    def __getitem__(self, idx):
        return (torch.Tensor(self.data[idx][0]), torch.Tensor(self.data[idx][1]))
    
    def __len__(self):
        return len(self.data)

# **2. Model**

In [None]:
class Encoder(nn.Module):
    def __init__(self,
                 en_vocab_size: int, # số lượng từ gtrong vocab ngôn ngữ input (en)
                 embedding_dim: int, # số chiều của vector embedding
                 hidden_size: int,   # số chiều của vectoer state h trong GRU
                ):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        
        # 1. Lớp Embedding
        self.embedding_layer = nn.Embedding(num_embeddings = en_vocab_size,
                                            embedding_dim = embedding_dim)
        # 2. Mạng GRU
        self.gru = nn.GRU(input_size = embedding_dim, 
                          hidden_size = hidden_size,
                          batch_first=True, 
                          bidirectional=True)
    
    def forward(self, x):
        embedding = self.embedding_layer(x)
        output, hidden = self.gru(embedding)
        last_backward_hidden = output[:, 0, self.hidden_size:].unsqueeze(0)
        
        return output, last_backward_hidden

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, vocab_size, embedding_dim):
        super(Decoder, self).__init__()
        
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        # Alignment model
        self.Wa = nn.Linear(self.hidden_size, self.hidden_size)
        self.Ua = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.Va = nn.Linear(self.hidden_size, 1)
        self.softmax = nn.Softmax(dim=1)
        
        # GRU layer
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.gru = nn.GRU(self.embedding_dim + self.hidden_size * 2, self.hidden_size, batch_first=True)
        self.out = nn.Linear(self.hidden_size, self.vocab_size)
        
    def forward(self, dec_input, hidden, enc_out):
        Tx = enc_out.shape[1]
        hidden_repeat = hidden.permute(1, 0, 2).repeat(1, Tx, 1)
        energies = self.Va(torch.tanh(self.Wa(hidden_repeat) + self.Ua(enc_out)))
        alphas = self.softmax(energies)
        context = torch.sum(alphas * enc_out, dim=1).unsqueeze(1)
        embedding = self.embedding(dec_input.unsqueeze(1))
        gru_input = torch.cat((embedding, context), dim=-1)
        out, hidden = self.gru(gru_input, hidden.contiguous())
        out = self.out(out)
        return out, hidden

In [None]:
def get_encoder(en_vocab_size: int):
    encoder = Encoder(en_vocab_size = en_vocab_size,
                      embedding_dim = config["embedding_dim"],
                      hidden_size = config["hidden_size"]).to(device)
    return encoder

def get_decoder(vi_vocab_size: int):
    decoder = Decoder(hidden_size = config["hidden_size"],
                      vocab_size = vi_vocab_size,
                      embedding_dim = config["embedding_dim"]).to(device)
    return decoder

# **3. Loss function**

In [None]:
def getLossfn():
    return nn.CrossEntropyLoss()

# **4. Optimizer (and Learning rate Scheduler)**

In [None]:
def getOptimizer(encoder: nn.Module, decoder: nn.Module):
    optimizer = torch.optim.Adam(params = list(encoder.parameters()) + list(decoder.parameters())
                                )
    return optimizer

def getLrScheduler(optimizer: torch.optim.Adam):
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size = config["step_size"],
                                                   gamma = config["gamma"])
    return lr_scheduler

# **5. Hàm train và eval**

In [None]:
### Hàm dịch 1 câu 
def translate(sentence, en_word2id, vi_word2id, vi_id2word, encoder, decoder, vi_max_len):

    sentence = '<START> ' + sentence.strip().lower() + ' <END>'
    sen_matrix = [en_word2id[s] for s in sentence.split()]
    sen_tensor = torch.Tensor(sen_matrix).to(device=device, dtype=torch.long).unsqueeze(0)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        enc_out, enc_hidden = encoder(sen_tensor)
        dec_hidden = enc_hidden
        dec_input = torch.Tensor([vi_word2id['<START>']]).to(device, dtype=torch.long)
        output_list = []
        for t in range(1, vi_max_len):
            out, dec_hidden = decoder(dec_input, dec_hidden, enc_out)
            dec_input = torch.max(out, dim=-1)[1].squeeze(1)
            next_id = dec_input.squeeze().clone().cpu().numpy()
            next_word = vi_id2word[next_id]
            if next_word == '<END>':
                break
            output_list.append(next_word)
        return ' '.join(output_list)

In [None]:
### Hàm Train
def train(train_loader, 
          encoder, decoder, loss_fn, optimizer, lr_scheduler,
          id2word
         ):
    
    encoder.train()
    decoder.train()
    best_bleu = 0
    best_statedict = {'encoder': encoder.state_dict(), 'decoder': decoder.state_dict()}
    
    for epoch in range(config["epoch"]):
        print(f"Start epoch {epoch}")
        t1 = time.time()
        ### Part 1. Train
        train_loss = 0
        for i, (x, y) in enumerate(train_loader):
            x = x.type(torch.long).to(device)
            y = y.type(torch.long).to(device)
            
            enc_out, enc_hidden = encoder(x)
            dec_hidden = enc_hidden
            dec_input = y[:, 0]
            loss = 0
            optimizer.zero_grad()
            for t in range(1, y.size(1)):
                out, dec_hidden = decoder(dec_input, dec_hidden, enc_out)
                dec_input = y[:, t]
                loss += loss_fn(out.squeeze(1), y[:, t])
            
            train_loss += loss
            loss.backward()
            optimizer.step()
        
        train_loss /= len(train_loader)
        train_loss = train_loss.detach().cpu().item()
        
        lr_scheduler.step()
        
        ### Part 2. Eval
        t2 = time.time()
        bleu, num_oov = eval_bleu(en_sentences_dev, vi_sentences_dev, en_train.word2id, vi_train.word2id, vi_train.id2word, encoder, decoder, vi_train.max_len )
        t3 = time.time()
        
        if bleu > best_bleu:
            best_statedict = {'encoder': encoder.state_dict(), 'decoder': decoder.state_dict()}
            best_bleu = bleu
        
        ### Print:
        print(f"Train loss: {train_loss}")
        print(f"Train time: {t2-t1}")
        print(f"Bleu score on dev set: {bleu} and num_oov = {num_oov}")
        print(f"Eval Bleu time: {t3 - t2}")
        print(f"End epoch {epoch}\n********************************************************")
        
        ### wandb
        wandb.log({
            "Train loss": train_loss,
            "Bleu score on dev set": bleu
        })
        
    return best_statedict, best_bleu

In [None]:
### Eval, tính BLEU score
def eval_bleu(en_sentences_test, vi_sentences_test,
              en_word2id, vi_word2id, vi_id2word,
              encoder, decoder, vi_max_len
             ):
    bleu_metric = load_metric("sacrebleu")
    number_oov = 0
    
    ### Dịch bộ test
    translated_sentences = list()
    reference_sentences = list()
    for idx, en_sen in enumerate(en_sentences_test):
        oov = False
        try:
            translate_sen = translate(en_sen, en_word2id, vi_word2id, vi_id2word, encoder, decoder, vi_max_len)
        except: # Nếu bị out of vocab
            number_oov += 1
            oov = True
        if not oov:
            translated_sentences.append(translate_sen)
            reference_sentences.append(vi_sentences_test[idx])
    
    ### Tính Bleu
    for translation, reference in zip(translated_sentences, reference_sentences):
        bleu_metric.add(prediction = translation, reference = [reference])
    
    bleu_score = bleu_metric.compute()
    return bleu_score, num_oov

# **Run**

In [None]:
### 1. Tạo dataset
en_sentences_train, vi_sentences_train = preprocess(config["train_data_file"])
en_sentences_dev, vi_sentences_dev = preprocess(config["dev_data_file"], train = False)
en_sentences_test, vi_sentences_test = preprocess(config["test_data_file"], train = False)

In [None]:
en_train, vi_train = Lang(en_sentences_train), Lang(vi_sentences_train)
en_dev, vi_dev = Lang(en_sentences_dev, train= False, word2id = en_train.word2id, id2word = en_train.id2word), Lang(vi_sentences_dev, train= False, word2id = vi_train.word2id, id2word = vi_train.id2word)
en_test, vi_test = Lang(en_sentences_test, train= False, word2id = en_train.word2id, id2word = en_train.id2word), Lang(vi_sentences_test, train= False, word2id = vi_train.word2id, id2word = vi_train.id2word)

In [None]:
train_dataset = MTDataset(en_train.wordvec, vi_train.wordvec)
train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

In [None]:
### 2. Tạo model
encoder = get_encoder(en_train.vocab_size)
decoder = get_decoder(vi_train.vocab_size)

In [None]:
### 3. Tạo loss
loss_fn = getLossfn()

In [None]:
### 4. Tạo optimizer và lr_scheduler
optimizer = getOptimizer(encoder, decoder)
lr_scheduler = getLrScheduler(optimizer)

## **Train & eval on dev set**

In [None]:
### Train
best_statedict, best_bleu = train(train_dataloader, encoder, decoder, loss_fn, optimizer, lr_scheduler, vi_train.id2word)
print("Complete train!")
print(f"The best bleu score on dev set: {best_bleu}")
print()

## **Eval on test set**

In [None]:
t1 = time.time()
bleu, num_oov = eval_bleu(en_sentences_test, vi_sentences_test, en_train.word2id, vi_train.word2id, vi_train.id2word, encoder, decoder, vi_train.max_len )
t2 = time.time()
print(f"Time for eval bleu score on test set: {t2 - t1}")
print(f"BLEU and OOV on test set: {bleu_score, num_oov}")
print()

# **SAVE**

In [None]:
torch.save(best_statedict, config["model_save_path"])
print("Saved model!")