In [1]:
import pickle
from src.expressions_generator import ExpressionGenerator
from src.evaluate import make_model, Evaluator
import torch
from torch.utils.data import DataLoader
from src.dataset import DiffeqDataset
import sympy as sp
import numpy as np
from tqdm import tqdm

In [9]:
def infix_eq_to_prefix_id(infix_expr):
    sympy_expr = gen.infix_to_sympy(infix_expr)
    prefix_expr = gen.sympy_to_prefix(sympy_expr)
    prefix_expr = gen.replace_f_by_y(prefix_expr)
    prefix_id = [token2id['<s>']] + [token2id[tok] for tok in prefix_expr] + [token2id['</s>']]
    return sympy_expr, prefix_id

In [10]:
def solution_to_sympy(prefix_sol):
    prefix_sol = [id2token[i] for i in prefix_sol]
    x_prefix = []
    y_prefix = []
    try:
        sep_id = prefix_sol.index('end_x_start_y')
    except:
        print('Solution is incorrect')
    x_prefix = prefix_sol[:sep_id]
    y_prefix = prefix_sol[sep_id+1:]
    x_t = gen.prefix_to_simpy(x_prefix)
    y_t = gen.prefix_to_simpy(y_prefix)
    return x_t, y_t

In [3]:
def create_data_iterator(train, batch_size, path):
       
        dataset = DiffeqDataset(train, batch_size, path)
    
        return DataLoader(dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          collate_fn=dataset.collate_fn
        )

## Опредление качества обученной модели

In [4]:
DATA_PATH = '/home/ilia/diff_eq_project/unresolved_der_dataset/'
MODEL_PATH = '/home/ilia/diff_eq_project/unresolved_der_models/'

with open(DATA_PATH + '/vocabulary', "rb") as f:
    vocab = pickle.load(f)
    token2id = vocab['token2id']
    id2token = vocab['id2token']
    
VOCAB_SIZE = len(token2id)
N_LAYERS = 6
D_MODEL = 256
D_FF = 1024
N_HEADS = 8
DROPOUT = 0.1
PAD_IDX = token2id['<pad>']
BATCH_SIZE = 64
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

Загружаем лучшую обученную модель:

In [5]:
model = make_model(DEVICE, VOCAB_SIZE, PAD_IDX, N_LAYERS, 
               D_MODEL, D_FF, N_HEADS, DROPOUT).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH + 'transformer_8.pt', map_location=torch.device('cpu')))
model.eval();

In [6]:
gen = ExpressionGenerator()
evaluator = Evaluator(model, DEVICE, vocab)

In [7]:
test_dataset = DiffeqDataset(train=False, batch_size=64, path=DATA_PATH)

In [13]:
x = gen.vars['x']
y = gen.vars['y']
t = gen.vars['t']
f = gen.functions['f']
c = gen.coef['C1']

Выберем некоторое количество уравнений из тестового датасета и посчитаем долю правильно решенных моделью уравнений:

In [21]:
check_sample = np.random.choice(len(test_dataset), size=500, replace=False)

In [22]:
count_right_solutions = 0
for i in tqdm(check_sample):
    prefix_eq_int = [int(id_token) for id_token in test_dataset[i][0]]
    prefix_sol_int = [int(id_token) for id_token in test_dataset[i][1]]
    try:
        beam_res, _, _ = evaluator.beam_decode(prefix_eq_int, beam_size=5, alpha=1)
        x_t, y_t = solution_to_sympy(prefix_sol_int[1:-1])
        x_t_model, y_t_model = solution_to_sympy(beam_res)
    except:
        continue
    
    x_t_p = x_t.subs([(t, 1.33), (c, 1.33)])
    y_t_p = y_t.subs([(t, 1.33), (c, 1.33)])
    x_t_model_p = x_t_model.subs([(t, 1.33), (c, 1.33)])
    y_t_model_p = y_t_model.subs([(t, 1.33), (c, 1.33)])
    
    if x_t_p.equals(x_t_model_p) and y_t_p.equals(y_t_model_p):
        count_right_solutions += 1

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [20:30<00:00,  2.46s/it]


In [23]:
print("Accuracy: ",  count_right_solutions/500)

Accuracy:  0.52


## Пример решения дифференциального уравнения

Попробуем решить ДУ Лагранжа (здесь t обозначает y'):

In [93]:
infix_eq = 't + ln(t)+y'

In [94]:
sympy_eq, id_eq = infix_eq_to_prefix_id(infix_eq)

"Решаем" ДУ с помощью модели:

In [95]:
beam_best_res, hyps, _= evaluator.beam_decode(id_eq, beam_size=10, alpha=0.5)

Исходное уравнение:

In [96]:
sympy_eq

t + y + log(t)

Полученное решение:

In [97]:
x_t, y_t = solution_to_sympy(hyps[0])

In [98]:
x_t

C1 - log(t) + 1/t

In [99]:
y_t

-t - log(t)

Проверим правильностьполученного решения подстановкой его в исходное уравнение:

In [100]:
sp.simplify(sympy_eq.subs(x, x_t).subs(y, y_t))

0