In [None]:
import os
import json 
import time

import numpy as np
import torch

import matplotlib.pyplot as plt
my_cmap = plt.get_cmap("viridis")

from sympy import lambdify
import sympy as sp
import sklearn
import sklearn.metrics

from symbolicgpt.models import GPT, GPTConfig, PointNetConfig
from symbolicgpt.utils import processDataFiles, CharDataset,\
        sample_from_model, lossFunc
                         
from scipy.optimize import minimize, least_squares    

from pathlib import Path
from functools import partial

#
import glob

from symr.metrics import compute_complexity, compute_shannon_diversity, compute_r2
from symr.fake_sr import loss_function, PolySR, FourierSR, RandomSR


In [None]:
# parameters for running benchmarks

distribution_type = "uniform"
# for normal distribution, use mean and standard deviation.
# for uniform distribution the range is the min and max values
distribution_range = [-.990, 1.0]

number_points = 200
number_trials = 2 # seeds will be trial number
logging = False

In [None]:
eqn_1 = "sin(x0*exp(x0))"
eqn_2 = "x0 + log(x0**4)"
eqn_3 = "1 + x0 * sin(1/x0)"
eqn_4 = "sqrt(x0**3) * log(x0**2)"
eqn_5 = "(x0+x0**3) / (1+x0*cos(x0))"
eqn_6 = "x0 / (sqrt(x0**2 + sin(x0)))"
eqn_7 = "cos( (x0 + sin(x0)) / (x0**3 + x0*log(x0**2)) )"
eqn_8 = "(exp(x0) * (1 + sqrt(1+x0) + cos(x0**2))) / (x0**2)"


sample_meta = {eqn_1: (-1, 1., 200, -2,-1,1, 2),
               eqn_2: (-1, 1., 200, -2,-1,1, 2),
               eqn_3: (-1, 1., 200, -2,-1,1, 2),
               eqn_4: (-1, 1., 200, -2,-1,1, 2),
               eqn_5: (-1, 1., 200, -2,-1,1, 2),
               eqn_6: (-1, 1., 200, -2,-1,1, 2),
               eqn_7: (0, 2., 200, 2, 3, 3, 4),
               eqn_8: (0.2, 2., 200, 2, 3, 3, 4),
              }

benchmark_eqns = [eqn_1, eqn_2, eqn_3, eqn_4, \
        eqn_5, eqn_6, eqn_7, eqn_8]

set_name = "ab"

In [None]:
nguyen_1 = "x0**3 + x0**2 + x0"
nguyen_2 = "x0**4 + x0**3 + x0**2 + x0"
nguyen_3 = "x0**5 + x0**4 + x0**3 + x0**2 + x0"
nguyen_4 = "x0**6 + x0**5 + x0**4 + x0**3 + x0**2 + x0"
nguyen_5 = "sin(x0**2) * cos(x0) - 1"
nguyen_6 = "sin(x0) + sin(x0+x0**2)" 
nguyen_7 = "log(x0+1) + log(x0**2 + 1) "
nguyen_8 = "x0**(1/2)"

# sampling meta are (a,b,c,d,e,f,g) tuples, e.g. c samples in the range from a to b
# d,e and f,g are the min and max values for validation input ranges

sample_meta = {nguyen_1: (-1, 1., 200, -2,-1,1, 2),
               nguyen_2: (-1, 1., 200, -2,-1,1, 2),
               nguyen_3: (-1, 1., 200, -2,-1,1, 2),
               nguyen_4: (-1, 1., 200, -2,-1,1, 2),
               nguyen_5: (-1, 1., 200, -2,-1,1, 2),
               nguyen_6: (-1, 1., 200, -2,-1,1, 2),
               nguyen_7: (0, 2., 200, 2, 3, 3, 4),
               nguyen_8: (0, 4., 200, 4, 6, 6, 8),
              }

benchmark_eqns = [nguyen_1, nguyen_2, nguyen_3, nguyen_4, \
        nguyen_5, nguyen_6, nguyen_7, nguyen_8]

set_name = "nguyen"

In [None]:
"""
Nguyen
equation 0 complexity: 8 diversity: 1.93
equation 1 complexity: 11 diversity: 2.17
equation 2 complexity: 14 diversity: 2.34
equation 3 complexity: 17 diversity: 2.48
equation 4 complexity: 9 diversity: 2.15
equation 5 complexity: 9 diversity: 2.05
equation 6 complexity: 11 diversity: 2.34
equation 7 complexity: 3 diversity: 1.1

A.B. expressions
equation 0 complexity: 5 diversity: 1.55
equation 1 complexity: 6 diversity: 1.73
equation 2 complexity: 8 diversity: 2.03
equation 3 complexity: 10 diversity: 2.25
equation 4 complexity: 14 diversity: 2.4
equation 5 complexity: 10 diversity: 2.15
equation 6 complexity: 19 diversity: 2.76
equation 7 complexity: 17 diversity: 2.63
"""

for ii, equation in enumerate(benchmark_eqns):
    
    complexity = compute_complexity(equation)
    diversity = compute_shannon_diversity(equation)
    
    print(f"equation {ii} complexity: {complexity} diversity: {diversity:.3}")

In [None]:
# visualize equations


plt.figure()
for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
    
    x = np.arange(bottom, top, step_size)
    
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label = f"nguyen-{1+number}")
    

    
plt.legend()
plt.title(f"{set_name} benchmark equations")

plt.figure()
for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
    
  
    x = np.arange(d,e, step_size)
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label= f"{set_name}-{1+number} val", alpha=0.69)
       
plt.legend()

for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
  
    x = np.arange(f,g, step_size)
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label= f"{set_name}-{1+number} val", alpha=0.69)

plt.title("nguyen benchmark validation regions")
plt.show()

In [None]:
my_degree = 20

In [None]:
columns = "eqn, seed, pred, target, correct, mse, r2_ood, r2_id, "
columns += "target_complexity, complexity, method, range_low, range_high, "
columns += "number_points, val_low0, val_high0, val_lo1, val_high1"

for my_method, model_class in zip(["random", "poly", "Fourier"], [RandomSR, PolySR, FourierSR]):

    model = model_class(degree=my_degree)
    
    eval_tag = f"eval_nguyen_{my_method}_{int(time.time())}"

    if logging:
        with open(f"{eval_tag}.csv", "w") as f:
            f.write(columns)


    accuracies = []
    all_mses = []
    all_mse_sds = []

    all_r2_means = []
    all_r2_sds = []

    all_id_r2_means = []
    all_id_r2_sds = []
    catastrophic_failure_count = 0
    for hh, eqn in enumerate(benchmark_eqns):
        equivalents = []
        mses = []

        r2s = []
        r2_ids = []
        complexities = []

        target_complexity = compute_complexity(eqn)

        for trial in range(number_trials):

            np.random.seed(trial)
            torch.manual_seed(trial)

            my_fn = sp.lambdify("x0", expr=eqn)

            (bottom, top, number_samples, d, e, f, g) = sample_meta[eqn]
            x = np.random.rand(number_samples, 1) \
                    * (top-bottom) \
                    + bottom

            y = my_fn(x)

            pred_skeleton = model(x, y)

            c = [1.0 for i,x in enumerate(pred_skeleton) if x=='C']

            #points = torch.cat([x,y], dim=1).float()
            optimized = minimize(loss_function, c, args=(pred_skeleton, x, y), method="BFGS")                                          

            pred_expression = ""
            constants_placed = 0 
            for my_char in pred_skeleton:

                if my_char == "C":
                    pred_expression += f"{optimized.x[constants_placed]}"
                    constants_placed += 1
                else:
                    pred_expression += my_char

            tgt_eqn = sp.simplify(eqn)
            best_eqn = sp.simplify(pred_expression)
            
            tgt_fn = sp.lambdify("x0", expr=eqn)
            best_fn = sp.lambdify("x0", expr=pred_expression)

            is_correct = 1.0 * (sp.simplify(best_eqn - tgt_eqn) == 0) 

            my_mse = np.mean((tgt_fn(x) - best_fn(x))**2)

            id_x = np.arange(bottom, top, (top-bottom)/1000).reshape(-1, 1)

            bigger_x = np.append(np.arange(d,e, (e-d)/500).reshape(-1, 1),\
                                 np.arange(f,g, (g-f)/500).reshape(-1,1))

            bigger_y_true = tgt_fn(bigger_x)
            bigger_y_pred = best_fn(bigger_x)

            id_y_true = tgt_fn(id_x)
            id_y_pred = best_fn(id_x)

            if np.isfinite(bigger_y_pred.mean()):
                my_r2 = compute_r2(bigger_y_true, bigger_y_pred)
                #sklearn.metrics.r2_score(bigger_y_true, bigger_y_pred)
            else:
                my_r2 = np.nan

            if np.isfinite(id_y_pred.mean()):
                my_r2_id = compute_r2(bigger_y_true, bigger_y_pred)
                #sklearn.metrics.r2_score(id_y_true, id_y_pred)
            else:
                my_r2_id = np.nan


            my_complexity = compute_complexity(best_eqn)
            complexities.append(my_complexity)
            mses.append(my_mse)
            r2s.append(my_r2)
            r2_ids.append(my_r2_id)

            # fake is correct, based on r2 in distribution
            is_correct = my_r2_id > 0.99
            equivalents.append(is_correct)

            wright = "correct" if is_correct else "incorrect"

            correct = 1 if is_correct else 0

            try:
                msg = f"eqn {hh+1}, trial {trial} predicted {wright} equation: \n   "
                msg += f" with mse {my_mse:.3} and r^2 ood = {my_r2:.4}, r^2 id {my_r2_id:.4f}"
                msg += f"\n"
            except:
                msg = ""
            print(msg)

            #columns = "eqn, seed, pred, target, correct, mse, method, range_low, range_high, r2, number_points"
            if logging:
                results = f"{eqn}, {hh}, {best_eqn}, {tgt_eqn}, {correct}, {my_mse}, {my_r2}, "\
                        f"{my_r2_id}, {target_complexity}, {my_complexity}, "\
                        f" {my_method}, {bottom}, {top}, {number_samples}, {d}, {e}, {f}, {g}"

                with open(f"{eval_tag}.csv", "a") as f:
                    f.write(results)
                    f.write("\n")

        msg = f"accuracy for equation {hh+1}: {np.mean(equivalents)}"\
                f" with mean mse: {np.mean(mses):3}, running total failure count: {catastrophic_failure_count}\n"
        print(msg)
        accuracies.append(np.mean(equivalents))
        all_mses.append(np.mean(mses))
        all_mse_sds.append(np.std(mses))
        all_r2_means.append(np.mean(r2s))
        all_r2_sds.append(np.std(r2s))
        all_id_r2_means.append(np.mean(r2_ids))
        all_id_r2_sds.append(np.std(r2_ids))

    failure_msg = f"\nTotal failure count: {catastrophic_failure_count}, of {len(benchmark_eqns)*number_trials}"
    failure_msg += f" = {catastrophic_failure_count / (len(benchmark_eqns)*number_trials)}"
    print(failure_msg)


    msg = f"{my_method} accuracies\n"

    for ii, eqn in enumerate(benchmark_eqns):

        msg += f"\n  equation-{ii+1},  accuracy: {accuracies[ii]:5f}, "\
                f"mse: {all_mses[ii]:.5} +/- {all_mse_sds[ii]:.5} \n"\
                f"r^2 i.d.: {all_id_r2_means[ii]:.5} +/- {all_id_r2_sds[ii]:.5} and "\
                f"r^2 o.o.d.: {all_r2_means[ii]:.5} +/- {all_r2_sds[ii]:.5}\n"

        msg += f"  {sp.expand(eqn)} \n"

    if logging:
        summary_file = f"{eval_tag}_summary.txt"
        with open(summary_file, "w") as f:
            f.write(msg)

    print(msg)

In [None]:
"""
Total failure count: 0, of 80 = 0.0
random accuracies

  equation-1,  accuracy: 0.300000, mse: 1610.3 +/- 2917.0 
r^2 i.d.: -2861.1 +/- 4417.5 and r^2 o.o.d.: -10.184 +/- 11.945
  x0**3 + x0**2 + x0 

  equation-2,  accuracy: 0.300000, mse: 1609.9 +/- 2917.7 
r^2 i.d.: -2327.4 +/- 3594.9 and r^2 o.o.d.: -4.2523 +/- 4.5978
  x0**4 + x0**3 + x0**2 + x0 

  equation-3,  accuracy: 0.300000, mse: 1610.0 +/- 2918.2 
r^2 i.d.: -1547.3 +/- 2390.5 and r^2 o.o.d.: 0.53232 +/- 0.59292
  x0**5 + x0**4 + x0**3 + x0**2 + x0 

  equation-4,  accuracy: 0.300000, mse: 1609.8 +/- 2918.6 
r^2 i.d.: -1321.7 +/- 2042.6 and r^2 o.o.d.: 0.22509 +/- 0.56231
  x0**6 + x0**5 + x0**4 + x0**3 + x0**2 + x0 

  equation-5,  accuracy: 0.300000, mse: 1614.9 +/- 2911.1 
r^2 i.d.: -1.0304e+05 +/- 1.5864e+05 and r^2 o.o.d.: -2.1256e+04 +/- 1.9781e+04
  sin(x0**2)*cos(x0) - 1 

  equation-6,  accuracy: 0.300000, mse: 1610.9 +/- 2917.6 
r^2 i.d.: -2980.7 +/- 4601.3 and r^2 o.o.d.: -1111.4 +/- 1004.8
  sin(x0) + sin(x0**2 + x0) 

  equation-7,  accuracy: 1.000000, mse: 8.55e-07 +/- 1.4807e-06 
r^2 i.d.: 1.0 +/- 5.9056e-06 and r^2 o.o.d.: -1.245e+04 +/- 3.5789e+04
  log(x0 + 1) + log(x0**2 + 1) 

  equation-8,  accuracy: 0.700000, mse: 3.099e+05 +/- 9.297e+05 
r^2 i.d.: nan +/- nan and r^2 o.o.d.: -3.6281e+10 +/- 1.0868e+11
  sqrt(x0) 


poly accuracies (10th degree)

  equation-1,  accuracy: 1.000000, mse: 0.0022448 +/- 0.0010967 
r^2 i.d.: 0.99933 +/- 0.00071256 and r^2 o.o.d.: -2.1803e+09 +/- 1.3989e+09
  1.0*x0/(1.0*x0*exp(1.0*x0)*exp(1.0*sin(1.0*x0 + 1.0)) + 1.0*exp(1.0*x0)*exp(1.0*sin(1.0*x0 + 1.0))) + 1.0 

  equation-2,  accuracy: 1.000000, mse: 6.7079e-08 +/- 3.2145e-08 
r^2 i.d.: 0.99999 +/- 1.6076e-06 and r^2 o.o.d.: -5.9602e+06 +/- 1.7796e+06
  1.0*sin(1.0*x0/(1.0*x0 + 1.0) + 1.0/(1.0*x0 + 1.0))*Abs(1.0*x0 + 1.0)**(1/4) + 1.0 

  equation-3,  accuracy: 1.000000, mse: 6.2544e-08 +/- 1.7855e-08 
r^2 i.d.: 1.0 +/- 6.9966e-07 and r^2 o.o.d.: -1.3639e+04 +/- 1.2049e+04
  1.0*exp(0.5*re(x0)) + 1.0*sqrt(Abs(1.0*x0/(1.0*x0 + 1.0) + 1.0/(1.0*x0 + 1.0))) + 1.0 

  equation-4,  accuracy: 0.000000, mse: 3.9605 +/- 3.7622 
r^2 i.d.: 0.96345 +/- 0.017341 and r^2 o.o.d.: -3.9659e+14 +/- 2.215e+14
  1.0*sqrt(Abs(x0)) + 1.0 + 1.0/(x0*Abs(x0)) 

  equation-5,  accuracy: 0.000000, mse: 0.017301 +/- 0.009087 
r^2 i.d.: 0.90617 +/- 0.031298 and r^2 o.o.d.: -3.3336e+13 +/- 3.9097e+13
  1.0*x0/(1.0*x0*sqrt(Abs(log(x0))) + 1.0*sqrt(Abs(log(x0)))) + 1.0 

  equation-6,  accuracy: 1.000000, mse: 0.0024263 +/- 0.0011845 
r^2 i.d.: 0.99993 +/- 3.7152e-05 and r^2 o.o.d.: -5.1229e+07 +/- 7.2238e+07
  1.0*x0**2*sqrt(Abs(1.0*x0 + 1.0)) + 1.0*x0 + 1.0*log(1.0*x0 + 1.0) + 2.0 

  equation-7,  accuracy: 1.000000, mse: 1.0064e-05 +/- 7.3207e-07 
r^2 i.d.: 0.99852 +/- 0.00018173 and r^2 o.o.d.: -1.3808e+08 +/- 2.104e+07
  1.0*sqrt(Abs(1.0*x0*sqrt(Abs(x0))/(1.0*x0 + 1.0) + 1.0/(1.0*x0 + 1.0))) + 1.0 

  equation-8,  accuracy: 1.000000, mse: 8.6062e-07 +/- 1.9929e-07 
r^2 i.d.: 1.0 +/- 7.3036e-08 and r^2 o.o.d.: -8333.7 +/- 592.41
  1.0*x0**2 + 1.0*x0 + 1.0*exp(1.0*x0)*Abs(1.0*x0 + 1.0)**(1/4) + 1.0 
  

Fourier accuracies (10 components)

  equation-1,  accuracy: 1.000000, mse: 0.0021909 +/- 0.0011683 
r^2 i.d.: 0.99923 +/- 0.0010648 and r^2 o.o.d.: -111.66 +/- 47.461
  1.0*x0/(1.0*x0*exp(1.0*x0)*exp(1.0*sin(1.0*x0 + 1.0)) + 1.0*exp(1.0*x0)*exp(1.0*sin(1.0*x0 + 1.0))) + 1.0 

  equation-2,  accuracy: 1.000000, mse: 1.1282e-06 +/- 4.3993e-07 
r^2 i.d.: 0.99988 +/- 9.7977e-05 and r^2 o.o.d.: -48.767 +/- 1.7621
  1.0*sin(1.0*x0/(1.0*x0 + 1.0) + 1.0/(1.0*x0 + 1.0))*Abs(1.0*x0 + 1.0)**(1/4) + 1.0 

  equation-3,  accuracy: 1.000000, mse: 9.5529e-07 +/- 1.1701e-06 
r^2 i.d.: 0.99998 +/- 1.862e-05 and r^2 o.o.d.: -2.7922 +/- 0.18415
  1.0*exp(0.5*re(x0)) + 1.0*sqrt(Abs(1.0*x0/(1.0*x0 + 1.0) + 1.0/(1.0*x0 + 1.0))) + 1.0 

  equation-4,  accuracy: 0.900000, mse: 0.23874 +/- 0.16687 
r^2 i.d.: 0.99564 +/- 0.0046563 and r^2 o.o.d.: -2.8586e+08 +/- 2.3535e+08
  1.0*sqrt(Abs(x0)) + 1.0 + 1.0/(x0*Abs(x0)) 

  equation-5,  accuracy: 0.000000, mse: 0.0025657 +/- 0.0018011 
r^2 i.d.: 0.9804 +/- 0.010192 and r^2 o.o.d.: -2.2758e+07 +/- 2.8889e+07
  1.0*x0/(1.0*x0*sqrt(Abs(log(x0))) + 1.0*sqrt(Abs(log(x0)))) + 1.0 

  equation-6,  accuracy: 1.000000, mse: 1.2452e-05 +/- 1.3044e-05 
r^2 i.d.: 1.0 +/- 1.1244e-05 and r^2 o.o.d.: -3.9517 +/- 0.077104
  1.0*x0**2*sqrt(Abs(1.0*x0 + 1.0)) + 1.0*x0 + 1.0*log(1.0*x0 + 1.0) + 2.0 

  equation-7,  accuracy: 1.000000, mse: 2.3033e-06 +/- 1.7938e-06 
r^2 i.d.: 0.99966 +/- 0.0002593 and r^2 o.o.d.: -97.253 +/- 10.432
  1.0*sqrt(Abs(1.0*x0*sqrt(Abs(x0))/(1.0*x0 + 1.0) + 1.0/(1.0*x0 + 1.0))) + 1.0 

  equation-8,  accuracy: 1.000000, mse: 1.5729e-06 +/- 1.3872e-06 
r^2 i.d.: 1.0 +/- 1.5567e-06 and r^2 o.o.d.: -0.69905 +/- 0.043185
  1.0*x0**2 + 1.0*x0 + 1.0*exp(1.0*x0)*Abs(1.0*x0 + 1.0)**(1/4) + 1.0 
"""