<a href="https://colab.research.google.com/github/nhanphanvan/Transformer/blob/main/En_Vi_Machine_Translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Machine Translation with Transformer and Bert

### 1. Install Transformer package and extract data


In [1]:
!git clone https://github.com/nhanphanvan/Transformer.git

Cloning into 'Transformer'...
remote: Enumerating objects: 234, done.[K
remote: Counting objects: 100% (234/234), done.[K
remote: Compressing objects: 100% (153/153), done.[K
remote: Total 234 (delta 124), reused 166 (delta 64), pack-reused 0[K
Receiving objects: 100% (234/234), 51.32 MiB | 9.87 MiB/s, done.
Resolving deltas: 100% (124/124), done.
Checking out files: 100% (72/72), done.


In [2]:
import os, sys, tarfile

def extract(tar_url, extract_path='.'):
    print(tar_url)
    tar = tarfile.open(tar_url, 'r')
    for item in tar:
        tar.extract(item, extract_path)
        if item.name.find(".tgz") != -1 or item.name.find(".tar") != -1:
            extract(item.name, "./" + item.name[:item.name.rfind('/')])

# try:
#     extract(dev_path)
#     extract(test_path)
#     extract(train_path)
#     print('Done.')
# except:
#     print('Error')

In [3]:
dev_path = './Transformer/data/en-vi-translation/zip-file/dev-2012-en-vi.tgz'
test_path = './Transformer/data/en-vi-translation/zip-file/test-2013-en-vi.tgz'
train_path = './Transformer/data/en-vi-translation/zip-file/train-en-vi.tgz'

train_src_path = './Transformer/data/en-vi-translation/split-data/short_train.en'
train_tgt_path = './Transformer/data/en-vi-translation/split-data/short_train.vi'
dev_src_path = './Transformer/data/en-vi-translation/split-data/short_dev.en'
dev_tgt_path = './Transformer/data/en-vi-translation/split-data/short_dev.vi'
test_src_path = './Transformer/data/en-vi-translation/split-data/short_test.en'
test_tgt_path = './Transformer/data/en-vi-translation/split-data/short_test.vi'

medical_train_src_path = './Transformer/data/en-vi-translation/medical-data/long_medical_set.en'
medical_train_tgt_path = './Transformer/data/en-vi-translation/medical-data/long_medical_set.vi'
medical_test_src_path = './Transformer/data/en-vi-translation/medical-data/medical_test_set.en'
medical_test_tgt_path = './Transformer/data/en-vi-translation/medical-data/medical_test_set.vi'

# folder to save model and optimizer during training and after trained
FOLDER_PATH = './'
RESULT_PATH = FOLDER_PATH + "result.txt"

### 2. Build Custom Dataset

In [4]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import Tensor
from typing import Optional

In [5]:
def segment_sentence(sentence):
  sentences = rdrsegmenter.tokenize(sentence) 
  sentences = [" ".join(sentence) for sentence in sentences]
  sentence = " ".join(sentences).strip()
  return sentence

In [6]:
class CustomDataset(Dataset):
  def __init__(self, src_path, tgt_path, segment=None):
    with open(src_path, 'r', encoding='utf-8') as file:
      src = file.read().splitlines()
    with open(tgt_path, 'r', encoding='utf-8') as file:
      tgt = file.read().splitlines()
    if segment is not None:
      if segment:
        src = [segment_sentence(sentence) for sentence in src]
      else:
        tgt = [segment_sentence(sentence) for sentence in tgt]
    self.samples = list(zip(src, tgt))

  def __len__(self):
    return len(self.samples)

  def __getitem__(self, index):
    return self.samples[index]

### 3. Install vncorenlp and transformers 

In [7]:
!pip -q install transformers
!pip -q install vncorenlp
!pip -q install fairseq
!pip -q install fastBPE

!pip -q install fastapi
!pip -q install uvicorn
!pip -q install pyngrok

[K     |████████████████████████████████| 3.8 MB 13.0 MB/s 
[K     |████████████████████████████████| 895 kB 43.3 MB/s 
[K     |████████████████████████████████| 67 kB 4.8 MB/s 
[K     |████████████████████████████████| 596 kB 42.3 MB/s 
[K     |████████████████████████████████| 6.5 MB 39.0 MB/s 
[K     |████████████████████████████████| 2.6 MB 12.1 MB/s 
[?25h  Building wheel for vncorenlp (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 1.7 MB 13.2 MB/s 
[K     |████████████████████████████████| 90 kB 9.8 MB/s 
[K     |████████████████████████████████| 145 kB 49.9 MB/s 
[K     |████████████████████████████████| 74 kB 3.2 MB/s 
[K     |████████████████████████████████| 112 kB 56.1 MB/s 
[?25h  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone
  Building wheel for fastBPE (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 54 kB 247 kB/s 
[K     |████████████████████████████████| 58 kB 5.5 MB/s 
[K  

In [8]:
!gdown https://drive.google.com/a/gm.uit.edu.vn/uc?id=1pXJZ9eHp6DWkQ5MhCzmWYsKyLQEDiodz&export=download
!tar xzf /content/vn_sbert_deploy.tar.gz

!mkdir -p vncorenlp/models/wordsegmenter
!wget -q --show-progress https://raw.githubusercontent.com/vncorenlp/VnCoreNLP/master/VnCoreNLP-1.1.1.jar
!wget -q --show-progress https://raw.githubusercontent.com/vncorenlp/VnCoreNLP/master/models/wordsegmenter/vi-vocab
!wget -q --show-progress https://raw.githubusercontent.com/vncorenlp/VnCoreNLP/master/models/wordsegmenter/wordsegmenter.rdr
!mv VnCoreNLP-1.1.1.jar vncorenlp/ 
!mv vi-vocab vncorenlp/models/wordsegmenter/
!mv wordsegmenter.rdr vncorenlp/models/wordsegmenter/

Access denied with the following error:

 	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

	 https://drive.google.com/a/gm.uit.edu.vn/uc?id=1pXJZ9eHp6DWkQ5MhCzmWYsKyLQEDiodz 

tar (child): /content/vn_sbert_deploy.tar.gz: Cannot open: No such file or directory
tar (child): Error is not recoverable: exiting now
tar: Child returned status 2
tar: Error is not recoverable: exiting now


In [9]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

SRC_VOCAB_SIZE = 28996
TGT_VOCAB_SIZE = 64001
HIDDEN_SIZE = 768
NUM_ENCODER_LAYERS = 12
NUM_DECODER_LAYERS = 12
NUM_ATTENTION_HEADS = 12
FEEDFORWARD_SIZE = 3072
DROPOUT = 0.1
ACTIVATION = 'gelu'
LAYER_NORM_EPS = 1e-12
SRC_UNK_ID, SRC_PADDING_ID, SRC_BOS_ID, SRC_EOS_ID = 100, 0, 101, 102
TGT_UNK_ID, TGT_PADDING_ID, TGT_BOS_ID, TGT_EOS_ID = 3, 1, 0, 2
NORM_FIRST = True
MAX_SEQUENCE_LENGTH = 1024
BATCH_SIZE = 10
BERT_EMBEDDING = True
OUTPUT_HIDDEN_STATES = True
APPLY_LAYER_NORM = True

In [10]:
from vncorenlp import VnCoreNLP
from transformers import AutoModel, AutoTokenizer, AutoConfig, BertModel, RobertaModel

rdrsegmenter = VnCoreNLP("./vncorenlp/VnCoreNLP-1.1.1.jar", annotators="wseg", max_heap_size='-Xmx500m') 

src_model_id = 'bert-base-cased'
tgt_model_id = 'vinai/phobert-base'

src_config = AutoConfig.from_pretrained(src_model_id)
# src_bert = AutoModel.from_pretrained(src_model_id, config=src_config)
src_tokenizer = AutoTokenizer.from_pretrained(src_model_id)
src_tokenizer.model_max_length = MAX_SEQUENCE_LENGTH

tgt_config = AutoConfig.from_pretrained(tgt_model_id)
# tgt_bert = AutoModel.from_pretrained(tgt_model_id, config=tgt_config)
tgt_tokenizer = AutoTokenizer.from_pretrained(tgt_model_id)
tgt_tokenizer.model_max_length = MAX_SEQUENCE_LENGTH

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/557 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/874k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### 4. Init Custom Transformer

In [11]:
from Transformer.modules.config import TransformerConfig
from Transformer.modules.transformer import Transformer
from Transformer.modules.embedding import PositionalEmbedding, TransformerEmbedding
from Transformer.modules.seq2seq_transformer import Seq2SeqTransformer

In [12]:
kwargs = {
    'src_vocab_size': SRC_VOCAB_SIZE,
    'tgt_vocab_size': TGT_VOCAB_SIZE,
    'hidden_size': HIDDEN_SIZE,
    'num_encoder_layers': NUM_ENCODER_LAYERS,
    'num_decoder_layers': NUM_DECODER_LAYERS,
    'num_attention_heads': NUM_ATTENTION_HEADS,
    'feedforward_size': FEEDFORWARD_SIZE,
    'dropout': DROPOUT,
    'activation': ACTIVATION,
    'layer_norm_eps': LAYER_NORM_EPS,
    'src_padding_id': SRC_PADDING_ID,
    'tgt_padding_id': TGT_PADDING_ID,
    'norm_first': NORM_FIRST,
    'max_sequence_length': MAX_SEQUENCE_LENGTH,
    'bert_embedding': BERT_EMBEDDING,
    'output_hidden_states': OUTPUT_HIDDEN_STATES,
    'apply_layer_norm': APPLY_LAYER_NORM,
    'device': DEVICE,
    'dtype': torch.float32
}

config = TransformerConfig(**kwargs)

In [13]:
# Note: This part is incredibly important. 
# Need to train with this setup of the model is very unstable.
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def state_dict(self):
        """Returns the state of the warmup scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
    
    def load_state_dict(self, state_dict):
        """Loads the warmup scheduler's state.
        Arguments:
            state_dict (dict): warmup scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict) 
    
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup**(-1.5)))
        
def get_std_opt(model, d_model):
    return NoamOpt(d_model, 0.25, 8000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [14]:
transformer = Seq2SeqTransformer(config=config)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = nn.CrossEntropyLoss(ignore_index=TGT_PADDING_ID, label_smoothing=0.1)

# optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# optimizer = torch.optim.SGD(transformer.parameters(), lr=0.0001, momentum=0.9, nesterov=True)
optimizer = get_std_opt(transformer, HIDDEN_SIZE)

### 5. Utility Function


In [15]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 0).transpose(0, 1)
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == SRC_PADDING_ID)
    tgt_padding_mask = (tgt == TGT_PADDING_ID)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [16]:
# function to collate data samples into batch tesors
def collate_fn(batch):
    # src_batch, tgt_batch = [], []
    # for src_sample, tgt_sample in batch:
    #     src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
    #     tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    # src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    # tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    # return src_batch, tgt_batch
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(src_sample.rstrip("\n"))
        tgt_batch.append(tgt_sample.rstrip("\n"))
    src_encodings = src_tokenizer.batch_encode_plus(src_batch, padding=True)
    src_ids = torch.tensor(src_encodings.get('input_ids'))
    # src_attention_masks = torch.tensor(src_encodings.get('attention_mask'))
    
    tgt_encodings = tgt_tokenizer.batch_encode_plus(tgt_batch, padding=True)
    tgt_ids = torch.tensor(tgt_encodings.get('input_ids'))
    # tgt_attention_masks = torch.tensor(tgt_encodings.get('attention_mask'))
    
    # return (src_ids, src_attention_masks), (tgt_ids, tgt_attention_masks)
    return src_ids, tgt_ids

In [17]:
def create_token_type_ids(src, tgt):
  src_token_type_ids = torch.zeros(src.shape, dtype=torch.int64, device=DEVICE)
  tgt_token_type_ids = torch.zeros(tgt.shape, dtype=torch.int64, device=DEVICE)

  return src_token_type_ids, tgt_token_type_ids

In [18]:
def read_best_result(path):
  with open(path, 'r') as file:
    results = file.read().splitlines()
    results = [list(map(float, result.split())) for result in results]

  return results[-1]

def write_best_result(path, train_loss, valid_loss, is_best=False):
  with open(path, 'r+') as file:
    results = file.read().splitlines()
    results = [list(map(float, result.split())) for result in results]
    results.insert(-1, [train_loss, valid_loss])
    if is_best:
      results.pop(-1)
      results.append([train_loss, valid_loss])
    file.seek(0)
    file.truncate(0)
    results = [' '.join(map(str, result)) for result in results]
    for result in results:
      file.write(result + '\n')

In [19]:
def convert_ids_to_string(tokenizer, ids):
  """
    convert list ids (not tensor) to string
  """
  tokens = tokenizer.convert_ids_to_tokens(ids)
  sentence = " ".join(tokens).replace("@@ ", "").replace("<unk> ", "").replace("<s>", "").replace("</s>", "").strip()
  return sentence

In [20]:
def segment_file(src_path, tgt_path):
  with open(src_path, 'r') as src_file:
    with open(tgt_path, 'w') as tgt_file:
      src = src_file.read().splitlines()
      segmented = []
      for sentence in tqdm(src):
        temp = segment_sentence(sentence)
        segmented.append(temp)
      tgt = "\n".join(segmented)
      tgt_file.write(tgt)

In [21]:
def extract_features(src_sentence, tgt_sentence):
    src_encodings = src_tokenizer.batch_encode_plus([src_sentence])
    src_ids = torch.tensor(src_encodings.get('input_ids'))
    # src_attention_masks = torch.tensor(src_encodings.get('attention_mask'))
    
    tgt_encodings = tgt_tokenizer.batch_encode_plus([tgt_sentence])
    tgt_ids = torch.tensor(tgt_encodings.get('input_ids'))
    # tgt_attention_masks = torch.tensor(tgt_encodings.get('attention_mask'))
    
    return src_ids, tgt_ids

In [22]:
def forward_transformer(model, src_sentence, tgt_sentence, return_tgt=False):
  model.eval()
  with torch.no_grad():
    src, tgt = extract_features(src_sentence, tgt_sentence)
    src = src.to(DEVICE)
    tgt = tgt.to(DEVICE)
    tgt_input = tgt[:, :-1]
    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
    logits = model(src, 
                    tgt_input, 
                    src_mask=src_mask, 
                    tgt_mask=tgt_mask, 
                    src_key_padding_mask=src_padding_mask, 
                    tgt_key_padding_mask=tgt_padding_mask, 
                    memory_key_padding_mask=src_padding_mask)
    if return_tgt:
      return logits, tgt
    return logits

### 6. Train and Valid Function

In [23]:
from tqdm import tqdm

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = CustomDataset(train_src_path, train_tgt_path)
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=2, shuffle=True)        
    cnt = 0
    for src, tgt in tqdm(train_dataloader, desc='Training'):
        cnt += 1
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, 
                       tgt_input, 
                       src_mask=src_mask, 
                       tgt_mask=tgt_mask, 
                       src_key_padding_mask=src_padding_mask, 
                       tgt_key_padding_mask=tgt_padding_mask, 
                       memory_key_padding_mask=src_padding_mask)
        logits = logits.output

        optimizer.optimizer.zero_grad()
        # optimizer.zero_grad()
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()
        optimizer.step()
        losses += loss.item()
        if cnt % 1000 == 0:
            torch.save(model.state_dict(), FOLDER_PATH + "middle.pt")
            torch.save(optimizer.state_dict(), FOLDER_PATH + "optimizer.pt")
    
    torch.save(model.state_dict(), FOLDER_PATH + "middle.pt")
    torch.save(optimizer.state_dict(), FOLDER_PATH + "optimizer.pt")

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = CustomDataset(dev_src_path, dev_tgt_path)
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=2, shuffle=True)

    for src, tgt in tqdm(val_dataloader, desc='Testing '):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, 
                       tgt_input, 
                       src_mask=src_mask, 
                       tgt_mask=tgt_mask,
                       src_key_padding_mask=src_padding_mask, 
                       tgt_key_padding_mask=tgt_padding_mask, 
                       memory_key_padding_mask=src_padding_mask)
        logits = logits.output
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

### 7. Initialize Transformer by Bert

In [None]:
def freeze_layer(layer):
  for p in layer.parameters():
    p.requires_grad = False

In [None]:
def copy_encoder_layer_parameters(encoder_layer, bert_layer, freeze=False):
  encoder_layer.self_attention.query_projection.weight.data.copy_(bert_layer.attention.self.query.weight.data)
  encoder_layer.self_attention.key_projection.weight.data.copy_(bert_layer.attention.self.key.weight.data)
  encoder_layer.self_attention.value_projection.weight.data.copy_(bert_layer.attention.self.value.weight.data)
  # encoder_layer.self_attention.dropout.weight.data.copy_(bert_layer.attention.self.dropout.weight.data)
  
  encoder_layer.self_attention.weight_matrix.weight.data.copy_(bert_layer.attention.output.dense.weight.data)
  # encoder_layer.dropout1.weight.data.copy_(bert_layer.attention.output.dropout.weight.data)
  # encoder_layer.norm1.weight.data.copy_(bert_layer.attention.output.LayerNorm.weight.data)

  encoder_layer.linear1.weight.data.copy_(bert_layer.intermediate.dense.weight.data)
  encoder_layer.linear2.weight.data.copy_(bert_layer.output.dense.weight.data)

  # encoder_layer.dropout2.weight.data.copy_(bert_layer.output.dropout.weight.data)
  # encoder_layer.norm2.weight.data.copy_(bert_layer.output.LayerNorm.weight.data)

  if freeze:
    freeze_layer(encoder_layer.self_attention.query_projection)
    freeze_layer(encoder_layer.self_attention.key_projection)
    freeze_layer(encoder_layer.self_attention.value_projection)
    freeze_layer(encoder_layer.self_attention.weight_matrix)

    freeze_layer(encoder_layer.linear1)
    freeze_layer(encoder_layer.linear2)

In [None]:
def copy_decoder_layer_parameters(decoder_layer, bert_layer, freeze=False):
  decoder_layer.self_attention.query_projection.weight.data.copy_(bert_layer.attention.self.query.weight.data)
  decoder_layer.self_attention.key_projection.weight.data.copy_(bert_layer.attention.self.key.weight.data)
  decoder_layer.self_attention.value_projection.weight.data.copy_(bert_layer.attention.self.value.weight.data)
  # decoder_layer.self_attention.dropout.weight.data.copy_(bert_layer.attention.self.dropout.weight.data)
  
  decoder_layer.self_attention.weight_matrix.weight.data.copy_(bert_layer.attention.output.dense.weight.data)
  # decoder_layer.dropout1.weight.data.copy_(bert_layer.attention.output.dropout.weight.data)
  # decoder_layer.norm1.weight.data.copy_(bert_layer.attention.output.LayerNorm.weight.data)

  decoder_layer.linear1.weight.data.copy_(bert_layer.intermediate.dense.weight.data)
  decoder_layer.linear2.weight.data.copy_(bert_layer.output.dense.weight.data)

  # decoder_layer.dropout2.weight.data.copy_(bert_layer.output.dropout.weight.data)
  # decoder_layer.norm2.weight.data.copy_(bert_layer.output.LayerNorm.weight.data)

  if freeze:
    freeze_layer(decoder_layer.self_attention.query_projection)
    freeze_layer(decoder_layer.self_attention.key_projection)
    freeze_layer(decoder_layer.self_attention.value_projection)
    freeze_layer(decoder_layer.self_attention.weight_matrix)

    freeze_layer(decoder_layer.linear1)
    freeze_layer(decoder_layer.linear2)

In [None]:
def copy_src_embeddings_parameters_from_bert(transformer, src_bert, freeze=False):
  transformer.src_embedding.word_embedding.weight.data.copy_(src_bert.embeddings.word_embeddings.weight.data)
  transformer.src_embedding.position_embedding.weight.data[:512, :].copy_(src_bert.embeddings.position_embeddings.weight.data)
  transformer.src_embedding.token_type_embedding.weight.data.copy_(src_bert.embeddings.token_type_embeddings.weight.data[:1, :])
  # transformer.src_embedding.norm.weight.data.copy_(src_bert.embeddings.LayerNorm.weight.data)
  # transformer.src_embedding.dropout.weight.data.copy_(src_bert.embeddings.dropout.weight.data)

  if freeze:
    freeze_layer(transformer.src_embedding.word_embedding)
    freeze_layer(transformer.src_embedding.position_embedding)
    freeze_layer(transformer.src_embedding.token_type_embedding)
    # freeze_layer(transformer.src_embedding.norm)
    # freeze_layer(transformer.src_embedding.dropout)

In [None]:
def copy_tgt_embeddings_parameters_from_bert(transformer, tgt_bert, freeze=False):
  transformer.tgt_embedding.word_embedding.weight.data.copy_(tgt_bert.embeddings.word_embeddings.weight.data)
  transformer.tgt_embedding.position_embedding.weight.data[:258, :].copy_(tgt_bert.embeddings.position_embeddings.weight.data)
  transformer.tgt_embedding.token_type_embedding.weight.data.copy_(tgt_bert.embeddings.token_type_embeddings.weight.data)
  # transformer.tgt_embedding.norm.weight.data.copy_(tgt_bert.embeddings.LayerNorm.weight.data)
  # transformer.tgt_embedding.dropout.weight.data.copy_(tgt_bert.embeddings.dropout.weight.data)
  

  if freeze:
    freeze_layer(transformer.tgt_embedding.word_embedding)
    freeze_layer(transformer.tgt_embedding.position_embedding)
    freeze_layer(transformer.tgt_embedding.token_type_embedding)
    # freeze_layer(transformer.tgt_embedding.norm)
    # freeze_layer(transformer.tgt_embedding.dropout)

In [None]:
def copy_encoder_parameters_from_bert(encoder, bert, freeze=False):
  number_encoder_layers = len(encoder)
  number_bert_layers = len(bert)
  # if number_encoder_layers != number_bert_layers:
  #   raise RuntimeError(f"number of encoder layers of two models must be equal, but got {number_encoder_layers} and {number_bert_layers}")
  for index in range(number_encoder_layers):
    encoder_layer = encoder[index]
    bert_layer = bert[index]
    copy_encoder_layer_parameters(encoder_layer, bert_layer, freeze)

In [None]:
def copy_decoder_parameters_from_bert(decoder, bert, freeze=False):
  number_decoder_layers = len(decoder)
  number_bert_layers = len(bert)
  # if number_decoder_layers != number_bert_layers:
  #   raise RuntimeError(f"number of decoder layers of two models must be equal, but got {number_decoder_layers} and {number_bert_layers}")
  for index in range(number_decoder_layers):
    decoder_layer = decoder[index]
    bert_layer = bert[index]
    copy_decoder_layer_parameters(decoder_layer, bert_layer, freeze)

In [None]:
def copy_transformer_parameters_from_berts(transformer, src_bert, tgt_bert, freeze=False):
  copy_encoder_parameters_from_bert(transformer.transformer.encoder.layers, src_bert.encoder.layer, freeze)
  copy_decoder_parameters_from_bert(transformer.transformer.decoder.layers, tgt_bert.encoder.layer, freeze)
  copy_src_embeddings_parameters_from_bert(transformer, src_bert, freeze)
  copy_tgt_embeddings_parameters_from_bert(transformer, tgt_bert, freeze)

In [None]:
# # using for BERT Encoder
# copy_encoder_parameters_from_bert(transformer.transformer.encoder.layers, src_bert.encoder.layer, freeze)
# copy_src_embeddings_parameters_from_bert(transformer, src_bert, freeze)

# using for BERT Encoder-Decoder
copy_transformer_parameters_from_berts(transformer, src_bert, tgt_bert, freeze=True)

### 8. Search Algorithms

In [24]:
# function to generate output sequence using greedy algorithm 
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)
    # print(src)
    # print(src_mask)
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        tgt_mask = (generate_square_subsequent_mask(ys.size(1))
                    .type(torch.bool)).to(DEVICE)
        # print(ys.shape, tgt_mask.shape)
        out = model.decode(ys, memory, tgt_mask)
        prob = model.generator(out[:, -1])
        # print(prob)
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == TGT_EOS_ID:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src_encodings = src_tokenizer.batch_encode_plus([src_sentence], padding=True)
    src_ids = torch.tensor(src_encodings.get('input_ids'))
    num_tokens = src_ids.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src_ids, src_mask, max_len=num_tokens*2+5, start_symbol=TGT_BOS_ID).flatten()
    return convert_ids_to_string(tgt_tokenizer, tgt_tokens.tolist())

In [25]:
def beam_search(model, src, src_mask, max_len, start_symbol, num_beams, k, is_train=False):
  max_len = 256 if max_len > 256 else max_len
  src = src.to(DEVICE)
  src_mask = src_mask.to(DEVICE)

  memory = model.encode(src, src_mask)
  ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
  beam_results = [[ys, 0.0]]
  for i in range(max_len-1):
    beam_candidates = []
    is_change = False
    for beam in beam_results:
      ys = beam[0]
      if ys[-1][0] == TGT_EOS_ID:
        beam_candidates.append(beam)
        continue
      else:
        is_change = True
      tgt_mask = (generate_square_subsequent_mask(ys.size(1)).type(torch.bool)).to(DEVICE)
      out = model.decode(ys, memory, tgt_mask)
      prob = model.generator(out[:, -1])
      prob = F.log_softmax(prob, dim=-1)
      topk = torch.topk(prob, k, dim=1)
      indices = topk.indices[0]
      values = topk.values[0]
      for index in range(num_beams):
        ids = torch.cat([beam[0], torch.ones(1, 1).type_as(src.data).fill_(indices[index])], dim=1)
        score = beam[1] - values[index]
        beam_candidates.append([ids, score])

    beam_candidates = sorted(beam_candidates, key=lambda x: x[1])
    beam_results = beam_candidates[:num_beams]
    if not is_change:
      break
  
  return beam_results[0][0]

def beam_translate(model: torch.nn.Module, src_sentence: str, num_beams: int = 4, k: int = 4):
    model.eval()
    src_encodings = src_tokenizer.batch_encode_plus([src_sentence], padding=True)
    src_ids = torch.tensor(src_encodings.get('input_ids'))
    num_tokens = src_ids.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = beam_search(
        model,  src_ids, src_mask, max_len=num_tokens*2+5, start_symbol=TGT_BOS_ID, num_beams=num_beams, k=k).flatten()
    return convert_ids_to_string(tgt_tokenizer, tgt_tokens.tolist())

In [26]:
def init_beam_search(model, src, src_mask, max_len, start_symbol, num_beams):
  """
  init first beam search with BOS symbol and calculate topk candidates
  :param model:
  :param src: (N, S), batch size N = 1
  :param src_mask: (S, S)
  :param max_len: 
  :param start_symbol: BOS
  :param num_beams:
  :return: outputs (num_beams, max_len), memory (N,S,E), log_scores (N, num_beams) 
  """
  src = src.to(DEVICE)
  src_mask = src_mask.to(DEVICE)
  memory = model.encode(src, src_mask)
  batch_size, src_length, hidden_size = memory.shape
  tgt = torch.LongTensor([[start_symbol]]).to(DEVICE)
  tgt_mask = generate_square_subsequent_mask(tgt.size(1)).type(torch.bool)
  outputs = model.decode(tgt, memory, tgt_mask)
  outputs = model.generator(outputs[:, -1])
  log_scores, index = F.log_softmax(outputs, dim=-1).topk(num_beams)
  outputs = torch.zeros((num_beams, max_len), dtype=torch.int32, device=DEVICE)
  outputs[:, 0] = start_symbol
  outputs[:, 1] = index[0]
  memory = memory.expand(num_beams, src_length, hidden_size)

  return outputs, memory, log_scores

In [27]:
def choose_topk(outputs, prob, log_scores, i, num_beams):
  """
  choose topk candidates from kxk candidates
  """
  log_probs, index = F.log_softmax(prob, dim=-1).topk(num_beams)
  # log_scores.transpose(0,1) to add correct element to log_probs
  log_probs = log_probs + log_scores.transpose(0, 1)
  log_probs, k_index = log_probs.view(-1).topk(num_beams)

  # calculate rows, cols becasue log_probs now has shape (num_beams x num_beams)
  rows = torch.div(k_index, num_beams, rounding_mode='floor')
  cols = k_index % num_beams
  outputs[:, :i] = outputs[rows, :i]
  outputs[:, i] = index[rows, cols]
  
  # log_probs has shape (num_beams) -> (1, num_beams)
  log_scores = log_probs.unsqueeze(0)

  return outputs, log_scores

In [28]:
def model_generate(model,
                   tgt: Tensor,
                   memory: Tensor,
                   tgt_mask: Optional[Tensor] = None):
  
  prob = model.decode(tgt, memory, tgt_mask)
  prob = model.generator(prob[:, -1])
  return prob

In [29]:
def beam_generate(model, outputs, memory, log_scores, i, num_beams):
  tgt_mask = generate_square_subsequent_mask(outputs[:, :i].size(1)).type(torch.bool)
  prob = model_generate(model, outputs[:, :i], memory, tgt_mask)
  return choose_topk(outputs, prob, log_scores, i, num_beams)


In [30]:
def beam_search_1(model,src, src_mask, max_len, start_symbol, end_symbol, num_beams):
  max_len = 256 if max_len > 256 else max_len
  chosen_sentence_index = 0
  outputs, memory, log_scores = init_beam_search(model, src, src_mask, max_len, start_symbol, num_beams)
  for i in range(2, max_len):
    tgt_mask = generate_square_subsequent_mask(outputs[:, :i].size(1)).type(torch.bool)
    # prob = model_generate(model, outputs[:, :i], memory, tgt_mask)
    prob = model.decode(outputs[:, :i], memory, tgt_mask[:i, :i])
    prob = model.generator(prob[:, -1])
    outputs, log_scores = choose_topk(outputs, prob, log_scores, i, num_beams)
    # outputs, log_scores = beam_generate(model, outputs, memory, log_scores, i, num_beams)
    finished_sentences = (outputs == end_symbol).nonzero()
    mark_eos = torch.zeros(num_beams, dtype=torch.int64, device=DEVICE)
    num_finished_sentences = 0
    for eos_symbol in finished_sentences:
      sentence_ind, eos_location = eos_symbol
      if mark_eos[sentence_ind] == 0:
        mark_eos[sentence_ind] = eos_location
        num_finished_sentences += 1
    
    if num_finished_sentences == num_beams:
      alpha = 0.7
      division = mark_eos.type_as(log_scores)**alpha
      _, chosen_sentence_index = torch.max(log_scores / division, 1)
      chosen_sentence_index = chosen_sentence_index[0]
      break
  
  sentence_length = (outputs[chosen_sentence_index] == end_symbol).nonzero()
  sentence_length = sentence_length[0] if len(sentence_length) > 0 else -1
  return outputs[chosen_sentence_index][:sentence_length+1]

def beam_translate_1(model: torch.nn.Module, src_sentence: str, num_beams: int = 3):
    model.eval()
    with torch.no_grad():
      src_encodings = src_tokenizer.batch_encode_plus([src_sentence], padding=True)
      src_ids = torch.tensor(src_encodings.get('input_ids'))
      num_tokens = src_ids.shape[1]
      src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
      tgt_tokens = beam_search_1(
          model,  src_ids, src_mask, max_len=int(num_tokens*1.5+5), start_symbol=TGT_BOS_ID, end_symbol=TGT_EOS_ID ,num_beams=num_beams).flatten()
      return convert_ids_to_string(tgt_tokenizer, tgt_tokens.tolist())

In [31]:
from torchtext.data.metrics import bleu_score

def bleu(src_path, tgt_path, model, beam_search=True, num_beams=3, k=3, return_sentence=False):
    model.eval()
    data_iter = CustomDataset(src_path, tgt_path)
    pred_sents = []
    tgt_sents = []
    # search_func = beam_translate_1 if beam_search else translate
    for src, tgt in tqdm(data_iter, desc='Blue score'):
      if beam_search:
        pred_tgt = beam_translate_1(model, src, num_beams=num_beams)
      else:
        pred_tgt = translate(model, src)
      pred_sents.append(pred_tgt)
      tgt_sents.append(tgt)

    translation_sents = [sent.strip().replace('_', ' ').split() for sent in pred_sents]
    target_sents = [[sent.strip().replace('_', ' ').split()] for sent in tgt_sents]
    
    bleu = bleu_score(translation_sents, target_sents)
    if return_sentence:
      return bleu, pred_sents, tgt_sents
    else:
      return bleu

### 9. Training and Evaluating

In [None]:
from timeit import default_timer as timer
NUM_EPOCHS = 2
BATCH_SIZE = 10

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    best_train_lost, best_val_loss = read_best_result(RESULT_PATH)
    is_best = True if val_loss < best_val_loss else False
    write_best_result(RESULT_PATH, train_loss, val_loss, is_best)
    if is_best:
      torch.save(transformer.state_dict(), FOLDER_PATH + "best.pt")

In [None]:
bleu(test_src_path, test_tgt_path, transformer, num_beams=10)

In [48]:
# trained model path
folder = "https://drive.google.com/drive/folders/1HkRLj9iTdUi1pPUk_hU0fXH2BERAsCXf?usp=sharing"
best_bert = 'https://drive.google.com/file/d/1a5-iSc08WdpZmIWmQezBTKSpWI3RoU17/view?usp=sharing'
long_dataset_70000_index = 'https://drive.google.com/file/d/1H0WgrRJxmYuZcw3qoYEd_tGv22lUkvWx/view?usp=sharing'
medical_dataset_70000_index = 'https://drive.google.com/file/d/1FlKCWtemEUfWDEggMD5_2guxtEXEOBVh/view?usp=sharing'
medical_vals = 'https://drive.google.com/file/d/1cciP8LLqUlYddYuGPbxZdOD-VGmsTQdn/view?usp=sharing'
vals = 'https://drive.google.com/file/d/1fBBtd7eYbk8VGk-cy5pMH32oXrqPJQE1/view?usp=sharing'

# please download and move to a folder, enter folder path here
PATH = './'

In [None]:
#load best model
transformer.load_state_dict(torch.load(PATH + 'best-NMT.pt'))

In [41]:
with open(test_src_path, 'r') as file:
  contents = file.read().splitlines()

In [42]:
translated_sents = [beam_translate_1(transformer, sentence) for sentence in contents[:10]] 

In [None]:
translated_sents

In [None]:
import torch
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

### Test Github

In [44]:
# ### for cpu
# !apt install libomp-dev
# !pip install faiss
# ### for gpu
!pip install faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[K     |████████████████████████████████| 85.5 MB 117 kB/s 
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [45]:
import faiss
import numpy as np
from torchtext.data.metrics import bleu_score
from tqdm import tqdm

In [46]:
from Transformer.application.NMT import Datastore, DatastoreBuilder, NMTModel, TranslateMachine, CustomDataset, calculate_bleu_score 

In [47]:
nmt_model = NMTModel(SRC_BOS_ID, SRC_EOS_ID, TGT_BOS_ID, TGT_EOS_ID, src_tokenizer, tgt_tokenizer, config, transformer)

In [None]:
# build datastore and kNN-MT

datastore_builder = DatastoreBuilder(nmt_model, DEVICE)
embeddings_results, vals = datastore_builder.batch_create_features_file(medical_train_src_path, medical_train_tgt_path, batch_size=20, end_index=70000)
np.save(FOLDER_PATH + 'medical_vals', vals)

data_store_length = get_data_store_length(medical_train_src_path, medical_train_tgt_path, end_index=70000)
data_store = Datastore(768, size_value_array=TGT_VOCAB_SIZE, num_centroid=128, nprobe=32)
data_store.build_datastore(embeddings_results)
data_store.save_index(FOLDER_PATH + 'medical_dataset_70000_index')

translate_machine = TranslateMachine(nmt_model, data_store, vals, device=DEVICE)
data_store_length

In [None]:
# load built datastore and kNN-MT for normal dataset

import numpy as np
load_path = PATH + 'long_dataset_70000_index'
val_path = PATH + 'vals.npy'

data_store = Datastore(768, size_value_array=TGT_VOCAB_SIZE, num_centroid=128, nprobe=32, load_file=load_path)
# data_store.build_datastore(embeddings_results)
vals = np.load(val_path)
translate_machine = TranslateMachine(nmt_model, data_store, vals, use_layernorm=False, device=DEVICE)

In [49]:
# load built datastore and kNN-MT for medical dataset

import numpy as np
load_path = PATH + 'medical_dataset_70000_index'
val_path = PATH + 'medical_vals.npy'

data_store = Datastore(768, size_value_array=TGT_VOCAB_SIZE, num_centroid=128, nprobe=32, load_file=load_path)
# data_store.build_datastore(embeddings_results)
vals = np.load(val_path)
translate_machine = TranslateMachine(nmt_model, data_store, vals, device=DEVICE)

In [50]:
txt = "The distribution of the hemorrhage suggests a possible aneurysm of the left middle cerebral artery as a source."

In [None]:
translate_machine.beam_translate(txt, num_knns=64)

In [None]:
calculate_bleu_score(translate_machine, medical_test_src_path, medical_test_tgt_path, num_knns=64, gamma=0.4)

Blue score: 100%|██████████| 574/574 [06:46<00:00,  1.41it/s]


0.26484711562534274