In [263]:
!pip install pyvi torchsummary



In [264]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import spacy
import torchtext
import tqdm
import random
from spacy.lang.vi import Vietnamese
from spacy.lang.en import English
from torch.utils.data import Dataset, random_split
from torchtext.vocab import build_vocab_from_iterator
from torchsummary import summary

In [265]:
seed = 1234

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [266]:
def load_data(path):
    data = []
    with open(path,'r') as file:
        for line in file.readlines():
            splitted_line = line.split('\t')
            eng = splitted_line[0]
            vi = splitted_line[1]
            data.append({'vi':vi, 
                         'en':eng})
    return data

In [267]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        return self.data[index]

In [268]:
dataset = CustomDataset(load_data('/kaggle/input/languagedata/data/vie.txt'))

In [269]:
#7:2:1
total_samples = len(dataset)
train_size = int(0.8 * total_samples)
val_size = int(0.1 * total_samples)
test_size = total_samples - train_size - val_size

In [270]:
train_data, valid_data, test_data = random_split(dataset, [train_size, val_size, test_size])
print("Số lượng mẫu trong tập train:", len(train_data))
print("Số lượng mẫu trong tập validation:", len(valid_data))
print("Số lượng mẫu trong tập test:", len(test_data))

Số lượng mẫu trong tập train: 7542
Số lượng mẫu trong tập validation: 942
Số lượng mẫu trong tập test: 944


In [271]:
train_data[0]

{'vi': 'Bạn thực sự muốn mặc cái đó sao?',
 'en': 'Do you really want to wear that?'}

In [272]:
en_nlp = English()
vi_nlp = Vietnamese()

In [273]:
string = "What a lovely day it is today!"
[token.text for token in en_nlp.tokenizer(string)]

['What', 'a', 'lovely', 'day', 'it', 'is', 'today', '!']

In [274]:
def tokenize_example(example, en_nlp, vi_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example["en"])][:max_length]
    vi_tokens = [token.text for token in vi_nlp.tokenizer(example["vi"])][:max_length]
    if lower:
        en_tokens = [token.lower() for token in en_tokens]
        vi_tokens = [token.lower() for token in vi_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    vi_tokens = [sos_token] + vi_tokens + [eos_token]
    example["en_tokens"] = en_tokens
    example["vi_tokens"] = vi_tokens
    return example

In [275]:
max_length = 50
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "vi_nlp": vi_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}
train_data = [tokenize_example(example, **fn_kwargs) for example in train_data]
valid_data = [tokenize_example(example, **fn_kwargs) for example in valid_data]
test_data = [tokenize_example(example, **fn_kwargs) for example in test_data]

In [276]:
train_data[0]

{'vi': 'Bạn thực sự muốn mặc cái đó sao?',
 'en': 'Do you really want to wear that?',
 'en_tokens': ['<sos>',
  'do',
  'you',
  'really',
  'want',
  'to',
  'wear',
  'that',
  '?',
  '<eos>'],
 'vi_tokens': ['<sos>',
  'bạn',
  'thực sự',
  'muốn',
  'mặc',
  'cái',
  'đó',
  'sao',
  '?',
  '<eos>']}

In [277]:
def yield_tokens(data,s):
    for dct in data:
        yield dct[s]

In [278]:
min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"

special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token,
]

en_vocab = torchtext.vocab.build_vocab_from_iterator(
    yield_tokens(train_data,'en_tokens'),
    min_freq=min_freq,
    specials=special_tokens,
)

vi_vocab = torchtext.vocab.build_vocab_from_iterator(
    yield_tokens(train_data,'vi_tokens'),
    min_freq=min_freq,
    specials=special_tokens,
)

In [279]:
assert en_vocab[unk_token] == vi_vocab[unk_token]
assert en_vocab[pad_token] == vi_vocab[pad_token]

unk_index = en_vocab[unk_token]
pad_index = en_vocab[pad_token]

In [280]:
en_vocab.set_default_index(unk_index)
vi_vocab.set_default_index(unk_index)

In [281]:
tokens = ["i", "love", "watching", "crime", "shows"]
en_vocab.lookup_indices(tokens)

[5, 173, 509, 0, 0]

In [282]:
en_vocab.lookup_tokens(en_vocab.lookup_indices(tokens))

['i', 'love', 'watching', '<unk>', '<unk>']

In [283]:
def numericalize_example(example, en_vocab, vi_vocab):
    en_ids = en_vocab.lookup_indices(example["en_tokens"])
    vi_ids = vi_vocab.lookup_indices(example["vi_tokens"])
    example["en_ids"] = en_ids
    example["vi_ids"] = vi_ids
    return example

In [284]:
fn_kwargs = {"en_vocab": en_vocab, "vi_vocab": vi_vocab}
train_data = [numericalize_example(example, **fn_kwargs) for example in train_data]
valid_data = [numericalize_example(example, **fn_kwargs) for example in valid_data]
test_data = [numericalize_example(example, **fn_kwargs) for example in test_data]

In [285]:
train_data[0]

{'vi': 'Bạn thực sự muốn mặc cái đó sao?',
 'en': 'Do you really want to wear that?',
 'en_tokens': ['<sos>',
  'do',
  'you',
  'really',
  'want',
  'to',
  'wear',
  'that',
  '?',
  '<eos>'],
 'vi_tokens': ['<sos>',
  'bạn',
  'thực sự',
  'muốn',
  'mặc',
  'cái',
  'đó',
  'sao',
  '?',
  '<eos>'],
 'en_ids': [2, 14, 8, 88, 37, 6, 431, 15, 10, 3],
 'vi_ids': [2, 8, 184, 30, 281, 34, 15, 97, 11, 3]}

In [286]:
en_vocab.lookup_tokens(train_data[0]["en_ids"])

['<sos>', 'do', 'you', 'really', 'want', 'to', 'wear', 'that', '?', '<eos>']

In [287]:
def to_tensor(example):
    example['en_ids'] = torch.tensor(np.array(example['en_ids']), dtype=torch.int64)
    example['vi_ids'] = torch.tensor(np.array(example['vi_ids']), dtype=torch.int64)
    return example

In [288]:
train_data = [to_tensor(example) for example in train_data]
valid_data = [to_tensor(example) for example in valid_data]
test_data = [to_tensor(example) for example in test_data]

In [289]:
def get_collate_fn(pad_index, max_length):
    def collate_fn(batch):
        batch_en_ids = []
        batch_vi_ids = []
#         batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
#         batch_vi_ids = nn.utils.rnn.pad_sequence(batch_vi_ids, padding_value=pad_index)
        for example in batch:
            en_ids = example["en_ids"]
            vi_ids = example["vi_ids"]
            if len(en_ids) > max_length:
                en_ids = en_ids[:max_length]
            else:
                en_ids = torch.cat((en_ids, torch.tensor([pad_index] * (max_length - len(en_ids)))))
            if len(vi_ids) > max_length:
                vi_ids = vi_ids[:max_length]
            else:
                vi_ids = torch.cat((vi_ids, torch.tensor([pad_index] * (max_length - len(vi_ids)))))
            assert len(en_ids) == max_length
            assert len(vi_ids) == max_length
            batch_en_ids.append(en_ids)
            batch_vi_ids.append(vi_ids)
        batch = {
            "en_ids": torch.stack(batch_en_ids),
            "vi_ids": torch.stack(batch_vi_ids),
        }
        return batch

    return collate_fn

In [290]:
def get_data_loader(dataset, batch_size, pad_index, max_length, shuffle=False):
    collate_fn = get_collate_fn(pad_index, max_length)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    return data_loader

In [291]:
batch_size = 128
max_length = 50
train_data_loader = get_data_loader(train_data, batch_size, pad_index,max_length, shuffle=True)
valid_data_loader = get_data_loader(valid_data, batch_size, pad_index, max_length)
test_data_loader = get_data_loader(test_data, batch_size, pad_index, max_length)

In [292]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_length, device='cpu'):
        super().__init__()
        self.device = device
        self.word_emb = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim)
        self.pos_emb = nn.Embedding(
            num_embeddings=max_length,
            embedding_dim=embed_dim
        )

    def forward(self, x):
        N, seq_len = x.size()
        positions = torch.arange(0, seq_len).expand(N, seq_len).to(self.device)
        output1 = self.word_emb(x)
        output2 = self.pos_emb(positions)
        output =  output1 + output2
        return output

In [293]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim = embed_dim,
            num_heads = num_heads,
            batch_first = True
        )
        self.ffn = nn.Sequential(
            nn.Linear(in_features=embed_dim, out_features=ff_dim, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=ff_dim, out_features=embed_dim, bias=True)
        )
        self.layernorm_1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.layernorm_2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.dropout_1 = nn.Dropout(p=dropout)
        self.dropout_2 = nn.Dropout(p=dropout)

    def forward(self, query, key, value):
        attn_output, _ = self.attn(query, key, value)
        attn_output = self.dropout_1(attn_output)
        out_1 = self.layernorm_1(query + attn_output)
        ffn_output = self.ffn(out_1)
        ffn_output = self.dropout_2(ffn_output)
        out_2 = self.layernorm_2(out_1 + ffn_output)
        return out_2

In [294]:
class TransformerEncoder(nn.Module):
    def __init__(self, src_vocab_size, embed_dim, max_length, num_layers, num_heads, ff_dim, dropout=0.1, device='cpu'):
        super().__init__()
        self.embedding = TokenAndPositionEmbedding(src_vocab_size, embed_dim, max_length, device)
        self.layers = nn.ModuleList(
            [
                TransformerEncoderBlock(embed_dim, num_heads, ff_dim, dropout) for i in range(num_layers)
            ]
        )

    def forward(self, x):
        output = self.embedding(x)
        for layer in self.layers:
            output = layer(output, output, output)
        return output
     

In [295]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim = embed_dim,
            num_heads = num_heads,
            batch_first = True
        )
        self.cross_attn = nn.MultiheadAttention(
            embed_dim = embed_dim,
            num_heads = num_heads,
            batch_first = True
        )
        self.ffn = nn.Sequential(
            nn.Linear(in_features=embed_dim, out_features=ff_dim, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=ff_dim, out_features=embed_dim, bias=True)
        )
        self.layernorm_1  = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.layernorm_2  = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.layernorm_3  = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.dropout_1 = nn.Dropout(p=dropout)
        self.dropout_2 = nn.Dropout(p=dropout)
        self.dropout_3 = nn.Dropout(p=dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output, _ = self.attn(x, x, x, attn_mask=tgt_mask)
        attn_output = self.dropout_1(attn_output)
        out_1 = self.layernorm_1(x + attn_output)
        attn_output, _ = self.cross_attn(out_1, enc_output, enc_output)
        attn_output = self.dropout_2(attn_output)
        out_2 = self.layernorm_2(out_1 + attn_output)
        ffn_output = self.ffn(out_2)
        ffn_output = self.dropout_3(ffn_output)
        out_3 = self.layernorm_3(out_2 + ffn_output)
        return out_3

In [296]:
class TransformerDecoder(nn.Module):
    def __init__(self, tgt_vocab_size, embed_dim, max_length, num_layers, num_aheads, ff_dim, dropout=0.1, device='cpu'):
        super().__init__()
        self.embedding = TokenAndPositionEmbedding(tgt_vocab_size, embed_dim, max_length, device)
        self.layers = nn.ModuleList(
            [
                TransformerDecoderBlock(embed_dim, num_heads, ff_dim, dropout) for i in range(num_layers)
            ]
        )

    def forward(self, x, enc_output, src_mask, tgt_mask):
        output = self.embedding(x)
        for layer in self.layers:
            output = layer(output, enc_output, src_mask, tgt_mask)
        return output

In [297]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim, max_length, num_layers, num_heads, ff_dim, dropout=0.1, device='cpu'):
        super().__init__()
        self.device = device
        self.encoder = TransformerEncoder(src_vocab_size, embed_dim, max_length, num_layers, num_heads, ff_dim, dropout, device)
        self.decoder = TransformerDecoder(tgt_vocab_size, embed_dim, max_length, num_layers, num_heads, ff_dim, dropout, device)
        self.fc = nn.Linear(embed_dim, tgt_vocab_size)

    def generate_mask(self, src, tgt):
        src_seq_len = src.shape[1]
        tgt_seq_len = tgt.shape[1]

        src_mask = torch.zeros((src_seq_len, src_seq_len), device=self.device).type(torch.bool)
        tgt_mask = (torch.triu(torch.ones((tgt_seq_len, tgt_seq_len), device=self.device).type(torch.bool)) == 1).transpose(0,1)
        tgt_mask = tgt_mask.float().masked_fill(tgt_mask == 0, float('-inf')).masked_fill(tgt_mask == 1, float(0.0))
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        enc_output = self.encoder(src)
        dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        output = self.fc(dec_output)
        return output

In [376]:
src_vocab_size = len(en_vocab)
tgt_vocab_size = len(vi_vocab)
embed_dim = 128
max_length = 50
num_layers = 5
num_heads = 2
ff_dim = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dropout = 0.3
model = Transformer(src_vocab_size, tgt_vocab_size, embed_dim, max_length, num_layers, num_heads, ff_dim, dropout, device).to(device)

In [377]:
print(model)

Transformer(
  (encoder): TransformerEncoder(
    (embedding): TokenAndPositionEmbedding(
      (word_emb): Embedding(2187, 128)
      (pos_emb): Embedding(50, 128)
    )
    (layers): ModuleList(
      (0-4): 5 x TransformerEncoderBlock(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (ffn): Sequential(
          (0): Linear(in_features=128, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_features=256, out_features=128, bias=True)
        )
        (layernorm_1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (layernorm_2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (dropout_1): Dropout(p=0.3, inplace=False)
        (dropout_2): Dropout(p=0.3, inplace=False)
      )
    )
  )
  (decoder): TransformerDecoder(
    (embedding): TokenAndPositionEmbedding(
      (word_emb): Embedding(2065, 128)
      (pos_emb): Embedding(50,

In [378]:
def train_fn(model, data_loader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(data_loader):
        src = batch['en_ids'].to(device)
        trg = batch['vi_ids'].to(device)
        #src: n x src_seq_length
        #trg: n x trg_seq_length
        optimizer.zero_grad()
        output = model(src, trg)
        #output: n x trg_seq_length x trg_vocab_size
        output_dim = output.shape[-1]
        output = output[:,1:,].reshape(-1,output_dim)
        #output: (n * trg_seq_length - 1) x trg_vocab_size
        trg = trg[:,1:].reshape(-1)
        #trg: n x trg_seq_length-1
        loss = criterion(output, trg)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [379]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            src = batch['en_ids'].to(device)
            trg = batch['vi_ids'].to(device)
            #src: n x src_seq_length
            #trg: n x trg_seq_length
            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output[:,1:,].reshape(-1,output_dim)
            #output: n x trg_seq_legth - 1 x trg_vocab_size
            trg = trg[:,1:].reshape(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [380]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=pad_index)

In [381]:
n_epochs = 20
best_valid_loss = float("inf")
for epoch in tqdm.tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        device,
    )
    valid_loss = evaluate_fn(
        model,
        valid_data_loader,
        criterion,
        device,
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "model.pt")
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")

  5%|▌         | 1/20 [00:03<01:03,  3.34s/it]

	Train Loss:   3.738 | Train PPL:  41.994
	Valid Loss:   1.516 | Valid PPL:   4.556


 10%|█         | 2/20 [00:06<00:59,  3.32s/it]

	Train Loss:   1.173 | Train PPL:   3.232
	Valid Loss:   0.611 | Valid PPL:   1.843


 15%|█▌        | 3/20 [00:09<00:56,  3.31s/it]

	Train Loss:   0.570 | Train PPL:   1.769
	Valid Loss:   0.320 | Valid PPL:   1.377


 20%|██        | 4/20 [00:13<00:53,  3.31s/it]

	Train Loss:   0.326 | Train PPL:   1.385
	Valid Loss:   0.182 | Valid PPL:   1.200


 25%|██▌       | 5/20 [00:16<00:49,  3.31s/it]

	Train Loss:   0.195 | Train PPL:   1.215
	Valid Loss:   0.106 | Valid PPL:   1.112


 30%|███       | 6/20 [00:19<00:46,  3.31s/it]

	Train Loss:   0.118 | Train PPL:   1.125
	Valid Loss:   0.059 | Valid PPL:   1.061


 35%|███▌      | 7/20 [00:23<00:43,  3.31s/it]

	Train Loss:   0.071 | Train PPL:   1.073
	Valid Loss:   0.034 | Valid PPL:   1.035


 40%|████      | 8/20 [00:26<00:39,  3.32s/it]

	Train Loss:   0.042 | Train PPL:   1.043
	Valid Loss:   0.019 | Valid PPL:   1.020


 45%|████▌     | 9/20 [00:29<00:36,  3.31s/it]

	Train Loss:   0.026 | Train PPL:   1.026
	Valid Loss:   0.012 | Valid PPL:   1.012


 50%|█████     | 10/20 [00:33<00:33,  3.31s/it]

	Train Loss:   0.017 | Train PPL:   1.017
	Valid Loss:   0.008 | Valid PPL:   1.008


 55%|█████▌    | 11/20 [00:36<00:29,  3.31s/it]

	Train Loss:   0.012 | Train PPL:   1.012
	Valid Loss:   0.006 | Valid PPL:   1.006


 60%|██████    | 12/20 [00:39<00:26,  3.30s/it]

	Train Loss:   0.010 | Train PPL:   1.010
	Valid Loss:   0.005 | Valid PPL:   1.005


 65%|██████▌   | 13/20 [00:43<00:23,  3.31s/it]

	Train Loss:   0.008 | Train PPL:   1.008
	Valid Loss:   0.004 | Valid PPL:   1.004


 70%|███████   | 14/20 [00:46<00:19,  3.31s/it]

	Train Loss:   0.006 | Train PPL:   1.006
	Valid Loss:   0.003 | Valid PPL:   1.003


 75%|███████▌  | 15/20 [00:49<00:16,  3.31s/it]

	Train Loss:   0.005 | Train PPL:   1.005
	Valid Loss:   0.003 | Valid PPL:   1.003


 80%|████████  | 16/20 [00:52<00:13,  3.31s/it]

	Train Loss:   0.005 | Train PPL:   1.005
	Valid Loss:   0.002 | Valid PPL:   1.002


 85%|████████▌ | 17/20 [00:56<00:09,  3.31s/it]

	Train Loss:   0.004 | Train PPL:   1.004
	Valid Loss:   0.002 | Valid PPL:   1.002


 90%|█████████ | 18/20 [00:59<00:06,  3.32s/it]

	Train Loss:   0.004 | Train PPL:   1.004
	Valid Loss:   0.002 | Valid PPL:   1.002


 95%|█████████▌| 19/20 [01:02<00:03,  3.33s/it]

	Train Loss:   0.003 | Train PPL:   1.003
	Valid Loss:   0.002 | Valid PPL:   1.002


100%|██████████| 20/20 [01:06<00:00,  3.31s/it]

	Train Loss:   0.003 | Train PPL:   1.003
	Valid Loss:   0.001 | Valid PPL:   1.001





In [382]:
model.load_state_dict(torch.load("model.pt"))
test_loss = evaluate_fn(model, test_data_loader, criterion, device)
print(f"| Test Loss: {test_loss:.3f} | Test PPL: {np.exp(test_loss):7.3f} |")

| Test Loss: 0.001 | Test PPL:   1.001 |


In [383]:
def translate_sentence(
    sentence,
    model,
    en_nlp,
    de_nlp,
    en_vocab,
    de_vocab,
    lower,
    sos_token,
    eos_token,
    device,
    max_output_length=20,
):
    model.eval()
    with torch.no_grad():
        if isinstance(sentence, str):
            tokens = [token.text for token in en_nlp.tokenizer(sentence)]
        else:
            tokens = [token for token in sentence]
        if lower:
            tokens = [token.lower() for token in tokens]
        res = []
        tokens = [sos_token] + tokens + [eos_token]
        ids = en_vocab.lookup_indices(tokens)
        print(ids)
        src = torch.LongTensor(ids).unsqueeze(0).to(device)
        encoder_output = model.encoder(src)
        tgt = torch.LongTensor(de_vocab.lookup_indices([sos_token])).unsqueeze(0).to(device)
        for i in range(max_output_length):
            src_mask,tgt_mask = model.generate_mask(src, tgt)
            decoder_output = model.decoder(tgt,encoder_output,src_mask,tgt_mask)
            predicted_token = decoder_output[0,i].argmax(-1).item()
            res.append(predicted_token)
            tgt = torch.cat((tgt, torch.tensor([[predicted_token]]).to(device)), dim=1)
            if predicted_token == en_vocab[eos_token]:
                break
        output = tgt[0].cpu().numpy()
#         output = model(src,torch.tensor(res).unsqueeze(0).to(device))
#         output = output[0].argmax(-1).squeeze().detach().cpu().numpy()
        tokens = de_vocab.lookup_tokens(output)
    return " ".join(tokens)

In [384]:
sentence = train_data[35]['en']
print(sentence)
translate_sentence(
    sentence,
    model,
    en_nlp,
    vi_nlp,
    en_vocab,
    vi_vocab,
    lower,
    sos_token,
    eos_token,
    device,
)

Tom is almost never wrong.
[2, 7, 13, 413, 87, 268, 4, 3]


'<sos> không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể không thể đã để'