<a href="https://www.kaggle.com/code/bcgggg/pytorch-translation-transformer-ddp-en-zh?scriptVersionId=202244328" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
!pip install torch pandas datasets torchtext spacy
!python -m spacy download zh_core_web_sm
!python -m spacy download en_core_web_sm

import torch
import torchtext

print(torch.__version__)
print(torchtext.__version__)

In [None]:
!rsync -av --exclude='pytorch-translation-transformer-ddp-en-zh' /kaggle/input/* /kaggle/working/input

In [None]:
!rsync -av /kaggle/input/pytorch-translation-transformer-ddp-en-zh/output/* /kaggle/working/output
# !rsync -av /kaggle/input/chianbcg-translation-en-zh/output/* /kaggle/working/output
# !rsync -av /kaggle/input/bcggggg-translation-en-zh/output/* /kaggle/working/output

In [None]:
import os

os.environ['BATCH_SIZE'] = "50"
os.environ['DATA_TRAINING_ITER'] = "2"
os.environ['WORK_DIRECTORY_INPUT'] = '/kaggle/working/input'
# os.environ['EXPORT_ONNX'] = "True"
# os.environ['INFERENCE_ONLY'] = "True"
# os.environ['TEST_ONLY'] = "True"

In [None]:
%%writefile ddp.py
############################### Utils #######################################
import json
from pathlib import Path
from timeit import default_timer as timer

xprint = print
def print(*args, **kwargs):
    xprint(f'[Rank {RANK}]{args[0]}', *args[1:], **kwargs)
    
def saveText(txt: str, filePath, option = {}):
    file = Path(filePath)
    file.parent.mkdir(parents=True, exist_ok=True)
    file.write_text(txt)
    del file
    print(f"Saved text: {filePath}{(', content:' + txt) if option.get('printContent', False) else ''}")

def saveObject(obj: any, filePath):
    torch.save(obj, f'{filePath}')
    del obj
    print(f'Saved object: {filePath}')
    
def saveJSON(dicts: dict, filePath, option = {}):
    saveText(json.dumps(dicts), filePath, option)
    
def readText(filePath):
    text = ''
    if os.path.exists(filePath):
      with open(filePath, 'r', encoding="utf-8") as f:
        text = f.read()
    return text
    
def readJSON(filePath):
    status = {}
    if os.path.exists(filePath):
      with open(filePath, 'r', encoding="utf-8") as f:
        status = json.load(f)
    return status

def appendText(filePath: str, text: str):
    saveText(f'{readText(filePath)}{text}\n', filePath)

############################### Prepare hyperparameters and model #######################################
import os
import torch
import torch.distributed as dist
import pandas as pd

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

SRC_LANGUAGE = os.getenv("SRC_LANGUAGE", 'en')
TGT_LANGUAGE = os.getenv("TGT_LANGUAGE", 'zh')
WORK_DIRECTORY_INPUT = os.getenv("WORK_DIRECTORY_INPUT", '/kaggle/input/pytorch-translation-transformer-ddp-en-zh/input')
WORK_DIRECTORY_OUTPUT = os.getenv("WORK_DIRECTORY_OUTPUT", '/kaggle/working/output')

BATCH_SIZE = int(os.getenv("BATCH_SIZE", 128))
NUM_EPOCHS = int(os.getenv("NUM_EPOCHS", 1))
DATA_LEN = int(os.getenv("DATA_LEN", '0'))
DATA_LEN_VAL_PERT = float(os.getenv("DATA_LEN_VAL_PERT", 0.001))
DATA_TRAINING_ITER = int(os.getenv("DATA_TRAINING_ITER", '1'))
VOCAB_FILE = f'{WORK_DIRECTORY_INPUT}/vocab-10m/vocab-9999621-20-72696-82723.pt'
VOCAB_MIN_FREQ = int(os.getenv("VOCAB_MIN_FREQ", 40))

NHEAD = int(os.getenv("NHEAD", 8))
EMB_SIZE = int(os.getenv("EMB_SIZE", 512))
FFN_HID_DIM = int(os.getenv("FFN_HID_DIM", 2048))
NUM_ENCODER_LAYERS = int(os.getenv("NUM_ENCODER_LAYERS", 6))
NUM_DECODER_LAYERS = int(os.getenv("NUM_DECODER_LAYERS", 6))
DROPOUT_RATE = float(os.getenv("DROPOUT_RATE", 0.1))

TEST_ONLY = os.getenv("TEST_ONLY", 'False') == 'True'
INFERENCE_ONLY = os.getenv("INFERENCE_ONLY", 'False') == 'True'
EXPORT_ONNX = os.getenv("EXPORT_ONNX", 'False') == 'True'
GRAD_ACCUM_ITER = int(os.getenv("GRAD_ACCUM_ITER", '4'))
EARLY_STOPING_PATIENCE = int(os.getenv("EARLY_STOPING_PATIENCE", 6))
RANK = int(os.getenv("RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))

GPU_COUNT = torch.cuda.device_count()
DDP_ENABLED = WORLD_SIZE > 1
AMP_ENABLED = GPU_COUNT > 0
DEVICE = torch.device(f'cuda:{LOCAL_RANK % GPU_COUNT}' if GPU_COUNT > 0 else 'cpu')

DATA_TRAINING_ITER_INDEX = (readJSON(f'{WORK_DIRECTORY_OUTPUT}/train/params.json').get("DATA_TRAINING_ITER_INDEX", -1) + 1) % DATA_TRAINING_ITER
IS_DATA_TRAINING_ITER_MODE = (DATA_TRAINING_ITER_INDEX + 1 < DATA_TRAINING_ITER and not INFERENCE_ONLY)

if DATA_TRAINING_ITER > 1:
    NUM_EPOCHS = 1
    print("Fallback NUM_EPOCHES to 1 when DATA_TRAINING_ITER is gt than 1")

print(f'DDP_ENABLED: {DDP_ENABLED}, WORLD_SIZE: {WORLD_SIZE}, RANK: {RANK}, LOCAL_RANK: {LOCAL_RANK}, GPU_COUNT: {GPU_COUNT}, DEVICE: {DEVICE}')

if not INFERENCE_ONLY:
    ############################### Prepare dataset #######################################
    df_train_6m_origin = pd.read_json(f'{WORK_DIRECTORY_INPUT}/en-zh-6m/translation2019zh_train.json', lines=True)
    df_valid_6m_origin = pd.read_json(f'{WORK_DIRECTORY_INPUT}/en-zh-6m/translation2019zh_valid.json', lines=True)
    df_origin_6m = pd.concat([df_train_6m_origin, df_valid_6m_origin], ignore_index=True)
    df_origin_6m = df_origin_6m.rename(columns={"english": "en", "chinese": "zh"})
    df_train_10m_origin_en = pd.read_csv(f'{WORK_DIRECTORY_INPUT}/en-zh-10m/train.en', engine='python', sep='\r\t', header=None, keep_default_na=False, names=['en'])
    df_train_10m_origin_zh = pd.read_csv(f'{WORK_DIRECTORY_INPUT}/en-zh-10m/train.zh', engine='python', sep='\r\t', header=None, keep_default_na=False, names=['zh'])
    df_origin_10m = pd.concat([df_train_10m_origin_en, df_train_10m_origin_zh], axis=1)
    df_origin = pd.concat([df_origin_6m, df_origin_10m], ignore_index=True)
    df = df_origin[df_origin['en'].map(len) <= 300].reset_index(drop=True)

    if RANK == 0:
        print(df)
    
    if DATA_LEN == 0:
        DATA_LEN = df.shape[0]

    num_val_sample = int(DATA_LEN_VAL_PERT * DATA_LEN)
    num_train_samples = DATA_LEN - 2 * num_val_sample
    num_train_iter_samples = int(num_train_samples / DATA_TRAINING_ITER)

    train_ds = df[num_train_iter_samples * DATA_TRAINING_ITER_INDEX: num_train_iter_samples * (DATA_TRAINING_ITER_INDEX + 1) if IS_DATA_TRAINING_ITER_MODE else num_train_samples]
    val_ds = df[num_train_samples: num_train_samples + num_val_sample]
    test_ds = df[num_train_samples + num_val_sample: num_train_samples + 2 * num_val_sample]

    print(f'Train shape: {train_ds.shape}')
    if RANK == 0:
        print(train_ds)
    print(f'  Val shape: {val_ds.shape}')
    print(f' Test shape: {test_ds.shape}')
    del df_train_6m_origin, df_valid_6m_origin, df_origin_6m, df_train_10m_origin_en, df_train_10m_origin_zh, df_origin_10m, df_origin, df


############################### Prepare vocabulary #######################################
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

token_transform = {
 'en': get_tokenizer('spacy', language='en_core_web_sm'),
 'zh': get_tokenizer('spacy', language='zh_core_web_sm')
}

vocab_transform = {}

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

if os.path.exists(VOCAB_FILE):
    vocab_transform = torch.load(VOCAB_FILE)
    print(f'Loaded voab file: {VOCAB_FILE}')

else :
    for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
        vocab_transform[ln] = build_vocab_from_iterator(train_ds[ln].transform(lambda x: token_transform[ln](x)),
                              min_freq=VOCAB_MIN_FREQ,
                              specials=special_symbols,
                              special_first=True)
        # Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
        # If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
        vocab_transform[ln].set_default_index(UNK_IDX)
    torch.save(vocab_transform, VOCAB_FILE)
    print(f'Saved voab file: {VOCAB_FILE}')

# if RANK == 0:
#     for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
#         print(f'Vocab {ln}: {dict(enumerate(vocab_transform[ln].get_itos()))}')

############################### Prepare model #######################################

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int,
                 dropout: float):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                tgt: Tensor):
        
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt, DEVICE)
        
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, src_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)
    
############################### Prepare masking #######################################
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt, device):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]
    
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len).to(device)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool).to(device)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


############################### Prepare training dataset #######################################
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from typing import List

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

class PandasDataset(Dataset):
  def __init__(self, dataframe):
    self.dataframe = dataframe

  def __getitem__(self, index):
    feature = self.dataframe[index: index + 1][SRC_LANGUAGE].to_numpy()[0]
    label = self.dataframe[index: index + 1][TGT_LANGUAGE].to_numpy()[0]
    return feature, label

  def __len__(self):
    return len(self.dataframe)


############################### Prepare hyperparameters and model #######################################
import torch.distributed as dist

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, 
                                 NUM_DECODER_LAYERS, 
                                 EMB_SIZE, 
                                 NHEAD, 
                                 SRC_VOCAB_SIZE, 
                                 TGT_VOCAB_SIZE, 
                                 FFN_HID_DIM, 
                                 DROPOUT_RATE)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
scaler = torch.cuda.amp.GradScaler(enabled=AMP_ENABLED)


############################### Loading weights #######################################
from datetime import datetime
import time
import os.path
import pytz
import json

getNowInLocal = lambda : str(datetime.now(pytz.timezone('Asia/Shanghai')))

val_loss_prev = float('inf')
val_loss_increased_count = 0

weightFoler = 'best' if DATA_TRAINING_ITER_INDEX == 0 else 'train'
weightsFilePath =  f'{WORK_DIRECTORY_OUTPUT}/{weightFoler}/model_weights.pth' 
if os.path.exists(weightsFilePath):
    checkpoint = torch.load(weightsFilePath, map_location=torch.device(DEVICE))
    transformer.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scaler.load_state_dict(checkpoint['scaler'])
    val_loss = readJSON(f'{WORK_DIRECTORY_OUTPUT}/{weightFoler}/status.json').get('val_loss_prev', float('inf'))
    print(f'Loaded weights file, type: {weightFoler}, val_loss: {val_loss}, file: {weightsFilePath}.')
    del weightsFilePath, checkpoint

statusFilePath =  f'{WORK_DIRECTORY_OUTPUT}/best/status.json'
if os.path.exists(statusFilePath):
    status = readJSON(statusFilePath)
    val_loss_prev = status.get('val_loss_prev', float('inf'))
    print(f'Loaded status file from {statusFilePath}, val_loss_prev: {val_loss_prev}.')
    del statusFilePath, status


############################### ENABLE DDP #######################################
if DDP_ENABLED:
    dist.init_process_group(backend=("nccl" if GPU_COUNT > 0 else 'gloo'))
    transformer = torch.nn.parallel.DistributedDataParallel(transformer.to(DEVICE), device_ids=([LOCAL_RANK] if DEVICE.type != 'cpu' else None), output_device=(LOCAL_RANK if DEVICE.type != 'cpu' else None))
    
    if not INFERENCE_ONLY:
        train_ds_sampler =  torch.utils.data.distributed.DistributedSampler(PandasDataset(train_ds))
        val_ds_sampler =  torch.utils.data.distributed.DistributedSampler(PandasDataset(val_ds))

############################### Prepare traning and validation epoch #######################################
if not INFERENCE_ONLY and not TEST_ONLY:

    from torch.utils.data import DataLoader
    import math

    def train_epoch(model, optimizer, epoch, scaler):
        if not DEVICE.type == 'cpu':
            torch.cuda.reset_peak_memory_stats(DEVICE)
        model.train()
        losses = 0
        losses_ga = 0
        batch = 0
        train_dataloader = DataLoader(PandasDataset(train_ds), batch_size=BATCH_SIZE, collate_fn=collate_fn, sampler=train_ds_sampler if DDP_ENABLED else None)
        if DDP_ENABLED:
            train_dataloader.sampler.set_epoch(epoch)

        for src, tgt in train_dataloader:
            batch += 1
            start_time = timer()
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)

            tgt_input = tgt[:-1, :]

            with torch.autocast(device_type=DEVICE.type, enabled=AMP_ENABLED):
                logits = model(src, tgt_input)

            tgt_out = tgt[1:, :]
            loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            loss /= GRAD_ACCUM_ITER
            losses_ga += loss.item()

            scaler.scale(loss).backward()

            if (batch % GRAD_ACCUM_ITER == 0) or (batch == len(train_dataloader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                losses += losses_ga * GRAD_ACCUM_ITER

            end_time = timer()

            if RANK == 0 and ((batch % GRAD_ACCUM_ITER == 0) or (batch == len(train_dataloader))):
                print(f'[Tra]Epoch: {epoch}, Batch: {batch}/{len(train_dataloader)}, Data: {((batch - 1) * BATCH_SIZE + src.shape[1])}/{math.floor(train_ds.shape[0] / WORLD_SIZE)}, Train loss: {losses_ga:.3f}, ETA: {((len(train_dataloader) - batch) * (end_time - start_time)):.3f}s, Memory allocated: {torch.cuda.memory_allocated(DEVICE) // (1024 ** 3)} GB/{torch.cuda.max_memory_allocated(DEVICE) // (1024 ** 3)} GB, Datetime:{getNowInLocal()}', end="\r" if batch != len(train_dataloader) else '\n', flush=False if batch == 1 else True)

            if (batch % GRAD_ACCUM_ITER == 0) or (batch == len(train_dataloader)):
                losses_ga = 0

            del src, tgt, start_time, tgt_input, logits, tgt_out, loss, end_time

        return losses / len(train_dataloader)

    def evaluate(model, epoch):
        model.eval()
        losses = 0
        losses_ga = 0
        batch = 0
        val_dataloader = DataLoader(PandasDataset(val_ds), batch_size=BATCH_SIZE, collate_fn=collate_fn, sampler=val_ds_sampler if DDP_ENABLED else None)
        if DDP_ENABLED:
            val_dataloader.sampler.set_epoch(epoch)

        for src, tgt in val_dataloader:
            if not DEVICE.type == 'cpu':
                torch.cuda.reset_peak_memory_stats(DEVICE)
            batch += 1
            start_time = timer()
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)

            tgt_input = tgt[:-1, :]

            with torch.autocast(device_type=DEVICE.type, enabled=AMP_ENABLED):
                logits = model(src, tgt_input)

            tgt_out = tgt[1:, :]
            loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            loss /= GRAD_ACCUM_ITER
            losses_ga += loss.item()
            end_time = timer()

            if (batch % GRAD_ACCUM_ITER == 0) or (batch == len(val_dataloader)):
                losses += losses_ga * GRAD_ACCUM_ITER

            if RANK == 0 and ((batch % GRAD_ACCUM_ITER == 0) or (batch == len(val_dataloader))):
                print(f'[Val]Epoch: {epoch}, Batch: {batch}/{len(val_dataloader)}, Data: {((batch - 1) * BATCH_SIZE + src.shape[1])}/{math.floor(val_ds.shape[0] / WORLD_SIZE)}, Train loss: {losses_ga:.3f}, ETA: {((len(val_dataloader) - batch) * (end_time - start_time)):.3f}s, Memory allocated: {torch.cuda.memory_allocated(DEVICE) // (1024 ** 3)} GB/{torch.cuda.max_memory_allocated(DEVICE) // (1024 ** 3)} GB, Datetime:{getNowInLocal()}', end="\r" if batch != len(val_dataloader) else '\n', flush=False if batch == 1 else True)

            if (batch % GRAD_ACCUM_ITER == 0) or (batch == len(val_dataloader)):
                losses_ga = 0

            del src, tgt, start_time, tgt_input, logits, tgt_out, loss, end_time

        return losses / len(val_dataloader)


    ############################### Training model #######################################
    from timeit import default_timer as timer
    import os.path
    import shutil
    
    for epoch in range(1, NUM_EPOCHS+1):
        start_time = timer()
        train_loss = train_epoch(transformer, optimizer, epoch, scaler)
        end_time = timer()
        start_time_val = timer()
        val_loss = evaluate(transformer, epoch)
        end_time_val = timer()

        if RANK == 0:
            log = f"[Sum]Epoch: {epoch}, Data: {train_ds.shape[0] / WORLD_SIZE}, DATA_TRAINING_ITER: {DATA_TRAINING_ITER_INDEX + 1}/{DATA_TRAINING_ITER}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s, Val Epoch time = {(end_time_val - start_time_val):.3f}s, Memory allocated: {torch.cuda.memory_allocated(DEVICE) // (1024 ** 3)} GB/{torch.cuda.max_memory_allocated(DEVICE) // (1024 ** 3)} GB, Datetime:{getNowInLocal()}"
            print(log)
            pathOfTrainlogFile = f"{WORK_DIRECTORY_OUTPUT}/train/history.txt"
            pathOfBestlogFile = f"{WORK_DIRECTORY_OUTPUT}/best/history.txt"
            if DATA_TRAINING_ITER_INDEX == 0:
                appendText(pathOfTrainlogFile, f'{readText(pathOfBestlogFile)}{log}')
            else:
                appendText(pathOfTrainlogFile, log)
            
            weights = {
                "model": transformer.module.state_dict() if DDP_ENABLED else transformer.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scaler": scaler.state_dict()
            }
            saveObject(weights, f'{WORK_DIRECTORY_OUTPUT}/train/model_weights.pth')
            
            status = {
                "val_loss_prev": val_loss, 
                "datatime": time.ctime()
            }
            saveJSON(status, f'{WORK_DIRECTORY_OUTPUT}/train/status.json', { "printContent": True })   

            params = {
                "GRAD_ACCUM_ITER":  GRAD_ACCUM_ITER,
                "NUM_EPOCHS": NUM_EPOCHS,
                "DATA_LEN": DATA_LEN,
                "VOCAB_FILE": VOCAB_FILE,
                "VOCAB_MIN_FREQ": VOCAB_MIN_FREQ,
                "EMB_SIZE": EMB_SIZE,
                "NHEAD": NHEAD,
                "FFN_HID_DIM": FFN_HID_DIM,
                "BATCH_SIZE": BATCH_SIZE,
                "NUM_ENCODER_LAYERS": NUM_ENCODER_LAYERS,
                "NUM_DECODER_LAYERS":NUM_DECODER_LAYERS,
                "EARLY_STOPING_PATIENCE": EARLY_STOPING_PATIENCE,
                "WORLD_SIZE": WORLD_SIZE,
                "GPU_COUNT": GPU_COUNT,
                "DDP_ENABLED": DDP_ENABLED,
                "DATA_TRAINING_ITER": DATA_TRAINING_ITER,
                "DATA_TRAINING_ITER_INDEX": DATA_TRAINING_ITER_INDEX
            }
            saveJSON(params, f'{WORK_DIRECTORY_OUTPUT}/train/params.json')

            if not IS_DATA_TRAINING_ITER_MODE and val_loss < val_loss_prev:
                shutil.rmtree(f'{WORK_DIRECTORY_OUTPUT}/best', ignore_errors=True)
                os.rename(f'{WORK_DIRECTORY_OUTPUT}/train', f'{WORK_DIRECTORY_OUTPUT}/best')

        if not IS_DATA_TRAINING_ITER_MODE:
            if val_loss < val_loss_prev:
                val_loss_increased_count = 0
                val_loss_prev = val_loss

            else:
                val_loss_increased_count += 1

            if val_loss_increased_count == EARLY_STOPING_PATIENCE:
              print(f'Stop training as val_loss is increased {EARLY_STOPING_PATIENCE} times.')
              break
            
if EXPORT_ONNX and not IS_DATA_TRAINING_ITER_MODE:
    src = text_transform[SRC_LANGUAGE]('Hello world!').view(-1, 1)
    tgt = text_transform[TGT_LANGUAGE]('<bos>').view(-1, 1)
    torch.onnx.export(transformer.module if DDP_ENABLED else transformer, 
                      (src, tgt),
                      f"{WORK_DIRECTORY_OUTPUT}/pytorch_translation_{SRC_LANGUAGE}_{TGT_LANGUAGE}.onnx", 
                      verbose=False,
                      opset_version=19,
                      input_names=['src', 'tgt'], 
                      output_names=['output'],
                      dynamic_axes={
                        "src": [0, 1],
                        "tgt": [0, 1],
                        "output": [0, 1]
                      })

############################### Prediction #######################################
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.tensor([[start_symbol]]).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys, torch.tensor([[next_word]]).to(DEVICE)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys
 

# actual function to translate input sentence into target language
def translate(src_sentence: str):
    transformer.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        transformer.module if DDP_ENABLED else transformer,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return ("" if TGT_LANGUAGE == 'zh' else " ").join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

if not INFERENCE_ONLY and not IS_DATA_TRAINING_ITER_MODE:
    test_ds_start = RANK * test_ds.shape[0] // WORLD_SIZE
    test_ds_end = (RANK + 1) * test_ds.shape[0] // WORLD_SIZE
    print(f'Test size: {test_ds_start} - {test_ds_end}')

    if DDP_ENABLED:
        torch.distributed.barrier()

    for index, row in test_ds[test_ds_start: test_ds_end].iterrows():
        print(f"""{row[SRC_LANGUAGE]}
        {row[TGT_LANGUAGE]}
        {translate(row[SRC_LANGUAGE])}""")

del print

In [None]:
!torchrun --standalone --nnodes=1 --nproc-per-node=2 ddp.py

In [None]:
onnxFileName = f'/kaggle/working/output/pytorch_translation_{os.getenv("SRC_LANGUAGE", "en")}_{os.getenv("TGT_LANGUAGE", "en")}.onnx'
if os.path.exists(onnxFileName):
#     !pip install onnx onnxruntime
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # Load the ONNX model
    model = onnx.load(onnxFileName)

    # Check that the model is well formed
    onnx.checker.check_model(model)

    # Print a human readable representation of the graph
#     print(onnx.helper.printable_graph(model.graph))

    ort_session = ort.InferenceSession(onnxFileName)

    outputs = ort_session.run(
        None,
        {'src': [[    2],
            [12987],
            [  143],
            [ 1898],
            [    3]], 'tgt': [[2], [0], [0], [0], [0]]},
    )
    print(outputs[0])