# Interactive Demo of the Model on Functions Integration

In [1]:
import os
import sympy as sp
import torch

from src.utils import AttrDict
from src.envs import build_env
from src.model import build_modules
from src.utils import to_cuda
from src.envs.sympy_utils import simplify

## Build Environment - Reload Model

### Get Trained Model:

In [2]:
model_path = "../models/fwd_bwd_ibp_model.pth"
assert os.path.isfile(model_path)

### Set the Parameters for Environment and for the Model:

 Environment:
- **env_name**: SymPy character environment.
- **int_base**: integer representation base.
- **balanced**: balanced representation (base > 0).
- **positive**: do not sample negative numbers.
- **precision**: float numbers precision.
- **n_variables**: number of variables in expressions (between 1 and 4).
- **n_coefficients**: number of coefficients in expressions (between 0 and 10).
- **leaf_probs**: leaf probabilities of being a variable, a coefficient, an integer, or a constant.
- **max_len**: maximum sequences length.
- **max_int**: maximum integer value.
- **max_ops**: maximum number of operators.
- **max_ops_G**: maximum number of operators for G in IBP.
- **clean_prefix_expr**: clean prefix expressions (f x -> Y, derivative f x x -> Y').
- **rewrite_functions**: rewrite expressions with SymPy.
- **tasks**: tasks (prim_fwd, prim_bwd, prim_ibp, ode1, ode2).
- **operators**: operators (add, sub, mul, div), followed by weight.


 Model:
- **cpu**: run on CPU.
- **emb_dim**: embedding layer size.
- **n_enc_layers**: number of transformer layers in the encoder.
- **n_dec_layers**: number of transformer layers in the decoder.
- **n_heads**: number of transformer heads.
- **dropout**: dropout.
- **attention_dropout**: dropout in the attention layer.
- **sinusoidal_embeddings**: use sinusoidal embeddings.
- **share_inout_emb**: share input and output embeddings.
- **reload_model**: reload a pretrained model.

In [3]:
params = AttrDict({

    # Environment Parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,'
                 'acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

    # Model Parameters
    'cpu': False,
    'emb_dim': 1024,
    'n_enc_layers': 6,
    'n_dec_layers': 6,
    'n_heads': 8,
    'dropout': 0,
    'attention_dropout': 0,
    'sinusoidal_embeddings': False,
    'share_inout_emb': True,
    'reload_model': model_path,

})

### Set the Environment with SymPy

In [4]:
env = build_env(params)
x = env.local_dict['x']

### Build Model Modules

In [5]:
modules = build_modules(env, params)
encoder = modules['encoder']
decoder = modules['decoder']

## Start from a function F, compute its derivative f = F', and try to recover F from f

In [6]:
# here you can modify the integral function the model has to predict, F
F_infix = 'x * tan(exp(x)/x)'
F_infix = 'x * cos(x**2) * tan(x)'
F_infix = 'cos(x**2 * exp(x * cos(x)))'
F_infix = 'ln(cos(x + exp(x)) * sin(x**2 + 2) * exp(x) / x)'

In [7]:
# F (integral, that the model will try to predict)
F = sp.S(F_infix, locals=env.local_dict)
F

log(exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x)

In [8]:
# f (F', that the model will take as input)
f = F.diff(x)
f

x*(2*exp(x)*cos(x + exp(x))*cos(x**2 + 2) - (exp(x) + 1)*exp(x)*sin(x + exp(x))*sin(x**2 + 2)/x + exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x - exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x**2)*exp(-x)/(sin(x**2 + 2)*cos(x + exp(x)))

### Compute prefix representations

In [9]:
F_prefix = env.sympy_to_prefix(F)
f_prefix = env.sympy_to_prefix(f)
print(f"F prefix: {F_prefix}")
print(f"f prefix: {f_prefix}")

F prefix: ['ln', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2']
f prefix: ['mul', 'x', 'mul', 'pow', 'cos', 'add', 'x', 'exp', 'x', 'INT-', '1', 'mul', 'pow', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'INT-', '1', 'mul', 'add', 'mul', 'INT+', '2', 'mul', 'cos', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'exp', 'x', 'add', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'add', 'mul', 'INT-', '1', 'mul', 'pow', 'x', 'INT-', '2', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'mul', 'INT-', '1', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'add', 'INT+', '1', 'exp', 'x', 'mul', 'exp', 'x', 'mul', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'sin', 'add', 'x', 'exp', 'x', 'exp', 'mul', 'INT-', '

### Encode input

In [10]:
x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)
x1 = torch.LongTensor(
    [env.eos_index] +
    [env.word2id[w] for w in x1_prefix] +
    [env.eos_index]
).view(-1, 1)
len1 = torch.LongTensor([len(x1)])
x1, len1 = to_cuda(x1, len1)

with torch.no_grad():
    encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)

### Decode with beam search

In [11]:
beam_size = 10
with torch.no_grad():
    _, _, beam = decoder.generate_beam(encoded, len1, beam_size=beam_size, length_penalty=1.0, early_stopping=1, max_len=200)
    assert len(beam) == 1
hypotheses = beam[0].hyp
assert len(hypotheses) == beam_size

### Print results

In [12]:
print(f"Input function f: {f}")
print(f"Reference function F: {F}")
print("")

for score, sent in sorted(hypotheses, key=lambda x: x[0], reverse=True):

    # parse decoded hypothesis
    ids = sent[1:].tolist()                  # decoded token IDs
    tok = [env.id2word[wid] for wid in ids]  # convert to prefix

    try:
        hyp = env.prefix_to_infix(tok)       # convert to infix
        hyp = env.infix_to_sympy(hyp)        # convert to SymPy

        # check whether we recover f if we differentiate the hypothesis
        # note that sometimes, SymPy fails to show that hyp' - f == 0, and the result is considered as invalid, although it may be correct
        res = "OK" if simplify(hyp.diff(x) - f, seconds=1) == 0 else "NO"

    except:
        res = "INVALID PREFIX EXPRESSION"
        hyp = tok

    # print result
    print("%.5f  %s  %s" % (score, res, hyp))

Input function f: x*(2*exp(x)*cos(x + exp(x))*cos(x**2 + 2) - (exp(x) + 1)*exp(x)*sin(x + exp(x))*sin(x**2 + 2)/x + exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x - exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x**2)*exp(-x)/(sin(x**2 + 2)*cos(x + exp(x)))
Reference function F: log(exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x)

-0.00001  INVALID PREFIX EXPRESSION  ['ln', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2']
-0.33615  INVALID PREFIX EXPRESSION  ['ln', 'mul', 'INT-', '1', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2']
-0.33943  INVALID PREFIX EXPRESSION  ['ln', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2']
-0.40456  INVALID PREFIX EXPRESSION  ['ln', 'mul', 'pow', 'x', 'INT-', '1', 'm