In [1]:
### Imports ###
import sympy as sp
import torch

from ginnlp import GINNLP




In [2]:
from contextlib import contextmanager
import sys, os

@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

In [7]:
### Data Preparation ###
functions = [
    '(2*x**2 + 3 * x + 3)/(x + 7)',
    '(x**2 + 2)/(x**2 + 1)',
    '(x**2 + x + 1)',
    '(3*x**3 + 2 * x**2 + 5*x + 2)/(x**2 + 4*x + 8)',
    '(3*x + 1)/(x**2 + 3)',
    '(x**3 + 3*x**2 + 2*x + 1)/(x + 5)',
    '(2*x**4 + x**3 + 7)/(x**2 + 1)',
    '(x**2 + 3*x + 2)/(2*x + 3)',
    '(5*x**2 + 3*x + 4)/(x**3 + 6)',
    '(2*x + 2)/(x + 3)',
    '(3*x**2 + 4*x + 5)/(x**2 + 2*x + 1)',
    '(x**4 + x + 7)/(2*x**2 + 3*x + 9)',
    '(3*x**2 + x + 2)/(x + 1)',
    '(2*x**3 + 3)/(x**2 + 4)',
    '(x**2 + 5)/(3*x + 1)',
    '(x**5 + 2*x**3 + x)/(x**2 + 3*x + 2)',
    '(4*x**2 + 2*x + 3)/(x + 2)',
    '(3*x**3 + 3)/(x**4 + 1)',
    '(2*x + 1)/(x**3 + 2*x**2 + 4*x + 8)',
    '(x**3 + 2*x**2 + 3*x + 1)/(x**2 + 5*x + 2)'
]


recovered_functions = []
x_train = torch.linspace(1, 5, 101).to('cpu').unsqueeze(0).T

In [8]:
### Training each function ###
for i, function in enumerate(functions):
    target_function = sp.lambdify('x', function)
    y_train = target_function(x_train)
    ginnLP = GINNLP(num_epochs=500, round_digits=3, start_ln_blocks=1, growth_steps=3,
                    l1_reg=1e-4, l2_reg=1e-4, init_lr=0.01, decay_steps=1000, reg_change=0.5)
    with suppress_stdout():
        ginnLP.fit(x_train, y_train.squeeze())
    print("Functions trained: {}".format(i+1))
    recovered_function = ginnLP.recovered_eq
    recovered_functions.append(recovered_function)

Functions trained: 1
Functions trained: 2
Functions trained: 3
Functions trained: 4
Functions trained: 5
Functions trained: 6
Functions trained: 7
Functions trained: 8
Functions trained: 9
Functions trained: 10
Functions trained: 11
Functions trained: 12
Functions trained: 13
Functions trained: 14
Functions trained: 15
Functions trained: 16
Functions trained: 17
Functions trained: 18
Functions trained: 19
Functions trained: 20


In [9]:
recovered_functions

[0.348/X_0**0.185 + 0.07/X_0**0.018 + 0.581*X_0**1.379,
 0.491/X_0**1.504 + 0.433/X_0**0.076,
 0.36/X_0**0.302 + 0.319*X_0**0.05 + 0.892*X_0**0.292 + 1.458*X_0**1.854,
 -0.135/X_0**0.355 + 0.396/X_0**0.309 - 0.069*X_0**0.555 + 0.715*X_0**1.549,
 -0.179/X_0**0.87 + 0.639*X_0**0.239 + 1.347*X_0**0.261 - 0.778*X_0**0.714,
 -0.3/X_0**0.456 + 0.195*X_0**0.038 + 0.573*X_0**0.053 + 0.702*X_0**2.094,
 2.669/X_0**1.757 + 0.348/X_0**1.264 + 0.315*X_0**1.248 + 1.926*X_0**2.036,
 0.548*X_0**0.962,
 -2.563/X_0**2.222 - 0.742/X_0**2.13 + 3.54/X_0**0.748 + 1.518/X_0**0.747,
 -0.032/X_0**0.111 + 0.669*X_0**0.141 + 0.719*X_0**0.61 - 0.357*X_0**0.766,
 0.829/X_0**0.643 + 1.225/X_0**0.245 - 0.102/X_0**0.21 + 1.043*X_0**0.321,
 0.934/X_0**0.067 - 0.677*X_0**0.763 - 0.282*X_0**1.018 + 0.361*X_0**2.135,
 1.996*X_0**1.153,
 -0.243*X_0**0.442 + 1.001*X_0**1.112 - 0.358*X_0**1.251 + 0.472*X_0**1.606,
 1.001/X_0**0.869 - 0.e-3/X_0**0.042 + 0.334*X_0**0.002 + 0.165*X_0**1.279,
 -0.194/X_0**0.61 + 0.447*X_0**0.01