## Data

In [11]:
import torch
import torch.nn as nn
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

corpus = [
    "ăn quả nhớ kẻ trồng cây",
    "có chí thì nên"    
]
data_size = len(corpus)

# Define the max vocabulary size and sequence length
vocab_size = 13
sequence_length = 6

In [12]:
# Initialize the tokenizer and define a trainer
tokenizer = Tokenizer(WordLevel())
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_id=1, 
                         pad_token="<pad>", 
                         length=sequence_length)
tokenizer.enable_truncation(max_length=sequence_length)

# Train the tokenizer on your corpus
trainer = WordLevelTrainer(vocab_size=vocab_size, 
                           special_tokens=["<unk>", "<pad>", "<sos>"])
tokenizer.train_from_iterator(corpus, trainer)

In [13]:
data_x = []
data_y = []
for vector in corpus:
    vector = vector.split()    

    for i in range(len(vector)):
        data_x.append( ' '.join((['<sos>'] + vector[:i])))
        data_y.append(vector[i])

# print
for x, y in zip(data_x, data_y):
    print(x)
    print(y)
    print()

<sos>
ăn

<sos> ăn
quả

<sos> ăn quả
nhớ

<sos> ăn quả nhớ
kẻ

<sos> ăn quả nhớ kẻ
trồng

<sos> ăn quả nhớ kẻ trồng
cây

<sos>
có

<sos> có
chí

<sos> có chí
thì

<sos> có chí thì
nên



In [14]:
# Tokenize and numericalize your samples
def vectorize(x, y, tokenizer, sequence_length):     
    x_ids = tokenizer.encode(x)
    y_ids = tokenizer.token_to_id(y)
    print(x_ids.ids, y_ids)
    return x_ids.ids, y_ids

# Vectorize the samples
data_x_ids = []
data_y_ids = []
for x, y in zip(data_x, data_y):
    x_ids, y_ids = vectorize(x, y, tokenizer, sequence_length)
    data_x_ids.append(x_ids)
    data_y_ids.append(y_ids)

data_x_ids = torch.tensor(data_x_ids, dtype=torch.long)
data_y_ids = torch.tensor(data_y_ids, dtype=torch.long)

[2, 1, 1, 1, 1, 1] 12
[2, 12, 1, 1, 1, 1] 9
[2, 12, 9, 1, 1, 1] 7
[2, 12, 9, 7, 1, 1] 6
[2, 12, 9, 7, 6, 1] 11
[2, 12, 9, 7, 6, 11] 4
[2, 1, 1, 1, 1, 1] 5
[2, 5, 1, 1, 1, 1] 3
[2, 5, 3, 1, 1, 1] 10
[2, 5, 3, 10, 1, 1] 8


## Train with full data

In [15]:
class TG_Model(nn.Module):
    def __init__(self, vocab_size, sequence_length):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, 4)
        self.recurrent = nn.RNN(4, 4, batch_first=True)
        self.linear = nn.Linear(sequence_length*4, vocab_size)

    def forward(self, x):
        x = self.embedding(x)   # [n, sequence_length, 4]
        x,_ = self.recurrent(x) # [n, sequence_length, 4]
        x = nn.Flatten()(x)     # [n, 24]
        x = self.linear(x)      # [n, 13] 
        return x

model = TG_Model(vocab_size, sequence_length)
print(model)

TG_Model(
  (embedding): Embedding(13, 4)
  (recurrent): RNN(4, 4, batch_first=True)
  (linear): Linear(in_features=24, out_features=13, bias=True)
)


In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

In [17]:
# train
for _ in range(50):
    optimizer.zero_grad()
    outputs = model(data_x_ids)
    loss = criterion(outputs, data_y_ids)
    #print(loss.item())
    loss.backward()
    optimizer.step()

In [18]:
outputs = model(data_x_ids)
print(torch.argmax(outputs, axis=-1))

tensor([12,  9,  7,  6, 11,  4, 12,  3, 10,  8])


In [19]:
data_y_ids

tensor([12,  9,  7,  6, 11,  4,  5,  3, 10,  8])

## Inference

In [42]:
promt = '<sos> ăn'
promt_length = 2
promt_ids = tokenizer.encode(promt).ids
print(promt_ids)

[2, 12, 1, 1, 1, 1]


In [40]:
for i in range(sequence_length - promt_length):
    promt_tensor = torch.tensor(promt_ids, dtype=torch.long).reshape(1, -1)
    outputs = model(promt_tensor)
    next_id = torch.argmax(outputs, axis=-1)

    promt_ids[promt_length+i] = next_id.item()
    print(promt_ids)

[2, 12, 9, 1, 1, 1]
[2, 12, 9, 7, 1, 1]
[2, 12, 9, 7, 6, 1]
[2, 12, 9, 7, 6, 11]


In [33]:
print(tokenizer.get_vocab())

{'nên': 8, 'kẻ': 6, 'quả': 9, 'ăn': 12, 'chí': 3, '<sos>': 2, 'thì': 10, 'trồng': 11, 'cây': 4, 'nhớ': 7, '<unk>': 0, '<pad>': 1, 'có': 5}
