## Data

In [15]:
import torch
import torch.nn as nn
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

corpus = [
    "ăn quả nhớ kẻ trồng cây",
    "có chí thì nên"    
]
data_size = len(corpus)

# Define the max vocabulary size and sequence length
vocab_size = 13
sequence_length = 6

In [16]:
# Define tokenizer function
tokenizer = get_tokenizer('basic_english')

# Create a function to yield list of tokens
def yield_tokens(examples):
    for text in examples:
        yield tokenizer(text)

# Create vocabulary
vocab = build_vocab_from_iterator(yield_tokens(corpus),
                                  max_tokens=vocab_size,
                                  specials=["<unk>", "<pad>", "<sos>"]
                                 )
vocab.set_default_index(vocab["<unk>"])
vocab.get_stoi()

{'ăn': 12,
 'nên': 8,
 'nhớ': 7,
 'có': 5,
 'thì': 10,
 'trồng': 11,
 'kẻ': 6,
 '<unk>': 0,
 'cây': 4,
 'quả': 9,
 '<sos>': 2,
 'chí': 3,
 '<pad>': 1}

In [17]:
data_x = []
data_y = []
for vector in corpus:
    vector = vector.split()    

    for i in range(len(vector)):
        data_x.append(['<sos>'] + vector[:i])
        data_y.append(vector[i])

# print
for x, y in zip(data_x, data_y):
    print(x)
    print(y)
    print()

['<sos>']
ăn

['<sos>', 'ăn']
quả

['<sos>', 'ăn', 'quả']
nhớ

['<sos>', 'ăn', 'quả', 'nhớ']
kẻ

['<sos>', 'ăn', 'quả', 'nhớ', 'kẻ']
trồng

['<sos>', 'ăn', 'quả', 'nhớ', 'kẻ', 'trồng']
cây

['<sos>']
có

['<sos>', 'có']
chí

['<sos>', 'có', 'chí']
thì

['<sos>', 'có', 'chí', 'thì']
nên



In [18]:
data_x_ids = []
data_y_ids = []
def vectorize(x, y, vocab, sequence_length):
    x_ids = [vocab[token] for token in x][:sequence_length]
    x_ids = x_ids + [vocab["<pad>"]]*(sequence_length - len(x))    
    return x_ids, vocab[y]
for x, y in zip(data_x, data_y):
    x_ids, y_ids = vectorize(x, y, vocab, sequence_length)
    data_x_ids.append(x_ids)
    data_y_ids.append(y_ids)

In [19]:
for x, y in zip(data_x_ids, data_y_ids):
    print(x)
    print(y)
    print()

[2, 1, 1, 1, 1, 1]
12

[2, 12, 1, 1, 1, 1]
9

[2, 12, 9, 1, 1, 1]
7

[2, 12, 9, 7, 1, 1]
6

[2, 12, 9, 7, 6, 1]
11

[2, 12, 9, 7, 6, 11]
4

[2, 1, 1, 1, 1, 1]
5

[2, 5, 1, 1, 1, 1]
3

[2, 5, 3, 1, 1, 1]
10

[2, 5, 3, 10, 1, 1]
8



In [20]:
data_x_ids = torch.tensor(data_x_ids, dtype=torch.long)
print(data_x_ids.shape)

torch.Size([10, 6])


In [21]:
data_y_ids = torch.tensor(data_y_ids, dtype=torch.long)
print(data_y_ids.shape)

torch.Size([10])


## Model

In [22]:
class TG_Model(nn.Module):
    def __init__(self, vocab_size, sequence_length):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, 4)
        self.recurrent = nn.RNN(4, 4, batch_first=True)
        self.linear = nn.Linear(sequence_length*4, vocab_size)

    def forward(self, x):
        x = self.embedding(x)   # [n, sequence_length, 4]
        x,_ = self.recurrent(x) # [n, sequence_length, 4]
        x = nn.Flatten()(x)     # [n, 24]
        x = self.linear(x)      # [n, 13] 
        return x

model = TG_Model(vocab_size, sequence_length)
print(model)

TG_Model(
  (embedding): Embedding(13, 4)
  (recurrent): RNN(4, 4, batch_first=True)
  (linear): Linear(in_features=24, out_features=13, bias=True)
)


In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

## Train

In [24]:
# train
for _ in range(35):
    optimizer.zero_grad()
    outputs = model(data_x_ids)
    loss = criterion(outputs, data_y_ids)
    print(loss.item())
    loss.backward()
    optimizer.step()

2.7055916786193848
2.164893627166748
1.7753992080688477
1.4416204690933228
1.151464819908142
0.9141565561294556
0.7340354323387146
0.6000059247016907
0.5004871487617493
0.4314578175544739
0.38419976830482483
0.3464199900627136
0.31238603591918945
0.28088831901550293
0.2526453137397766
0.22972261905670166
0.21128150820732117
0.19626843929290771
0.1852664202451706
0.17725470662117004
0.17126096785068512
0.1667129099369049
0.16242918372154236
0.15837077796459198
0.15525270998477936
0.15270254015922546
0.15056125819683075
0.14917609095573425
0.14813630282878876
0.14716339111328125
0.1465173065662384
0.1459810435771942
0.14535680413246155
0.1448863446712494
0.1444772481918335


In [25]:
outputs = model(data_x_ids)
print(torch.argmax(outputs, axis=-1))

tensor([ 5,  9,  7,  6, 11,  4,  5,  3, 10,  8])


In [26]:
data_y_ids

tensor([12,  9,  7,  6, 11,  4,  5,  3, 10,  8])

## Inference

In [27]:
promt = '<sos> ăn'
promt = promt.split()
promt_ids = [vocab[token] for token in promt][:sequence_length]
promt_ids = promt_ids + [vocab["<pad>"]] * (sequence_length - len(promt))

print(promt_ids)

[2, 12, 1, 1, 1, 1]


In [28]:
for i in range(sequence_length - len(promt)):
    promt_tensor = torch.tensor(promt_ids, dtype=torch.long).reshape(1, -1)
    outputs = model(promt_tensor)
    next_id = torch.argmax(outputs, axis=-1)

    promt_ids[len(promt)+i] = next_id.item()
    print(promt_ids)

[2, 12, 9, 1, 1, 1]
[2, 12, 9, 7, 1, 1]
[2, 12, 9, 7, 6, 1]
[2, 12, 9, 7, 6, 11]
