## Data

In [29]:
import torch
import torch.nn as nn
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# Define tokenizer function
tokenizer = get_tokenizer('basic_english')

# Create a function to yield list of tokens
def yield_tokens(examples):
    for text in examples:
        yield tokenizer(text)

# Tokenize and numericalize your samples
def vectorize_context(text, vocab, sequence_length):
    tokens = tokenizer(text)
    tokens = [vocab[token] for token in tokens][:sequence_length]
    token_ids = tokens + [vocab["<pad>"]] * (sequence_length - len(tokens))
    
    return token_ids

def vectorize_generation(x, y, vocab, sequence_length):    
    x_ids = [vocab[token] for token in x][:sequence_length]
    x_ids = x_ids + [vocab["<pad>"]] * (sequence_length - len(x))

    y_ids = [vocab[token] for token in y][:sequence_length]
    y_ids = y_ids + [vocab["<pad>"]] * (sequence_length - len(y))
    
    return x_ids, y_ids

#### corpus_generation

In [30]:
corpus_generation = [
    "ăn quả nhớ kẻ trồng cây",
    "làm giàu không khó"    
]
data_size_generation = len(corpus_generation)

# max vocabulary size and sequence length
vocab_size_generation = 14
sequence_length_generation = 7

In [31]:
# Create vocabulary
vocab_generation = build_vocab_from_iterator(yield_tokens(corpus_generation),
                                     max_tokens=vocab_size_generation,
                                     specials=["<unk>", "<pad>", "<sos>", "<eos>"])
vocab_generation.set_default_index(vocab_generation["<unk>"])
vocab_generation.get_stoi()

{'ăn': 13,
 'quả': 11,
 'nhớ': 10,
 'làm': 9,
 'trồng': 12,
 '<eos>': 3,
 'kẻ': 8,
 'không': 7,
 'khó': 6,
 '<unk>': 0,
 'cây': 4,
 'giàu': 5,
 '<sos>': 2,
 '<pad>': 1}

In [32]:
data_x = []
data_y = []
for vector in corpus_generation:
    vector = ['<sos>'] + vector.split() + ['<eos>']
    data_x.append(vector[:-1])
    data_y.append(vector[1:])  

print(data_x)
print(data_y)

[['<sos>', 'ăn', 'quả', 'nhớ', 'kẻ', 'trồng', 'cây'], ['<sos>', 'làm', 'giàu', 'không', 'khó']]
[['ăn', 'quả', 'nhớ', 'kẻ', 'trồng', 'cây', '<eos>'], ['làm', 'giàu', 'không', 'khó', '<eos>']]


In [33]:
data_x_ids = []
data_y_ids = []
for x, y in zip(data_x, data_y):
    x_ids, y_ids = vectorize_generation(x, y, 
                                        vocab_generation, 
                                        sequence_length_generation)
    data_x_ids.append(x_ids)
    data_y_ids.append(y_ids)

print(data_x_ids)
print(data_y_ids)

[[2, 13, 11, 10, 8, 12, 4], [2, 9, 5, 7, 6, 1, 1]]
[[13, 11, 10, 8, 12, 4, 3], [9, 5, 7, 6, 3, 1, 1]]


In [34]:
data_x_ids = torch.tensor(data_x_ids, dtype=torch.long)
print(data_x_ids.shape)

data_y_ids = torch.tensor(data_y_ids, dtype=torch.long)
print(data_y_ids.shape)

torch.Size([2, 7])
torch.Size([2, 7])


#### topics

In [35]:
topics = [
    'khuyên răn',
    'kinh doanh'
]
vocab_size_context = 6
sequence_length_context = 2



In [36]:
topics = [
    'khuyên răn',
    'kinh doanh'
]
vocab_size_context = 6
sequence_length_context = 2
vocab_context = build_vocab_from_iterator(yield_tokens(topics),
                                          max_tokens=vocab_size_context,
                                          specials=["<unk>", "<pad>"])

topics_ids = []
for x in topics:
    x_ids = vectorize_context(x, vocab_context, 
                              sequence_length_context)
    topics_ids.append(x_ids)

In [37]:
# Create vocabulary
vocab_context = build_vocab_from_iterator(yield_tokens(topics),
                                          max_tokens=vocab_size_context,
                                          specials=["<unk>", "<pad>"])
vocab_context.get_stoi()

{'răn': 5, '<unk>': 0, '<pad>': 1, 'khuyên': 3, 'kinh': 4, 'doanh': 2}

In [38]:
topics_ids = []
for vector in topics:
    vector = vector.split()
    topics_ids.append(vector)

topics_ids2 = []
for x in topics:
    x_ids = vectorize_context(x, vocab_context, 
                              sequence_length_context)
    topics_ids2.append(x_ids)

# print
print(topics_ids)
print(topics_ids2)

[['khuyên', 'răn'], ['kinh', 'doanh']]
[[3, 5], [4, 2]]


In [39]:
topics_tensor = torch.tensor(topics_ids2, dtype=torch.long)
print(topics_tensor)

tensor([[3, 5],
        [4, 2]])


## Train with full data

In [40]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, model_dim, nhead):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer_encoder = nn.TransformerEncoderLayer(d_model=model_dim, 
                                                              nhead=nhead, 
                                                              dim_feedforward=6,
                                                              batch_first=True)

    def forward(self, src):
        # src = [batch_size, seq_length]
        
        embedded = self.embedding(src)                  
        # [batch_size, seq_length, embedding_dim]
        
        context = self.transformer_encoder(embedded)        
        return context

In [41]:
embedding_dim, model_dim, nhead = 6, 6, 2
encoder = Encoder(vocab_size_context, embedding_dim, model_dim, nhead)

context_sample = encoder(topics_tensor)
print(context_sample.shape)

torch.Size([2, 2, 6])


In [42]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, model_dim, nhead):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer_decoder = nn.TransformerDecoderLayer(d_model=model_dim, 
                                                        nhead=nhead, 
                                                        dim_feedforward=6,
                                                        batch_first=True)
        self.fc_out = nn.Linear(model_dim, vocab_size)

    def forward(self, input, context):
        # input = [batch_size, seq_length_vn]   
        # context = [batch_size, seq_length_en, model_dim]   
                
        embedded = self.embedding(input)
        # embedded = [batch_size, seq_length_vn, embedding_dim]
        
        output = self.transformer_decoder(embedded, context)
        # output = [batch_size, seq_length_vn, model_dim]
        
        prediction = self.fc_out(output)
        # prediction = [batch_size, vocab_size_vn]
        
        return prediction.permute(0, 2, 1)

In [43]:
decoder = Decoder(vocab_size_generation, embedding_dim, model_dim, nhead)

prediction = decoder(data_x_ids, context_sample)
print(prediction.shape)

torch.Size([2, 14, 7])


In [44]:
class Seq2Seq_Model(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder  

    def forward(self, sequence_encoder, sequence_decoder):        
        context = self.encoder(sequence_encoder)
        outputs = self.decoder(sequence_decoder, context)
            
        return outputs

In [45]:
model = Seq2Seq_Model(encoder, decoder)
outputs = model(topics_tensor, data_x_ids)
print(outputs.shape)

torch.Size([2, 14, 7])


In [46]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

In [47]:
# train
for _ in range(40):
    optimizer.zero_grad()
    outputs = model(topics_tensor, data_x_ids)
    loss = criterion(outputs, data_y_ids)
    print(loss.item())
    loss.backward()
    optimizer.step()

2.8102221488952637
2.2792866230010986
2.0750935077667236
1.7887763977050781
1.6383552551269531
1.4232064485549927
1.2275269031524658
1.0885570049285889
0.88576740026474
0.7645182013511658
0.6225671768188477
0.5774521231651306
0.5062259435653687
0.405166357755661
0.2784051299095154
0.26622429490089417
0.1934419870376587
0.1445288211107254
0.12828584015369415
0.10890782624483109
0.06873469799757004
0.06670587509870529
0.05160193890333176
0.05473150685429573
0.027748581022024155
0.021766694262623787
0.041832808405160904
0.025965657085180283
0.014230777509510517
0.014112441800534725
0.01168507058173418
0.007866929285228252
0.0061887046322226524
0.011358500458300114
0.005096558015793562
0.04116428270936012
0.0047452799044549465
0.006873291451483965
0.0061792186461389065
0.0055232904851436615


In [48]:
outputs = model(topics_tensor, data_x_ids)
#print(outputs)
print(torch.argmax(outputs, axis=1))

tensor([[13, 11, 10,  8, 12,  4,  3],
        [ 9,  5,  7,  6,  3,  1,  1]])


In [49]:
data_y_ids

tensor([[13, 11, 10,  8, 12,  4,  3],
        [ 9,  5,  7,  6,  3,  1,  1]])

In [50]:
# check
topic1 = topics_tensor[0:1,:]
print(topic1)

data_x_id1 = data_x_ids[0:1,:]
print(data_x_id1)

data_y_id1 = data_y_ids[0:1,:]
print(data_y_id1)

tensor([[3, 5]])
tensor([[ 2, 13, 11, 10,  8, 12,  4]])
tensor([[13, 11, 10,  8, 12,  4,  3]])


In [51]:
outputs = model(topic1, data_x_id1)
print(outputs.shape)
print(torch.argmax(outputs, axis=1))

torch.Size([1, 14, 7])
tensor([[13, 11, 10,  8, 12,  4,  3]])


In [52]:
# check
topic1 = topics_tensor[1:2,:]
print(topic1)

data_x_id1 = data_x_ids[1:2,:]
print(data_x_id1)

data_y_id1 = data_y_ids[1:2,:]
print(data_y_id1)

tensor([[4, 2]])
tensor([[2, 9, 5, 7, 6, 1, 1]])
tensor([[9, 5, 7, 6, 3, 1, 1]])


In [53]:
outputs = model(topic1, data_x_id1)
print(outputs.shape)
print(torch.argmax(outputs, axis=1))

torch.Size([1, 14, 7])
tensor([[9, 5, 7, 6, 3, 1, 1]])


## Inference

In [58]:
promt = '<sos>'
promt = promt.split()
promt_ids = [vocab_generation[token] for token in promt][:sequence_length_generation]
promt_ids = promt_ids + [vocab_generation["<pad>"]] * (sequence_length_generation - len(promt))

print(promt_ids)

[2, 1, 1, 1, 1, 1, 1]


In [59]:
topic = 'kinh doanh'   # 'kinh doanh' ; 'khuyên răn'
topic = topic.split()
topic_ids = [vocab_context[token] for token in topic]
topic_tensor = torch.tensor(topic_ids, dtype=torch.long).reshape(1, -1)
print(topic_tensor)

tensor([[4, 2]])


In [None]:
for i in range(sequence_length_generation - len(promt)):
    promt_tensor = torch.tensor(promt_ids, 
                                dtype=torch.long).reshape(1, -1)
    outputs = model(topic_tensor, promt_tensor)
    outputs = torch.argmax(outputs, axis=1)   
    next_id = outputs[0][len(promt)+i-1]

    promt_ids[len(promt)+i] = next_id.item()
print(promt_ids)

[2, 9, 5, 7, 6, 3, 7]


In [57]:
{'ăn': 13,
 'quả': 11,
 'nhớ': 10,
 'làm': 9,
 'trồng': 12,
 '<eos>': 3,
 'kẻ': 8,
 'không': 7,
 'khó': 6,
 '<unk>': 0,
 'cây': 4,
 'giàu': 5,
 '<sos>': 2,
 '<pad>': 1}

{'ăn': 13,
 'quả': 11,
 'nhớ': 10,
 'làm': 9,
 'trồng': 12,
 '<eos>': 3,
 'kẻ': 8,
 'không': 7,
 'khó': 6,
 '<unk>': 0,
 'cây': 4,
 'giàu': 5,
 '<sos>': 2,
 '<pad>': 1}