In [None]:
import os
import json 
import time

import numpy as np
import torch

import matplotlib.pyplot as plt
my_cmap = plt.get_cmap("viridis")

from sympy import lambdify
import sympy as sp
import sklearn
import sklearn.metrics

import nesymres
from nesymres.architectures.model import Model
from nesymres.utils import load_metadata_hdf5
from nesymres.dclasses import FitParams, NNEquation, BFGSParams
import omegaconf
  

from pathlib import Path
from functools import partial

import omegaconf
#
import glob

from symr.metrics import compute_complexity, compute_tree_distance, compute_exact_equivalence,\
        compute_r2, compute_r2_truncated, compute_relative_error, compute_isclose_accuracy,\
        compute_r2_over_threshold

In [None]:
# parameters for running benchmarks

distribution_type = "uniform"
# for normal distribution, use mean and standard deviation.
# for uniform distribution the range is the min and max values
distribution_range = [-.0, 4.0]

number_points = 200
number_trials = 100 # seeds will be trial number
logging = False


In [None]:
if(0):
    complex_1 = "1.0*x0/(1.0*x0*1.0*exp(1.0*x0)*1.0*exp(1.0*sin(1.0*x0+1.0))+1.0*exp(1.0*x0)*1.0*exp(1.0*sin(1.0*x0+1.0)))+1.0"
    complex_2 = "1.0*sqrt(1.0*abs(1.0*sqrt(1.0*abs(1.0*x0+1.0))))*1.0*sin(1.0*x0/(1.0*x0+1.0)+1.0/(1.0*x0+1.0))+1.0"
    complex_3 = "1.0*sqrt(1.0*abs(1.0*x0/(1.0*x0+1.0)+1.0/(1.0*x0+1.0)))+1.0*sqrt(1.0*abs(1.0*exp(1.0*x0)))+1.0"
    complex_4 = "1.0*sqrt(1.0*abs(1.0*x0))+1.0/(x0*1.0*sqrt(1.0*abs(1.0*x0))*1.0*sqrt(1.0*abs(1.0*x0)))+1.0"
    complex_5 = "1.0*x0/(1.0*x0*1.0*sqrt(1.0*abs(1.0*log(1.0*x0)))+1.0*sqrt(1.0*abs(1.0*log(1.0*x0))))+1.0"
    complex_6 = "1.0*x0**2*1.0*sqrt(1.0*abs(1.0*x0+1.0))+1.0*x0+1.0*x0/(1.0*x0)+1.0*log(1.0*x0+1.0)+1.0"
    complex_7 = "1.0*sqrt(1.0*abs(1.0*x0*1.0*sqrt(1.0*abs(1.0*x0))/(1.0*x0+1.0)+1.0/(1.0*x0+1.0)))+1.0"
    complex_8 = "1.0*x0**2+1.0*x0+1.0*exp(1.0*x0)/1.0*sqrt(1.0*abs(1.0*sqrt(1.0*abs(1.0*x0+1.0))))+1.0"


    sample_meta = {complex_1: (-.9, 1., number_points, -3,-1.1,1, 3),
                   complex_2: (-.9, 1., number_points, -3,-1.1,1, 3),
                   complex_3: (-0.9, 1., number_points, -3,-1,1, 3),
                   complex_4: (0.1, 2., number_points, 2, 3, 3, 4),
                   complex_5: (1.01, 2, number_points, 2, 3, 3, 4),
                   complex_6: (-.9, 3., number_points, 3, 5, 5, 7),
                   complex_7: (-0.9, 1., number_points, -3,-1.1, 1, 3),
                   complex_8: (-0.9, 1., number_points, -3,-1.1, 1, 3),              }
    set_name = "generated_complex"
    
elif(1):
    complex_1 = "sin(x0 * exp(x0))"
    complex_2 = "x0 + log(x0**4)"
    complex_3 = "1+x0*sin(1/x0)"
    complex_4 = "sqrt(x0**3) * log(x0**2)"
    complex_5 = "(x0+x0**3) / (1+x0*cos(x0**2))"
    complex_6 = "x0 / (sqrt(x0**2 + sin(x0)))"
    complex_7 = "cos((x0+sin(x0))/ (x0**3+x0*log(x0**2)))"
    complex_8 = "(exp(x0) * (1+sqrt(1+x0) + cos(x0**2)))/ x0**2"

    # sampling meta are (a,b,c,d,e,f,g) tuples, e.g. c samples in the range from a to b
    # d,e and f,g are the min and max values for validation input ranges

    sample_meta = {complex_1: (-5., 5., number_points, -7,-5, 5,7),
                   complex_2: (-5., 5., number_points, -7,-5, 5,7),
                   complex_3: (-5., 5., number_points, -7,-5, 5,7),
                   complex_4: (-5., 5., number_points, -7,-5, 5,7),
                   complex_5: (-1., .5, number_points, -2, -1., .5, 1.25),
                   complex_6: (-5., 5., number_points, -7,-5, 5,7),
                   complex_7: (-5., 5., number_points, -7,-5, 5,7),
                   complex_8: (.1, 5., number_points, -7,-5, 5,7),
                  }            

    set_name = "AB_complex"
    
    
benchmark_eqns = [complex_1, complex_2, complex_3, complex_4, \
        complex_5, complex_6, complex_7, complex_8]

In [None]:
use_bfgs = True

if use_bfgs:
    my_method = "nsrts_w_bfgs"
else:
    assert False, "not implemented"
    

In [None]:
# visualize equations


plt.figure()
for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
    
    x = np.arange(bottom, top, step_size)
    
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label = f"{set_name}-{1+number}")
    
plt.legend()
plt.title(f"{set_name} benchmark equations")

plt.figure()
for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
    
  
    x = np.arange(d,e, step_size)
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label= f"{set_name}-{1+number} val", alpha=0.69)
       
plt.legend()

for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
  
    x = np.arange(f,g, step_size)
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label= f"{set_name}-{1+number} val", alpha=0.69)

plt.title(f"{set_name} benchmark validation regions")
plt.show()

In [None]:
my_beam_width = 1

In [None]:
# needs to point to directory containing 
#  weights/100M.ckpt and jupyter/100M/eq_settings.json and jupyter/100M/config.yaml

eval_tag = f"eval_complex_{my_method}_{int(time.time())}"
columns = "eqn, seed, pred, target, correct, mse,  r2_ood, r2_id,"
columns += "target_complexity, complexity, method, range_low, range_high, "
columns += "number_points, val_low0, val_high0, val_lo1, val_high1"

if logging:
    with open(f"{eval_tag}.csv", "w") as f:
        f.write(columns)
        f.write("\n")

nsrts_dir = "../../nsrts"

json_filepath = os.path.join(nsrts_dir, "jupyter", "100M", "eq_setting.json")
with open(json_filepath, 'r') as json_file:
    eq_setting = json.load(json_file)
     
cfg_filepath = os.path.join(nsrts_dir, "jupyter", "100M", "config.yaml")
cfg = omegaconf.OmegaConf.load(cfg_filepath)

weights_path = os.path.join(nsrts_dir, "weights", "100M.ckpt")
    
## Set up BFGS load rom the hydra config yaml
bfgs = BFGSParams(
        activated= cfg.inference.bfgs.activated,
        n_restarts=cfg.inference.bfgs.n_restarts,
        add_coefficients_if_not_existing=cfg.inference.bfgs.add_coefficients_if_not_existing,
        normalization_o=cfg.inference.bfgs.normalization_o,
        idx_remove=cfg.inference.bfgs.idx_remove,
        normalization_type=cfg.inference.bfgs.normalization_type,
        stop_time=cfg.inference.bfgs.stop_time,
    )

# adjust this parameter up for greater accuracy and longer runtime
cfg.inference.beam_size = my_beam_width

params_fit = FitParams(word2id=eq_setting["word2id"], 
                            id2word={int(k): v for k,v in eq_setting["id2word"].items()}, 
                            una_ops=eq_setting["una_ops"], 
                            bin_ops=eq_setting["bin_ops"], 
                            total_variables=list(eq_setting["total_variables"]),  
                            total_coefficients=list(eq_setting["total_coefficients"]),
                            rewrite_functions=list(eq_setting["rewrite_functions"]),
                            bfgs=bfgs,
                            beam_size=cfg.inference.beam_size #This parameter is a tradeoff between accuracy and fitting time
                            )

## Load equation configuration and architecture configuration
accuracies = []
all_mses = []
all_mse_sds = []

all_r2_means = []
all_r2_sds = []


all_tree_distances = []
all_ares = []

all_tree_distance_sds = []
all_are_sds = []

catastrophic_failure_count = 0
for hh, eqn in enumerate(benchmark_eqns):
    
    equivalents = []
    mses = []
    r2s = []
    complexities = []
    ares = []
    tree_distances = []
    
    target_complexity = compute_complexity(eqn)
    
    for trial in range(number_trials):

        np.random.seed(trial)
        torch.manual_seed(trial)
        
        ## Load architecture, set into eval mode, and pass the config parameters
        model = Model.load_from_checkpoint(weights_path, cfg=cfg.architecture)
        model.eval()
        if torch.cuda.is_available(): 
            model.to(torch.device("cuda:1")) 
            
        fitfunc = partial(model.fitfunc, cfg_params=params_fit)

        my_fn = sp.lambdify("x0", expr=eqn)
        
        
        # work-around for occasional catastrophic failure
        output = {"best_bfgs_preds": []}
        
        np.random.seed(trial)
        torch.manual_seed(trial)
        
        #while len(output["best_bfgs_preds"]) == 0:
        
        (bottom, top, number_samples, d, e, f, g) = sample_meta[eqn]
        x = np.random.rand(number_samples, 1) \
                * (top-bottom) \
                + bottom
        y = my_fn(x)

        x = torch.tensor(x)
        y = torch.tensor(y)

        try:
            output = fitfunc(x, y.squeeze()) 

            
            if (x >= 0).all():
                print("removing abs")
                # remove abs() if input only consists of positive numbers
                no_abs = output["best_bfgs_preds"][0].replace("Abs","")
                best_eqn = sp.simplify(no_abs.replace("x_1", "x0"))
                if "abs" in eqn:
                    tgt_eqn = sp.simplify(eqn.replace("abs", ""))
                else:
                    tgt_eqn = sp.simplify(eqn)
            else:
                best_eqn = sp.simplify(output["best_bfgs_preds"][0].replace("x_1", "x0"))
                tgt_eqn = sp.simplify(eqn)

            if "x_2" in output["best_bfgs_preds"][0]:
                is_correct = 0
                my_mse = nan
            else:
                is_correct = sp.simplify(best_eqn - tgt_eqn) == 0
                tgt_fn = sp.lambdify("x0", expr=tgt_eqn)
                best_fn = sp.lambdify("x0", expr=best_eqn)
                my_mse = np.mean((tgt_fn(x.cpu().numpy()) - best_fn(x.cpu().numpy()))**2)

                id_x = np.arange(bottom, top, (top-bottom)/1000).reshape(-1, 1)

            bigger_x = np.append(np.arange(d,e, (e-d)/500).reshape(-1, 1),\
                                 np.arange(f,g, (g-f)/500).reshape(-1,1))

            bigger_y_true = tgt_fn(bigger_x)
            bigger_y_pred = best_fn(bigger_x)

            id_y_true = tgt_fn(id_x)
            id_y_pred = best_fn(id_x)

            if np.isfinite(bigger_y_pred.mean()):
                my_r2 = sklearn.metrics.r2_score(bigger_y_true, bigger_y_pred)
            else:
                my_r2 = np.nan

            if np.isfinite(id_y_pred.mean()):
                my_r2_id = sklearn.metrics.r2_score(id_y_true, id_y_pred)
            else:
                my_r2_id = np.nan


            my_complexity = compute_complexity(best_eqn)
            my_ted = compute_tree_distance(eqn, best_eqn)
            my_are = compute_relative_error(id_y_true, id_y_pred)

            complexities.append(my_complexity)
            mses.append(my_mse)
            r2s.append(my_r2_id)
            equivalents.append(is_correct)

            tree_distances.append(my_ted)
            ares.append(my_are)

            wright = "correct" if equivalents[-1] else "incorrect"
            correct = 1 if equivalents[-1] else 0

            msg = f"eqn {hh+1}, trial {trial} predicted {wright} equation {best_eqn} for target {tgt_eqn}"
            msg += f" with mse {my_mse:.3} and r2 = {my_r2:.4}"
            print(msg)

            #columns = "eqn, seed, pred, target, correct, mse, method, range_low, range_high, r2, number_points"
            if logging:
                results = f"{eqn}, {hh}, {best_eqn}, {tgt_eqn}, {correct}, {my_mse}, {my_r2}, "\
                        f"{my_r2_id}, {target_complexity}, {my_complexity}, "\
                        f" {my_method}, {bottom}, {top}, {number_samples}, {d}, {e}, {f}, {g}"

                with open(f"{eval_tag}.csv", "a") as f:
                    f.write(results)
                    f.write("\n")
        except:
            print("catastrophic fail of some kind")
            catastrophic_failure_count += 1
   
    msg = f"accuracy for equation {hh+1}: {np.mean(equivalents)}"\
            f" with mean mse: {np.mean(mses):3}, r2 = {np.mean(r2s)},"\
            f" and avg. complexity {np.mean(complexities)}"  
    
    print(msg)
    if len(tree_distances):
        accuracies.append(np.mean(equivalents))
        all_mses.append(np.mean(mses))
        all_mse_sds.append(np.std(mses))
        all_r2_means.append(np.mean(r2s))
        all_r2_sds.append(np.std(r2s))

        all_ares.append(np.mean(ares))
        all_tree_distances.append(np.mean(tree_distances))

        all_are_sds.append(np.std(ares))
        all_tree_distance_sds.append(np.std(tree_distances))
    else:
        accuracies.append(0.0)
        all_mses.append(float("inf"))
        all_mse_sds.append(float("inf"))
        all_r2_means.append(-float("inf"))
        all_r2_sds.append(float("inf"))

        all_ares.append(float("inf"))
        all_tree_distances.append(float("inf"))

        all_are_sds.append(float("inf"))
        all_tree_distance_sds.append(float("inf"))
        
failure_msg = f"\nTotal failure count: {catastrophic_failure_count}, of {len(benchmark_eqns)*number_trials}"
failure_msg += f" = {catastrophic_failure_count / (len(benchmark_eqns)*number_trials)}"
print(failure_msg)

In [None]:

msg = f"{my_method} accuracies\n"
print(failure_msg)

for ii, eqn in enumerate(benchmark_eqns):
    
    msg += f"\n  complex-{ii+1}, exact equ. accuracy: {accuracies[ii]:5f}, "\
            f"\ntree edit distance: {all_tree_distances[ii]:.5} +/- {all_tree_distance_sds[ii]:.5} "\
            f"\nmse: {all_mses[ii]:.5} +/- {all_mse_sds[ii]:.5} "\
            f"\nrelative absolute error: {all_ares[ii]:.5} +/- {all_are_sds[ii]:.5} "\
            f"\nr^2 (i.d): {all_r2_means[ii]:.5} +/- {all_r2_sds[ii]:.5}\n\n"
    
    msg += f"  {sp.expand(eqn)} \n"

if logging:
    summary_file = f"{eval_tag}_summary.txt"
    with open(summary_file, "w") as f:
        f.write(msg)
        
print(msg)

In [None]:
"""
Total failure count: 300, of 800 = 0.375
nsrts_w_bfgs accuracies

  complex-1, exact equ. accuracy: 0.990000, 
tree edit distance: 0.03 +/- 0.2985 
mse: 6.989e-15 +/- 6.9539e-14 
relative absolute error: 1.4083e-08 +/- 1.4012e-07 
r^2 (i.d): 1.0 +/- 2.1803e-13

  sin(x0*exp(x0)) 

  complex-2, exact equ. accuracy: 1.000000, 
tree edit distance: 0.0 +/- 0.0 
mse: 0.0 +/- 0.0 
relative absolute error: 0.0 +/- 0.0 
r^2 (i.d): 1.0 +/- 0.0

  x0 + log(x0**4) 

  complex-3, exact equ. accuracy: 0.000000, 
tree edit distance: 8.0 +/- 0.0 
mse: 0.019544 +/- 0.0075421 
relative absolute error: 0.05534 +/- 0.0092023 
r^2 (i.d): 0.65063 +/- 0.034109

  x0*sin(1/x0) + 1 

  complex-4, exact equ. accuracy: 0.000000, 
tree edit distance: inf +/- inf 
mse: inf +/- inf 
relative absolute error: inf +/- inf 
r^2 (i.d): -inf +/- inf

  sqrt(x0**3)*log(x0**2) 

  complex-5, exact equ. accuracy: 0.000000, 
tree edit distance: 20.72 +/- 1.1232 
mse: 0.13449 +/- 0.36205 
relative absolute error: 0.15297 +/- 0.081074 
r^2 (i.d): 0.95056 +/- 0.11777

  x0**3/(x0*cos(x0**2) + 1) + x0/(x0*cos(x0**2) + 1) 

  complex-6, exact equ. accuracy: 0.000000, 
tree edit distance: inf +/- inf 
mse: inf +/- inf 
relative absolute error: inf +/- inf 
r^2 (i.d): -inf +/- inf

  x0/sqrt(x0**2 + sin(x0)) 

  complex-7, exact equ. accuracy: 0.000000, 
tree edit distance: 23.0 +/- 0.0 
mse: 0.15846 +/- 0.041775 
relative absolute error: 0.54942 +/- 0.042619 
r^2 (i.d): 0.32733 +/- 0.044813

  cos(x0/(x0**3 + x0*log(x0**2)) + sin(x0)/(x0**3 + x0*log(x0**2))) 

  complex-8, exact equ. accuracy: 0.000000, 
tree edit distance: inf +/- inf 
mse: inf +/- inf 
relative absolute error: inf +/- inf 
r^2 (i.d): -inf +/- inf

  sqrt(x0 + 1)*exp(x0)/x0**2 + exp(x0)*cos(x0**2)/x0**2 + exp(x0)/x0**2 1
"""

In [None]:

msg = f"{my_method} accuracies\n"

for ii, eqn in enumerate(benchmark_eqns):
    
    msg += f"\n  Complex-{ii+1},  accuracy: {accuracies[ii]:5f}, "\
            f"mse: {all_mses[ii]:.5} +/- {all_mse_sds[ii]:.5} "\
            f"r^2: {all_r2_means[ii]:.5} +/- {all_r2_sds[ii]:.5}\n"
    
    msg += f"  {sp.expand(eqn)} \n"

if logging:
    summary_file = f"{eval_tag}_summary.txt"
    with open(summary_file, "w") as f:
        f.write(msg)
print(msg)