# Demo of the Model on Differential Equations

In [1]:
import os

import numpy
import pandas
import sympy
import torch

from src.envs import build_env
from src.envs.sympy_utils import simplify_equa_diff
from src.envs.char_sp import InvalidPrefixExpression
from src.envs.sympy_utils import simplify
from src.model import build_modules
from src.utils import AttrDict
from src.utils import to_cuda

Use the Model to Find the solutions to first and second order differential equations.

## First Order Differential Equations (ODE 1)

Procedure:
1. Start from a bivariate function $F(x,c)$, that will be the equation solution, that can be solved in $c$.
2. Solve $F(x,c)$ in $c$.
3. Differentiate in $x$.
4. Simplify the final form.

### Build Environment - Reload Model

Get Trained Model:

In [2]:
model_path = '../models/differential-equations/ode1.pth'
assert os.path.isfile(model_path)

Set the Parameters for Environment and for the Model:

 Environment:
- **env_name**: SymPy character environment.
- **int_base**: integer representation base.
- **balanced**: balanced representation (base > 0).
- **positive**: do not sample negative numbers.
- **precision**: float numbers precision.
- **n_variables**: number of variables in expressions (between 1 and 4).
- **n_coefficients**: number of coefficients in expressions (between 0 and 10).
- **leaf_probs**: leaf probabilities of being a variable, a coefficient, an integer, or a constant.
- **max_len**: maximum sequences length.
- **max_int**: maximum integer value.
- **max_ops**: maximum number of operators.
- **max_ops_G**: maximum number of operators for G in IBP.
- **clean_prefix_expr**: clean prefix expressions (f x -> Y, derivative f x x -> Y').
- **rewrite_functions**: rewrite expressions with a given SymPy function.
- **tasks**: tasks to run (prim_fwd, prim_bwd, prim_ibp, ode1, ode2).
- **operators**: considered operators (add, sub, mul, div), followed by (unnormalized) sampling probabilities.


 Model:
- **cpu**: run on CPU.
- **emb_dim**: embedding layer size.
- **n_enc_layers**: number of transformer layers in the encoder.
- **n_dec_layers**: number of transformer layers in the decoder.
- **n_heads**: number of transformer heads.
- **dropout**: dropout.
- **attention_dropout**: dropout in the attention layer.
- **sinusoidal_embeddings**: use sinusoidal embeddings.
- **share_inout_emb**: share input and output embeddings.
- **reload_model**: reload a pretrained model.

In [3]:
params = AttrDict({

    # Environment Parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'ode1',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,'
                 'acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

    # Model Parameters
    'cpu': False,
    'emb_dim': 1024,
    'n_enc_layers': 6,
    'n_dec_layers': 6,
    'n_heads': 8,
    'dropout': 0,
    'attention_dropout': 0,
    'sinusoidal_embeddings': False,
    'share_inout_emb': True,
    'reload_model': model_path,

})

Set the Environment with SymPy:

In [4]:
env = build_env(params)
x = env.local_dict['x']
f = env.local_dict['f']
a8 = env.local_dict['a8']

The primary components of the model are one encoder and one decoder network. The encoder turns each item into a
corresponding hidden vector containing the item and its context. The decoder reverses the process, turning the vector
into an output item, using the previous output as the input context.

Build Model Modules:

In [5]:
modules = build_modules(env, params)
encoder = modules['encoder']
decoder = modules['decoder']

### Declare Differential Equation Input and its Solution

Declare a bivariate function $F(x,c=a_8)$, that will be the equation solution, that can be solved in $a_8$:

In [6]:
# y_infix = 'x*log(a8/x)'
# y_infix = '((4+(2*x))**(-1))*((x+(sin(x)))*((x**4)+(a8*(x**(-1)))))'
y_infix = 'exp(a8+((sqrt(cos(x)))+(acos(2*x))))'
#y_infix = '3+((a8*(x**(-1)))+(sin(tanh(cos(x)))))'

Converts **y_infix** to a type that can be used inside SymPy:

In [7]:
y = sympy.sympify(y_infix, locals=env.local_dict)
y

exp(a8 + sqrt(cos(x)) + acos(2*x))

Solve $y$ in $a8$:

In [8]:
solve_a8 = sympy.solve(f(x) - y, a8, check=False, simplify=False)
a8 = solve_a8[0]
a8

log(f(x)*exp(-sqrt(cos(x)))*exp(-acos(2*x)))

Differentiate $a_8$ in $x$:

In [9]:
eq = a8.diff(x)
eq = simplify(eq, seconds=1)
eq

sin(x)/(2*sqrt(cos(x))) + Derivative(f(x), x)/f(x) + 2/sqrt(1 - 4*x**2)

Simplify previous differential equation:

In [10]:
eq = simplify_equa_diff(eq, required=f(x).diff(x))
eq

sqrt(1 - 4*x**2)*f(x)*sin(x) + 2*sqrt(1 - 4*x**2)*sqrt(cos(x))*Derivative(f(x), x) + 4*f(x)*sqrt(cos(x))

### Compute Prefix Representations

In [11]:
y_prefix = env.sympy_to_prefix(y)
eq_prefix = env.sympy_to_prefix(eq)
print(f"y with Prefix Notation:\n{y_prefix}\n")
print(f"eq with Prefix Notation:\n{eq_prefix}")

y with Prefix Notation:
['exp', 'add', 'a8', 'add', 'sqrt', 'cos', 'x', 'acos', 'mul', 'INT+', '2', 'x']

eq with Prefix Notation:
['add', 'mul', 'INT+', '4', 'mul', 'sqrt', 'cos', 'x', 'f', 'x', 'add', 'mul', 'sqrt', 'add', 'INT+', '1', 'mul', 'INT-', '4', 'pow', 'x', 'INT+', '2', 'mul', 'f', 'x', 'sin', 'x', 'mul', 'INT+', '2', 'mul', 'sqrt', 'add', 'INT+', '1', 'mul', 'INT-', '4', 'pow', 'x', 'INT+', '2', 'mul', 'sqrt', 'cos', 'x', 'derivative', 'f', 'x', 'x']


### Encode Input

Clean prefix expressions before they are converted to PyTorch data.

Examples:
- f x  -> Y
- derivative f x x  -> Y'

In [12]:
x1_prefix = env.clean_prefix(eq_prefix)
print(f"f Clean Prefix Notation:\n{x1_prefix}")

f Clean Prefix Notation:
['add', 'mul', 'INT+', '4', 'mul', 'sqrt', 'cos', 'x', 'Y', 'add', 'mul', 'sqrt', 'add', 'INT+', '1', 'mul', 'INT-', '4', 'pow', 'x', 'INT+', '2', 'mul', 'Y', 'sin', 'x', 'mul', 'INT+', '2', 'mul', 'sqrt', 'add', 'INT+', '1', 'mul', 'INT-', '4', 'pow', 'x', 'INT+', '2', 'mul', 'sqrt', 'cos', 'x', "Y'"]


Create a PyTorch LongTensor for storing $eq$ as a sequence of indexes based on prefix clean notation "words" (Word to
index dictionary is defined inside the Model environment):

In [13]:
x1 = torch.LongTensor(
    [env.eos_index] +
    [env.word2id[w] for w in x1_prefix] +
    [env.eos_index]
).view(-1, 1)
x1.transpose(0, 1)
print(env.word2id)

{'<s>': 0, '</s>': 1, '<pad>': 2, '(': 3, ')': 4, '<SPECIAL_5>': 5, '<SPECIAL_6>': 6, '<SPECIAL_7>': 7, '<SPECIAL_8>': 8, '<SPECIAL_9>': 9, 'pi': 10, 'E': 11, 'x': 12, 'y': 13, 'z': 14, 't': 15, 'a0': 16, 'a1': 17, 'a2': 18, 'a3': 19, 'a4': 20, 'a5': 21, 'a6': 22, 'a7': 23, 'a8': 24, 'a9': 25, 'abs': 26, 'acos': 27, 'acosh': 28, 'acot': 29, 'acoth': 30, 'acsc': 31, 'acsch': 32, 'add': 33, 'asec': 34, 'asech': 35, 'asin': 36, 'asinh': 37, 'atan': 38, 'atanh': 39, 'cos': 40, 'cosh': 41, 'cot': 42, 'coth': 43, 'csc': 44, 'csch': 45, 'derivative': 46, 'div': 47, 'exp': 48, 'f': 49, 'g': 50, 'h': 51, 'inv': 52, 'ln': 53, 'mul': 54, 'pow': 55, 'pow2': 56, 'pow3': 57, 'pow4': 58, 'pow5': 59, 'rac': 60, 'sec': 61, 'sech': 62, 'sign': 63, 'sin': 64, 'sinh': 65, 'sqrt': 66, 'sub': 67, 'tan': 68, 'tanh': 69, 'I': 70, 'INT+': 71, 'INT-': 72, 'INT': 73, 'FLOAT': 74, '-': 75, '.': 76, '10^': 77, 'Y': 78, "Y'": 79, "Y''": 80, '0': 81, '1': 82, '2': 83, '3': 84, '4': 85, '5': 86, '6': 87, '7': 88, '8'

Move PyTorch tensors to CUDA (GPU):

In [14]:
len1 = torch.LongTensor([len(x1)])
x1, len1 = to_cuda(x1, len1)

Encodes the “meaning” of the input sequence into a single vector, with the Encoder of the Model:

In [15]:
with torch.no_grad():
    encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)

encoded

tensor([[[ 0.0072, -0.0884,  0.0259,  ...,  0.0011, -0.0184, -0.0673],
         [ 0.0148, -0.0460, -0.0399,  ...,  0.0005,  0.0184,  0.0081],
         [-0.0214, -0.0219,  0.0367,  ...,  0.0256,  0.0066,  0.0188],
         ...,
         [ 0.1294, -0.0265,  0.0656,  ...,  0.1684,  0.0164,  0.0663],
         [-0.0210, -0.0229,  0.0597,  ...,  0.0015,  0.0819,  0.0254],
         [ 0.0121, -0.0171,  0.0205,  ..., -0.0246,  0.0144, -0.0151]]],
       device='cuda:0')

### Decode with Beam Search

Instead of picking a single output, a sequence (in this case an hypothesis of differential equation solution), multiple
highly probable choices are retained.

Declare beam size:

In [16]:
beam_size = 10

Takes the encoder output vector and outputs multiple sequences of "words", that in this case should represent the
solution $y$ for the differential equation $eq$, using the Decoder of the model.

In [17]:
with torch.no_grad():
    _, _, beam = decoder.generate_beam(encoded, len1, beam_size=beam_size, length_penalty=1.0, early_stopping=1,
                                       max_len=params.max_len)
assert len(beam) == 1
hypotheses = beam[0].hyp
assert len(hypotheses) == beam_size

### View the Results

Input differential equation $eq$:

In [18]:
eq

sqrt(1 - 4*x**2)*f(x)*sin(x) + 2*sqrt(1 - 4*x**2)*sqrt(cos(x))*Derivative(f(x), x) + 4*f(x)*sqrt(cos(x))

Solution $y$ to find:

In [19]:
y

exp(a8 + sqrt(cos(x)) + acos(2*x))

Extract scores and solution hypotheses:

In [20]:
rows = numpy.arange(1, beam_size + 1)
columns = ['Score', 'Solution Hypothesis', 'Valid']
results = []

for score, sequence in sorted(hypotheses, reverse=True):
    # Parse decoded hypothesis
    ids = sequence[1:].tolist()  # Decoded token IDs
    hyp_prefix = [env.id2word[word_id] for word_id in ids]  # Convert to prefix notation

    try:
        hyp_infix = env.prefix_to_infix(hyp_prefix)  # Convert to infix notation
        hyp_sympy = env.infix_to_sympy(hyp_infix)  # Convert to SymPy

        # Check if the hypothesis is a valid solution, replacing 'hyp_sympy' with 'f(x)' in the equation
        validation = "YES" if simplify(eq.subs(f(x), hyp_sympy).doit(), seconds=1) == 0 else "NO"

        # Transform hypothesis to a valid latex expression
        hyp_expr = "$" + sympy.latex(env.infix_to_sympy(hyp_infix))  + "$"

    except InvalidPrefixExpression:
        validation = "INVALID PREFIX EXPRESSION"
        hyp_expr = hyp_prefix

    # Prepare results
    results.append([score, hyp_expr, validation])

Print results:

In [21]:
pandas.set_option('max_colwidth', None)
pandas.DataFrame(results, index=rows, columns=columns).style.set_properties(**{'text-align': 'center'})

Unnamed: 0,Score,Solution Hypothesis,Valid
1,-0.040032,$e^{a_{8} + \sqrt{\cos{\left(x \right)}} + \operatorname{acos}{\left(2 x \right)}}$,YES
2,-0.128009,$e^{a_{8} + \sqrt{\cos{\left(x \right)}} - \operatorname{asin}{\left(2 x \right)}}$,YES
3,-0.144557,$a_{8} e^{\sqrt{\cos{\left(x \right)}} + \operatorname{acos}{\left(2 x \right)}}$,YES
4,-0.209824,$a_{8} e^{\sqrt{\cos{\left(x \right)}} - \operatorname{asin}{\left(2 x \right)}}$,YES
5,-0.331873,$a_{8} e^{\sqrt{\cos{\left(x \right)}} + \operatorname{acos}{\left(2 x \right)} + 1}$,YES
6,-0.346632,$a_{8} e^{\sqrt{\cos{\left(x \right)}}} e^{- \operatorname{asin}{\left(2 x \right)}}$,YES
7,-0.348289,$a_{8} e^{\sqrt{\cos{\left(x \right)}} + \operatorname{acos}{\left(2 x \right)} + 2}$,YES
8,-0.349814,$a_{8} e^{\sqrt{\cos{\left(x \right)}} + \operatorname{acos}{\left(2 x \right)} + 5}$,YES
9,-0.349916,$a_{8} e^{\sqrt{\cos{\left(x \right)}}} e^{\operatorname{acos}{\left(2 x \right)}}$,YES
10,-0.350311,$a_{8} e^{\sqrt{\cos{\left(x \right)}} + \operatorname{acos}{\left(2 x \right)} + 3}$,YES


## Second Order Differential Equations (ODE 2)

Procedure:
1. Start from a trivariate function $F(x,c_1,c_2)$, that will be the equation solution, that can be solved in $c_2$.
2. Solve $F(x,c_1,c_2)$ in $c_2$.
3. Differentiate in $x$.
4. Solve in $c_1$.
5. Differentiate in $x$.
6. Simplify the final form.

### Build Environment - Reload Model

Get Trained Model:

In [22]:
model_path = '../models/differential-equations/ode2.pth'
assert os.path.isfile(model_path)

params = AttrDict({

    # Environment Parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'ode1',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,'
                 'acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

    # Model Parameters
    'cpu': False,
    'emb_dim': 1024,
    'n_enc_layers': 6,
    'n_dec_layers': 6,
    'n_heads': 8,
    'dropout': 0,
    'attention_dropout': 0,
    'sinusoidal_embeddings': False,
    'share_inout_emb': True,
    'reload_model': model_path,

})

env = build_env(params)
a8 = env.local_dict['a8']
a9 = env.local_dict['a9']

modules = build_modules(env, params)
encoder = modules['encoder']
decoder = modules['decoder']

### Declare Differential Equation Input and its Solution

Declare a bivariate function $F(x,c_1=a_8, c_2=a_9)$, that will be the equation solution, that can be solved in $a_9$:

In [23]:
# y_infix = 'a8*exp(x)+a9*exp(-x)'
# y_infix = '(2*x)+((x*(a9+(a8*x)))+(exp(3)))'
# y_infix = '(x**9)*((a8+(a9*((cos(cosh(asin(tanh(x)))))**(-1))))**(-1))'
y_infix = 'a9*(5+((a8+(x+(x**3)))*(exp((-1)*((tan(x))**2)))))'

Converts **y_infix** to a type that can be used inside SymPy:

In [24]:
y = sympy.sympify(y_infix, locals=env.local_dict)
y

a9*((a8 + x**3 + x)*exp(-tan(x)**2) + 5)

Solve $y$ in $a9$:

In [25]:
solve_a9 = sympy.solve(f(x) - y, a9, check=False, simplify=False)
a9 = solve_a9[0]
a9

-f(x)*exp(tan(x)**2)/(-a8 - x**3 - x - 5*exp(tan(x)**2))

Differentiate $a_9$ in $x$:

In [26]:
eq = a9.diff(x)
eq = simplify(eq, seconds=1)
eq

-(2*tan(x)**2 + 2)*f(x)*exp(tan(x)**2)*tan(x)/(-a8 - x**3 - x - 5*exp(tan(x)**2)) - (3*x**2 + 5*(2*tan(x)**2 + 2)*exp(tan(x)**2)*tan(x) + 1)*f(x)*exp(tan(x)**2)/(-a8 - x**3 - x - 5*exp(tan(x)**2))**2 - exp(tan(x)**2)*Derivative(f(x), x)/(-a8 - x**3 - x - 5*exp(tan(x)**2))

Solve $eq$ in $a8$:

In [27]:
solve_a8 = sympy.solve(eq, a8, check=False, simplify=False)
a8 = solve_a8[0]
a8

(-2*x**3*f(x)*tan(x)**3 - 2*x**3*f(x)*tan(x) - x**3*Derivative(f(x), x) + 3*x**2*f(x) - 2*x*f(x)*tan(x)**3 - 2*x*f(x)*tan(x) - x*Derivative(f(x), x) + f(x) - 5*exp(tan(x)**2)*Derivative(f(x), x))/(2*f(x)*tan(x)**3 + 2*f(x)*tan(x) + Derivative(f(x), x))

Differentiate $a_8$ in $x$:

In [28]:
eq = a8.diff(x)
eq = simplify(eq, seconds=1)
eq

(-2*x**3*(tan(x)**2 + 1)*f(x) - 2*x**3*(3*tan(x)**2 + 3)*f(x)*tan(x)**2 - 2*x**3*tan(x)**3*Derivative(f(x), x) - 2*x**3*tan(x)*Derivative(f(x), x) - x**3*Derivative(f(x), (x, 2)) - 6*x**2*f(x)*tan(x)**3 - 6*x**2*f(x)*tan(x) - 2*x*(tan(x)**2 + 1)*f(x) - 2*x*(3*tan(x)**2 + 3)*f(x)*tan(x)**2 + 6*x*f(x) - 2*x*tan(x)**3*Derivative(f(x), x) - 2*x*tan(x)*Derivative(f(x), x) - x*Derivative(f(x), (x, 2)) - 5*(2*tan(x)**2 + 2)*exp(tan(x)**2)*tan(x)*Derivative(f(x), x) - 2*f(x)*tan(x)**3 - 2*f(x)*tan(x) - 5*exp(tan(x)**2)*Derivative(f(x), (x, 2)))/(2*f(x)*tan(x)**3 + 2*f(x)*tan(x) + Derivative(f(x), x)) + (-2*(tan(x)**2 + 1)*f(x) - 2*(3*tan(x)**2 + 3)*f(x)*tan(x)**2 - 2*tan(x)**3*Derivative(f(x), x) - 2*tan(x)*Derivative(f(x), x) - Derivative(f(x), (x, 2)))*(-2*x**3*f(x)*tan(x)**3 - 2*x**3*f(x)*tan(x) - x**3*Derivative(f(x), x) + 3*x**2*f(x) - 2*x*f(x)*tan(x)**3 - 2*x*f(x)*tan(x) - x*Derivative(f(x), x) + f(x) - 5*exp(tan(x)**2)*Derivative(f(x), x))/(2*f(x)*tan(x)**3 + 2*f(x)*tan(x) + Derivative(

Simplify previous differential equation:

In [29]:
eq = simplify_equa_diff(eq, required=f(x).diff(x, 2))
eq

12*x**2*f(x)*tan(x)**6 + 42*x**2*f(x)*tan(x)**4 + 36*x**2*f(x)*tan(x)**2 + 6*x**2*f(x) + 12*x**2*tan(x)**3*Derivative(f(x), x) + 12*x**2*tan(x)*Derivative(f(x), x) + 3*x**2*Derivative(f(x), (x, 2)) - 12*x*f(x)*tan(x)**3 - 12*x*f(x)*tan(x) - 6*x*Derivative(f(x), x) + 4*f(x)*tan(x)**6 + 14*f(x)*tan(x)**4 + 12*f(x)*tan(x)**2 + 2*f(x) + 20*exp(tan(x)**2)*tan(x)**6*Derivative(f(x), x) + 10*exp(tan(x)**2)*tan(x)**4*Derivative(f(x), x) + 10*exp(tan(x)**2)*tan(x)**3*Derivative(f(x), (x, 2)) - 20*exp(tan(x)**2)*tan(x)**2*Derivative(f(x), x) + 10*exp(tan(x)**2)*tan(x)*Derivative(f(x), (x, 2)) - 10*exp(tan(x)**2)*Derivative(f(x), x) + 4*tan(x)**3*Derivative(f(x), x) + 4*tan(x)*Derivative(f(x), x) + Derivative(f(x), (x, 2))

### Compute Prefix Representations

In [30]:
y_prefix = env.sympy_to_prefix(y)
eq_prefix = env.sympy_to_prefix(eq)
print(f"y with Prefix Notation:\n{y_prefix}\n")
print(f"eq with Prefix Notation:\n{eq_prefix}")

y with Prefix Notation:
['mul', 'a9', 'add', 'INT+', '5', 'mul', 'add', 'a8', 'add', 'x', 'pow', 'x', 'INT+', '3', 'exp', 'mul', 'INT-', '1', 'pow', 'tan', 'x', 'INT+', '2']

eq with Prefix Notation:
['add', 'mul', 'INT+', '2', 'f', 'x', 'add', 'mul', 'INT-', '1', '0', 'mul', 'derivative', 'f', 'x', 'x', 'exp', 'pow', 'tan', 'x', 'INT+', '2', 'add', 'mul', 'INT-', '6', 'mul', 'x', 'derivative', 'f', 'x', 'x', 'add', 'mul', 'INT+', '3', 'mul', 'pow', 'x', 'INT+', '2', 'derivative', 'derivative', 'f', 'x', 'x', 'x', 'add', 'mul', 'INT+', '4', 'mul', 'pow', 'tan', 'x', 'INT+', '3', 'derivative', 'f', 'x', 'x', 'add', 'mul', 'INT+', '4', 'mul', 'pow', 'tan', 'x', 'INT+', '6', 'f', 'x', 'add', 'mul', 'INT+', '4', 'mul', 'derivative', 'f', 'x', 'x', 'tan', 'x', 'add', 'mul', 'INT+', '6', 'mul', 'pow', 'x', 'INT+', '2', 'f', 'x', 'add', 'mul', 'INT+', '1', '2', 'mul', 'pow', 'tan', 'x', 'INT+', '2', 'f', 'x', 'add', 'mul', 'INT+', '1', '4', 'mul', 'pow', 'tan', 'x', 'INT+', '4', 'f', 'x', 'ad

### Encode Input

In [31]:
x1_prefix = env.clean_prefix(eq_prefix)
print(f"eq Clean Prefix Notation:\n{x1_prefix}")

eq Clean Prefix Notation:
['add', 'mul', 'INT+', '2', 'Y', 'add', 'mul', 'INT-', '1', '0', 'mul', "Y'", 'exp', 'pow', 'tan', 'x', 'INT+', '2', 'add', 'mul', 'INT-', '6', 'mul', 'x', "Y'", 'add', 'mul', 'INT+', '3', 'mul', 'pow', 'x', 'INT+', '2', "Y''", 'add', 'mul', 'INT+', '4', 'mul', 'pow', 'tan', 'x', 'INT+', '3', "Y'", 'add', 'mul', 'INT+', '4', 'mul', 'pow', 'tan', 'x', 'INT+', '6', 'Y', 'add', 'mul', 'INT+', '4', 'mul', "Y'", 'tan', 'x', 'add', 'mul', 'INT+', '6', 'mul', 'pow', 'x', 'INT+', '2', 'Y', 'add', 'mul', 'INT+', '1', '2', 'mul', 'pow', 'tan', 'x', 'INT+', '2', 'Y', 'add', 'mul', 'INT+', '1', '4', 'mul', 'pow', 'tan', 'x', 'INT+', '4', 'Y', 'add', 'mul', 'INT-', '2', '0', 'mul', 'pow', 'tan', 'x', 'INT+', '2', 'mul', "Y'", 'exp', 'pow', 'tan', 'x', 'INT+', '2', 'add', 'mul', 'INT-', '1', '2', 'mul', 'x', 'mul', 'pow', 'tan', 'x', 'INT+', '3', 'Y', 'add', 'mul', 'INT-', '1', '2', 'mul', 'x', 'mul', 'Y', 'tan', 'x', 'add', 'mul', 'INT+', '1', '0', 'mul', 'pow', 'tan', 'x'

In [32]:
x1 = torch.LongTensor(
    [env.eos_index] +
    [env.word2id[w] for w in x1_prefix] +
    [env.eos_index]
).view(-1, 1)

In [33]:
len1 = torch.LongTensor([len(x1)])
x1, len1 = to_cuda(x1, len1)

In [34]:
with torch.no_grad():
    encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)

encoded

tensor([[[ 0.0388, -0.0235, -0.0223,  ...,  0.0413,  0.0049, -0.0236],
         [ 0.0129, -0.0387,  0.0985,  ...,  0.0268, -0.0467,  0.0007],
         [ 0.1357,  0.0585,  0.0384,  ...,  0.2114, -0.0529, -0.2556],
         ...,
         [-0.3126,  0.1535,  0.3317,  ..., -0.2210, -0.1265,  0.1046],
         [ 0.1923,  0.2024, -0.0650,  ..., -0.1848, -0.1778, -0.1041],
         [ 0.0417, -0.0210, -0.0255,  ...,  0.0377,  0.0020, -0.0247]]],
       device='cuda:0')

### Decode with Beam Search

In [35]:
beam_size = 10

with torch.no_grad():
    _, _, beam = decoder.generate_beam(encoded, len1, beam_size=beam_size, length_penalty=1.0, early_stopping=1,
                                       max_len=params.max_len)
assert len(beam) == 1
hypotheses = beam[0].hyp
assert len(hypotheses) == beam_size

### View the Results

Input differential equation $eq$:

In [36]:
eq

12*x**2*f(x)*tan(x)**6 + 42*x**2*f(x)*tan(x)**4 + 36*x**2*f(x)*tan(x)**2 + 6*x**2*f(x) + 12*x**2*tan(x)**3*Derivative(f(x), x) + 12*x**2*tan(x)*Derivative(f(x), x) + 3*x**2*Derivative(f(x), (x, 2)) - 12*x*f(x)*tan(x)**3 - 12*x*f(x)*tan(x) - 6*x*Derivative(f(x), x) + 4*f(x)*tan(x)**6 + 14*f(x)*tan(x)**4 + 12*f(x)*tan(x)**2 + 2*f(x) + 20*exp(tan(x)**2)*tan(x)**6*Derivative(f(x), x) + 10*exp(tan(x)**2)*tan(x)**4*Derivative(f(x), x) + 10*exp(tan(x)**2)*tan(x)**3*Derivative(f(x), (x, 2)) - 20*exp(tan(x)**2)*tan(x)**2*Derivative(f(x), x) + 10*exp(tan(x)**2)*tan(x)*Derivative(f(x), (x, 2)) - 10*exp(tan(x)**2)*Derivative(f(x), x) + 4*tan(x)**3*Derivative(f(x), x) + 4*tan(x)*Derivative(f(x), x) + Derivative(f(x), (x, 2))

Solution $y$ to find:

In [37]:
y

a9*((a8 + x**3 + x)*exp(-tan(x)**2) + 5)

Extract scores and solution hypotheses:

In [38]:
rows = numpy.arange(1, beam_size + 1)
columns = ['Score', 'Solution Hypothesis', 'Valid']
results = []

for score, sequence in sorted(hypotheses, reverse=True):
    # Parse decoded hypothesis
    ids = sequence[1:].tolist()  # Decoded token IDs
    hyp_prefix = [env.id2word[word_id] for word_id in ids]  # Convert to prefix notation

    try:
        hyp_infix = env.prefix_to_infix(hyp_prefix)  # Convert to infix notation
        hyp_sympy = env.infix_to_sympy(hyp_infix)  # Convert to SymPy

        # Check if the hypothesis is a valid solution, replacing 'hyp_sympy' with 'f(x)' in the equation
        validation = "YES" if simplify(eq.subs(f(x), hyp_sympy).doit(), seconds=1) == 0 else "NO"

        # Transform hypothesis to a valid latex expression
        hyp_expr = "$" + sympy.latex(env.infix_to_sympy(hyp_infix))  + "$"

    except InvalidPrefixExpression:
        validation = "INVALID PREFIX EXPRESSION"
        hyp_expr = hyp_prefix

    # Prepare results
    results.append([score, hyp_expr, validation])

Print results:

In [39]:
pandas.set_option('max_colwidth', None)
pandas.DataFrame(results, index=rows, columns=columns).style.set_properties(**{'text-align': 'center'})

Unnamed: 0,Score,Solution Hypothesis,Valid
1,-0.02118,$a_{9} \left(\left(a_{8} + x^{3} + x\right) e^{- \tan^{2}{\left(x \right)}} + 5\right)$,YES
2,-0.051553,$a_{8} \left(\left(a_{9} + x^{3} + x\right) e^{- \tan^{2}{\left(x \right)}} + 5\right)$,YES
3,-0.144064,$a_{9} \left(\left(a_{8} - x^{3} - x\right) e^{- \tan^{2}{\left(x \right)}} - 5\right)$,YES
4,-0.167011,$a_{8} \left(\left(a_{9} - x^{3} - x\right) e^{- \tan^{2}{\left(x \right)}} - 5\right)$,YES
5,-0.173669,$a_{9} \left(\left(a_{8} + x \left(x^{2} + 1\right)\right) e^{- \tan^{2}{\left(x \right)}} + 5\right)$,YES
6,-0.182952,$a_{9} \left(x \left(\frac{a_{8}}{x} + x^{2} + 1\right) e^{- \tan^{2}{\left(x \right)}} + 5\right)$,YES
7,-0.184083,$a_{9} \left(\frac{\left(a_{8} + x^{3} + x\right) e^{- \tan^{2}{\left(x \right)}}}{5} + 1\right)$,YES
8,-0.205459,$a_{9} \left(3 \left(a_{8} + x^{3} + x\right) e^{- \tan^{2}{\left(x \right)}} + 15\right)$,YES
9,-0.206048,$a_{8} \left(\left(a_{9} + x \left(x^{2} + 1\right)\right) e^{- \tan^{2}{\left(x \right)}} + 5\right)$,YES
10,-0.588345,$a_{9} \left(- a_{8} + x^{3} + x + 5\right)$,NO
