Будем обучать модель для сортировки чисел. В качестве энкодера и декодера будут LSTM, эмбеддинги у каждого свои.

In [1]:
from numpy.core.fromnumeric import size
from model import *
import numpy as np
import pickle
from tqdm import tqdm

Зададим параметры модели

In [2]:
n_dims_hidden = 50
emb_dims = 6
token_dims = 15
max_len = 20
n_samples = 2000
learning_rate = 5e-3

eps = 1e-7


Пара полезных функций для работы с токенами

In [3]:
def to_one_hot(value):
    if value == -1:
        return np.zeros((token_dims, 1))
    return np.eye(token_dims)[value].reshape((token_dims, 1))

def from_one_hot(value):
    if np.linalg.norm(value) < eps:
        return -1
    return np.argmax(value)

Полные параметры

In [4]:
params = {
        "n_dims_hidden" : n_dims_hidden,
        "enc_emb_dims" : emb_dims,
        "dec_emb_dims" : emb_dims,
        "loss_func" : "softmax_ce",
        "activation_func" : "softmax",
        "token_dims" : token_dims,
        "start_token" : to_one_hot(0),
        "max_len" : max_len
    }

Создадим данные

In [5]:
x = [np.random.randint(low=1, high=token_dims-1, size=np.random.randint(low=5, high=max_len-2)) for _ in range(n_samples)]
x_list = [[0] + val.tolist() + [token_dims - 1] for val in x]
y = [[0] + sorted(val) + [token_dims - 1] + [-1]*(max_len - len(val) - 2) for val in x]

x_oh = [np.array([to_one_hot(val) for val in sample]) for sample in x_list]
y_oh = [np.array([to_one_hot(val) for val in sample]) for sample in y]

In [6]:
model = DecoderWithEncoder(**params)

Можно обучать

In [7]:
epochs = 100
early_stop = 10
cur_early_stop = 0
min_loss = 1e9
best_model_params = model.copy_params()

for epoch in range(1, epochs + 1):
    if epoch % 50 == 0:
        learning_rate = learning_rate / 3
    total_loss = 0
    for i in tqdm(range(n_samples)):
        model.initialize_gradients()
        y_out = model.forward(x_oh[i])
        model.backprop(list(y_oh[i]))
        model.update_parameters(learning_rate)
        total_loss += cross_entropy_loss(y_out, y_oh[i]).sum()
    
    print(f"\nEpoch {epoch}/{epochs}, loss : {total_loss / n_samples}")
    if total_loss > min_loss:
        cur_early_stop += 1
    else:
        min_loss = total_loss
        cur_early_stop = 0
        best_model_params = model.copy_params()
    if cur_early_stop == early_stop:
        break
    if epoch % 10 == 0:
        with open("model", "wb") as f:
            pickle.dump(best_model_params, f)
model.set_params(best_model_params)

100%|██████████| 2000/2000 [00:22<00:00, 90.03it/s]
  0%|          | 10/2000 [00:00<00:22, 90.34it/s]
Epoch 1/100, loss : 30.526526483218685
100%|██████████| 2000/2000 [00:22<00:00, 90.25it/s]
  0%|          | 10/2000 [00:00<00:21, 91.92it/s]
Epoch 2/100, loss : 27.63238113916928
100%|██████████| 2000/2000 [00:21<00:00, 92.19it/s]
  0%|          | 10/2000 [00:00<00:21, 92.66it/s]
Epoch 3/100, loss : 26.381996320172405
100%|██████████| 2000/2000 [00:21<00:00, 91.96it/s]
  0%|          | 10/2000 [00:00<00:21, 91.90it/s]
Epoch 4/100, loss : 23.47813457564262
100%|██████████| 2000/2000 [00:21<00:00, 92.05it/s]
  0%|          | 10/2000 [00:00<00:20, 94.90it/s]
Epoch 5/100, loss : 20.79769373500732
100%|██████████| 2000/2000 [00:21<00:00, 91.00it/s]
  0%|          | 10/2000 [00:00<00:21, 92.96it/s]
Epoch 6/100, loss : 19.440059638721497
100%|██████████| 2000/2000 [00:21<00:00, 92.04it/s]
  0%|          | 10/2000 [00:00<00:21, 93.32it/s]
Epoch 7/100, loss : 18.346149801902513
100%|██████████|

Проверим на паре примеров

In [8]:
x_test = [np.random.randint(low=1, high=token_dims, size=np.random.randint(low=5, high=max_len-1)) for _ in range(5)]
print([val.tolist() for val in x_test])
x_test_oh = [np.array([to_one_hot(val) for val in x_i]) for x_i in x_test]
model.enable_caching(False)

y = [model.forward(val) for val in tqdm(x_test_oh)]

y = [[from_one_hot(val) for val in y_i] for y_i in y]
print(y)

100%|██████████| 5/5 [00:00<00:00, 592.55it/s][[8, 7, 3, 7, 3, 3, 7, 5], [5, 9, 2, 4, 6, 6, 14, 5, 7, 5, 13, 8], [1, 3, 14, 14, 6, 6, 11, 2, 6, 12], [5, 12, 2, 11, 9, 6, 11, 8, 12, 10, 11, 4, 6], [14, 5, 13, 4, 4, 8, 4, 6, 9, 1, 14, 3, 7, 10, 10]]
[[0, 3, 3, 3, 7, 7, 7, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14], [0, 2, 4, 5, 5, 5, 6, 6, 9, 9, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14], [0, 1, 2, 3, 6, 6, 6, 9, 9, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14], [0, 2, 4, 5, 6, 6, 8, 9, 11, 12, 12, 13, 14, 14, 14, 14, 14, 14, 14, 14], [0, 1, 3, 4, 4, 4, 5, 6, 7, 8, 9, 9, 9, 14, 14, 14, 14, 14, 14, 14]]



В качестве метрики возьмем количество совпадающих n-грам для n от 1 до max_len 

In [9]:
def extract_ngrams_of_size(y, size):
    return set(tuple(y[i:i+size]) for i in range(len(y) - size + 1))

def extract_ngrams(y):
    ngrams = set()
    for size in range(len(y) + 1):
        ngrams = ngrams.union(extract_ngrams_of_size(y, size))

    return ngrams

def compare(y_out, y):
    ngrams_out = extract_ngrams(y_out)
    ngrams = extract_ngrams(y)

    return len(ngrams.intersection(ngrams_out)) / len (ngrams)

Сгенерируем данные для проверки

In [10]:
x_test = []
n_test_samples = 50
while len(x_test) < n_test_samples:
    sample = np.random.randint(low=1, high=token_dims-1, size=np.random.randint(low=5, high=max_len-2))
    if sample.tolist() not in x_list:
        x_test.append(sample)

y_test = [[0] + sorted(val) + [token_dims-1] for val in x_test]

x_test_oh = [np.array([to_one_hot(val) for val in sample]) for sample in x_test]
y_test_oh = [np.array([to_one_hot(val) for val in sample]) for sample in y_test]

Посмотрим метрики

In [11]:
y_out = [model.forward(val) for val in tqdm(x_test_oh)]
y_out = [[from_one_hot(val) for val in y_i] for y_i in y_out]
y_out = [y_i[:y_i.index(token_dims - 1) + 1 if token_dims - 1 in y_i else len(y_i)] for y_i in y_out]

accuracy = 0
print()
for i in tqdm(range(n_test_samples)):
    accuracy += compare(y_out[i], y_test[i])
print()
print(f"{accuracy / n_test_samples}")


100%|██████████| 50/50 [00:00<00:00, 670.66it/s]
100%|██████████| 50/50 [00:00<00:00, 9974.56it/s]

0.44363725320767666

