Будем обучать модель для сортировки чисел. В качестве энкодера и декодера будут LSTM, эмбеддинги у каждого свои.

In [1]:
from numpy.core.fromnumeric import size
from model import *
import numpy as np
import pickle
from tqdm import tqdm

Зададим параметры модели

In [2]:
n_dims_hidden = 50
emb_dims = 6
token_dims = 15
max_len = 20
n_samples = 2000
learning_rate = 5e-3

eps = 1e-7


Пара полезных функций для работы с токенами

In [3]:
def to_one_hot(value):
    if value == -1:
        return np.zeros((token_dims, 1))
    return np.eye(token_dims)[value].reshape((token_dims, 1))

def from_one_hot(value):
    if np.linalg.norm(value) < eps:
        return -1
    return np.argmax(value)

Полные параметры

In [4]:
params = {
        "n_dims_hidden" : n_dims_hidden,
        "enc_emb_dims" : emb_dims,
        "dec_emb_dims" : emb_dims,
        "loss_func" : "softmax_ce",
        "activation_func" : "softmax",
        "token_dims" : token_dims,
        "start_token" : to_one_hot(0),
        "max_len" : max_len
    }

Создадим данные

In [5]:
x = [np.random.randint(low=1, high=token_dims-1, size=np.random.randint(low=5, high=max_len-2)) for _ in range(n_samples)]
x_list = [[0] + val.tolist() + [token_dims - 1] for val in x]
y = [[0] + sorted(val) + [token_dims - 1] + [-1]*(max_len - len(val) - 2) for val in x]

x_oh = [np.array([to_one_hot(val) for val in sample]) for sample in x_list]
y_oh = [np.array([to_one_hot(val) for val in sample]) for sample in y]

In [6]:
model = DecoderWithEncoder(**params)

Можно обучать

In [7]:
epochs = 2000
early_stop = 10
cur_early_stop = 0
min_loss = 1e9
best_model_params = model.copy_params()

for epoch in range(1, epochs + 1):
    if epoch % 30 == 0:
        learning_rate = learning_rate / 5
    total_loss = 0
    for i in tqdm(range(n_samples)):
        model.initialize_gradients()
        y_out = model.forward(x_oh[i])
        model.backprop(list(y_oh[i]))
        model.update_parameters(learning_rate)
        total_loss += cross_entropy_loss(y_out, y_oh[i]).sum()
    
    print(f"\nEpoch {epoch}/{epochs}, loss : {total_loss / n_samples}")
    if total_loss > min_loss:
        cur_early_stop += 1
    else:
        min_loss = total_loss
        cur_early_stop = 0
        best_model_params = model.copy_params()
    if cur_early_stop == early_stop:
        break
    if epoch % 10 == 0:
        with open("model", "wb") as f:
            pickle.dump(best_model_params, f)
model.set_params(best_model_params)

100%|██████████| 2000/2000 [00:22<00:00, 88.24it/s]
  0%|          | 9/2000 [00:00<00:23, 83.04it/s]
Epoch 1/2000, loss : 30.426619211930497
100%|██████████| 2000/2000 [00:22<00:00, 87.51it/s]
  0%|          | 9/2000 [00:00<00:24, 82.50it/s]
Epoch 2/2000, loss : 27.551604479702764
100%|██████████| 2000/2000 [00:22<00:00, 88.47it/s]
  0%|          | 8/2000 [00:00<00:25, 79.21it/s]
Epoch 3/2000, loss : 27.116872448751955
100%|██████████| 2000/2000 [00:22<00:00, 88.07it/s]
  0%|          | 10/2000 [00:00<00:21, 91.97it/s]
Epoch 4/2000, loss : 26.17955686333139
100%|██████████| 2000/2000 [00:22<00:00, 88.79it/s]
  0%|          | 9/2000 [00:00<00:23, 85.04it/s]
Epoch 5/2000, loss : 24.589851731800795
100%|██████████| 2000/2000 [00:22<00:00, 87.44it/s]
  0%|          | 9/2000 [00:00<00:22, 87.86it/s]
Epoch 6/2000, loss : 21.30141860103913
100%|██████████| 2000/2000 [00:22<00:00, 88.71it/s]
  0%|          | 10/2000 [00:00<00:21, 91.10it/s]
Epoch 7/2000, loss : 19.71845299249777
100%|█████████

KeyboardInterrupt: 

Проверим на паре примеров

In [8]:
x_test = [np.random.randint(low=1, high=token_dims, size=np.random.randint(low=5, high=max_len-1)) for _ in range(5)]
print(x_test)
x_test_oh = [np.array([to_one_hot(val) for val in x_i]) for x_i in x_test]
model.enable_caching(False)

y = [model.forward(val) for val in tqdm(x_test_oh)]

y = [[from_one_hot(val) for val in y_i] for y_i in y]
print(y)

100%|██████████| 5/5 [00:00<00:00, 581.25it/s][array([ 2, 10,  6,  1, 10,  4,  1,  3,  6, 11]), array([ 7,  8, 12,  7, 11]), array([10,  4,  8,  3,  9,  8,  5, 14, 12,  3, 10, 13,  8,  9, 14, 10]), array([ 2,  1, 14,  9,  7,  8]), array([ 6,  1,  1,  4, 14,  3,  2,  2,  6, 10, 10,  3,  1,  7, 12])]
[[0, 1, 1, 2, 3, 3, 6, 6, 9, 10, 12, 13, 14, 14, 14, 14, 14, 14, 14, 14], [0, 7, 7, 8, 11, 12, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14], [0, 3, 3, 4, 5, 8, 8, 8, 9, 9, 10, 10, 10, 11, 12, 14, 14, 14, 14, 14], [0, 1, 2, 7, 8, 8, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14], [0, 1, 1, 1, 2, 2, 3, 3, 3, 6, 6, 7, 9, 9, 10, 12, 14, 14, 14, 14]]



В качестве метрики возьмем количество совпадающих n-грам для n от 1 до max_len 

In [9]:
def extract_ngrams_of_size(y, size):
    return set(tuple(y[i:i+size]) for i in range(len(y) - size + 1))

def extract_ngrams(y):
    ngrams = set()
    for size in range(len(y) + 1):
        ngrams = ngrams.union(extract_ngrams_of_size(y, size))

    return ngrams

def compare(y_out, y):
    ngrams_out = extract_ngrams(y_out)
    ngrams = extract_ngrams(y)

    return len(ngrams.intersection(ngrams_out)) / len (ngrams)

Сгенерируем данные для проверки

In [10]:
x_test = []
n_test_samples = 50
while len(x_test) < n_test_samples:
    sample = np.random.randint(low=1, high=token_dims-1, size=np.random.randint(low=5, high=max_len-2))
    if sample.tolist() not in x_list:
        x_test.append(sample)

y_test = [[0] + sorted(val) + [token_dims-1] for val in x_test]

x_test_oh = [np.array([to_one_hot(val) for val in sample]) for sample in x_test]
y_test_oh = [np.array([to_one_hot(val) for val in sample]) for sample in y_test]

Посмотрим метрики

In [16]:
y_out = [model.forward(val) for val in tqdm(x_test_oh)]
y_out = [[from_one_hot(val) for val in y_i] for y_i in y_out]
y_out = [y_i[:y_i.index(token_dims - 1) + 1 if token_dims - 1 in y_i else len(y_i)] for y_i in y_out]

accuracy = 0
print()
for i in tqdm(range(n_test_samples)):
    accuracy += compare(y_out[i], y_test[i])
print()
print(f"{accuracy / n_test_samples}")


100%|██████████| 50/50 [00:00<00:00, 448.40it/s]
100%|██████████| 50/50 [00:00<00:00, 8465.82it/s][0, 1, 2, 4, 4, 4, 8, 9, 9, 14]
[0, 1, 2, 4, 4, 4, 8, 9, 9, 13, 13, 14]


0.36689972361338585

