In [None]:
import os
import json 
import time

import numpy as np
import torch

import matplotlib.pyplot as plt
my_cmap = plt.get_cmap("viridis")

from sympy import lambdify
import sympy as sp
import sklearn
import sklearn.metrics

from symbolicgpt.models import GPT, GPTConfig, PointNetConfig
from symbolicgpt.utils import processDataFiles, CharDataset,\
        sample_from_model, lossFunc
                         
from scipy.optimize import minimize, least_squares    

from pathlib import Path
from functools import partial

#
import glob



from symr.metrics import compute_complexity, compute_tree_distance, compute_exact_equivalence,\
        compute_r2, compute_r2_truncated, compute_relative_error, compute_isclose_accuracy,\
        compute_r2_over_threshold

In [None]:
# parameters for running benchmarks

distribution_type = "uniform"
# for normal distribution, use mean and standard deviation.
# for uniform distribution the range is the min and max values
distribution_range = [-.990, 1.0]
number_points = 200
number_trials = 100 # seeds will be trial number
logging = False

In [None]:
if(0):
    complex_1 = "1.0*x0/(1.0*x0*1.0*exp(1.0*x0)*1.0*exp(1.0*sin(1.0*x0+1.0))+1.0*exp(1.0*x0)*1.0*exp(1.0*sin(1.0*x0+1.0)))+1.0"
    complex_2 = "1.0*sqrt(1.0*abs(1.0*sqrt(1.0*abs(1.0*x0+1.0))))*1.0*sin(1.0*x0/(1.0*x0+1.0)+1.0/(1.0*x0+1.0))+1.0"
    complex_3 = "1.0*sqrt(1.0*abs(1.0*x0/(1.0*x0+1.0)+1.0/(1.0*x0+1.0)))+1.0*sqrt(1.0*abs(1.0*exp(1.0*x0)))+1.0"
    complex_4 = "1.0*sqrt(1.0*abs(1.0*x0))+1.0/(x0*1.0*sqrt(1.0*abs(1.0*x0))*1.0*sqrt(1.0*abs(1.0*x0)))+1.0"
    complex_5 = "1.0*x0/(1.0*x0*1.0*sqrt(1.0*abs(1.0*log(1.0*x0)))+1.0*sqrt(1.0*abs(1.0*log(1.0*x0))))+1.0"
    complex_6 = "1.0*x0**2*1.0*sqrt(1.0*abs(1.0*x0+1.0))+1.0*x0+1.0*x0/(1.0*x0)+1.0*log(1.0*x0+1.0)+1.0"
    complex_7 = "1.0*sqrt(1.0*abs(1.0*x0*1.0*sqrt(1.0*abs(1.0*x0))/(1.0*x0+1.0)+1.0/(1.0*x0+1.0)))+1.0"
    complex_8 = "1.0*x0**2+1.0*x0+1.0*exp(1.0*x0)/1.0*sqrt(1.0*abs(1.0*sqrt(1.0*abs(1.0*x0+1.0))))+1.0"


    sample_meta = {complex_1: (-.9, 1., number_points, -3,-1.1,1, 3),
                   complex_2: (-.9, 1., number_points, -3,-1.1,1, 3),
                   complex_3: (-0.9, 1., number_points, -3,-1,1, 3),
                   complex_4: (0.1, 2., number_points, 2, 3, 3, 4),
                   complex_5: (1.01, 2, number_points, 2, 3, 3, 4),
                   complex_6: (-.9, 3., number_points, 3, 5, 5, 7),
                   complex_7: (-0.9, 1., number_points, -3,-1.1, 1, 3),
                   complex_8: (-0.9, 1., number_points, -3,-1.1, 1, 3),              }
    set_name = "generated_complex"
    
elif(1):
    complex_1 = "sin(x0 * exp(x0))"
    complex_2 = "x0 + log(x0**4)"
    complex_3 = "1+x0*sin(1/x0)"
    complex_4 = "sqrt(x0**3) * log(x0**2)"
    complex_5 = "(x0+x0**3) / (1+x0*cos(x0**2))"
    complex_6 = "x0 / (sqrt(x0**2 + sin(x0)))"
    complex_7 = "cos((x0+sin(x0))/ (x0**3+x0*log(x0**2)))"
    complex_8 = "(exp(x0) * (1+sqrt(1+x0) + cos(x0**2)))/ x0**2"

    # sampling meta are (a,b,c,d,e,f,g) tuples, e.g. c samples in the range from a to b
    # d,e and f,g are the min and max values for validation input ranges

    sample_meta = {complex_1: (-5., 5., number_points, -7,-5, 5,7),
                   complex_2: (-5., 5., number_points, -7,-5, 5,7),
                   complex_3: (-5., 5., number_points, -7,-5, 5,7),
                   complex_4: (-5., 5., number_points, -7,-5, 5,7),
                   complex_5: (-1., .5, number_points, -2, -1., .5, 1.25),
                   complex_6: (-5., 5., number_points, -7,-5, 5,7),
                   complex_7: (-5., 5., number_points, -7,-5, 5,7),
                   complex_8: (.1, 5., number_points, -7,-5, 5,7),
                  }            

    set_name = "AB_complex"
    
    
benchmark_eqns = [complex_1, complex_2, complex_3, complex_4, \
        complex_5, complex_6, complex_7, complex_8]

In [None]:
use_bfgs = True

if use_bfgs:
    my_method = "valipour_w_bfgs"
else:
    
    my_method = "valipour_no_bfgs"

In [None]:
# visualize equations


plt.figure()
for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
    
    x = np.arange(bottom, top, step_size)
    
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label = f"{set_name}-{1+number}")
    
plt.legend()
plt.title(f"{set_name} benchmark equations")

plt.figure()
for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
    
  
    x = np.arange(d,e, step_size)
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label= f"{set_name}-{1+number} val", alpha=0.69)
       
plt.legend()

for number, fn in enumerate(benchmark_eqns):

    my_fn = sp.lambdify("x0", expr=fn)
    
    (bottom, top, c,d,e,f,g) = sample_meta[fn]
    
    step_size = (top-bottom)/10000
  
    x = np.arange(f,g, step_size)
    my_color = my_cmap(number/len(benchmark_eqns))
    plt.plot(x, my_fn(x), color=my_color, label= f"{set_name}-{1+number} val", alpha=0.69)

plt.title(f"{set_name} benchmark validation regions")
plt.show()

In [None]:
# from symbolicGPT.py in Valipour

embeddingSize = 512
numPoints = [20,21]
numVars = 1
numYs = 1
method = "EMB_SUM"
variableEmbedding = "NOT_VAR"

# create the model                                                              
pconf = PointNetConfig(embeddingSize=embeddingSize,                             
                       numberofPoints=numPoints[1]-1,                           
                       numberofVars=numVars,                                    
                       numberofYs=numYs,                                        
                       method=method,                                           
                       variableEmbedding=variableEmbedding)    


In [None]:
blockSize = 64
maxNumFiles = 100
const_range = [-2.1, 2.1]
decimals = 8
trainRange = [-3.0,3.0]

target = "Skeleton"
addVars = True if variableEmbedding == 'STR_VAR' else False
path = os.path.join("./symbolicgpt", "datasets", "exp_test_temp", "Train", "*.json")
my_device = torch.device("cpu")


files = glob.glob(path)[:maxNumFiles]                                       
text = processDataFiles(files) 
chars = sorted(list(set(text))+['_','T','<','>',':']) # extract unique characters from the text before converting the text to a list, # T is for the test data
text = text.split('\n') # convert the raw text to a set of examples         
trainText = text[:-1] if len(text[-1]) == 0 else text    
vocab_size = 49

train_dataset = CharDataset(text, blockSize, chars, numVars=numVars,        
                numYs=numYs, numPoints=numPoints, target=target, addVars=addVars, 
                const_range=const_range, xRange=trainRange, decimals=decimals, augment=False)

                 
mconf = GPTConfig(vocab_size, blockSize,           
                  n_layer=8, n_head=8, n_embd=embeddingSize,                    
                  padding_idx=train_dataset.paddingID)   

model = GPT(mconf, pconf)      

# # load the best model before training                                         

model_name = "XYE_1Var_30-31Points_512EmbeddingSize_SymbolicGPT_GPT_PT_EMB_SUM_Skeleton_Padding_NOT_VAR_MINIMIZE.pt"
model_path = os.path.join("symbolicgpt", "Models", model_name)
model.load_state_dict(torch.load(model_path))                                   
model = model.eval().to(my_device)

char_dict = {index:elem for index, elem in enumerate(chars[:])}

In [None]:
columns = "eqn, seed, pred, target, correct, mse, r2_ood, r2_id, "
columns += "target_complexity, complexity, method, range_low, range_high, "
columns += "number_points, val_low0, val_high0, val_lo1, val_high1"

eval_tag = f"eval_complex_{my_method}_{int(time.time())}"

if logging:
    with open(f"{eval_tag}.csv", "w") as f:
        f.write(columns)
        f.write("\n")

variables = torch.tensor([1])
temperature = 1.0
top_k = 0.0
top_p = 0.7
do_sample = False
inputs = torch.tensor([[23]]) # assume 23 is start token '<' 
model.to(torch.device("cpu"));

accuracies = []
all_mses = []
all_mse_sds = []

all_r2_means = []
all_r2_sds = []

all_tree_distances = []
all_ares = []

all_tree_distance_sds = []
all_are_sds = []

catastrophic_failure_count = 0
for hh, eqn in enumerate(benchmark_eqns):
    equivalents = []
    mses = []
    
    r2s = []
    complexities = []
    ares = []
    tree_distances = []
    
    target_complexity = compute_complexity(eqn)
    
    for trial in range(number_trials):
        
        np.random.seed(trial)
        torch.manual_seed(trial)
        
        my_fn = sp.lambdify("x0", expr=eqn)
        
        (bottom, top, number_samples, d, e, f, g) = sample_meta[eqn]
        x = np.random.rand(number_samples, 1) \
                * (top-bottom) \
                + bottom
        
        y = my_fn(x)
        
        x = torch.tensor(x.transpose(1,0)[None,:,:])
        y = torch.tensor(y.transpose(1,0)[None,:,:])
        
        points = torch.cat([x,y], dim=1).float()
        
        
        pred_outputs = sample_from_model(model, inputs, 
            blockSize, points=points,\
            variables=variables, temperature=temperature,\
            sample=do_sample, top_k=top_k, top_p=top_p)
        
        string_output = [char_dict[elem.item()] for elem in pred_outputs[0]]
        pred_skeleton = "".join(string_output).split(">")[0][1:].replace("s","x").replace("q","s").replace("***","**")
   
        best_eqn = None
        best_fn = None
        tgt_eqn = sp.simplify(eqn)
        try:     

            # train a regressor to find the constants (too slow)                
            c = [1.0 for i,x in enumerate(pred_skeleton) if x=='C'] # initialize coefficients as 1
            # c[-1] = 0 # initialize the constant as zero                       
            #b = [(-2,2) for i,x in enumerate(predicted) if x=='C']  # bounds on variables

            if use_bfgs:
            
                optimized = minimize(lossFunc, c, args=(pred_skeleton, x.numpy(), y.numpy()), method="BFGS")                                          

                constants_placed = 0

                pred_expression = ""

                for my_char in pred_skeleton:

                    if my_char == "C":
                        pred_expression += f"{optimized.x[constants_placed]}"
                        constants_placed += 1
                    else:
                        pred_expression += my_char
            else:
                constants_placed = 0

                pred_expression = ""

                for my_char in pred_skeleton:

                    if my_char == "C":
                        pred_expression += f"{1.0}"
                        constants_placed += 1
                    else:
                        pred_expression += my_char
            

            print(pred_expression)
            
            #tgt_eqn = sp.simplify(eqn)
            best_eqn = sp.simplify(pred_expression)
            print(best_eqn)

            tgt_fn = sp.lambdify("x0", expr=eqn)
            best_fn = sp.lambdify("x0", expr=pred_expression.replace("x1","x0"))

            is_correct = 1.0 * (sp.simplify(best_eqn - tgt_eqn) == 0) 
            
            my_mse_1 = np.mean((tgt_fn(x.numpy()) - best_fn(x.numpy()))**2)
            my_mse_0 = np.mean((tgt_fn(x.numpy()) - (best_fn(x.numpy())-1.0) )**2)
            
            my_mse = min([my_mse_1, my_mse_0])
            
            id_x = np.arange(bottom, top, (top-bottom)/1000).reshape(-1, 1)

            bigger_x = np.append(np.arange(d,e, (e-d)/500).reshape(-1, 1),\
                                 np.arange(f,g, (g-f)/500).reshape(-1,1))

            bigger_y_true = tgt_fn(bigger_x)
            bigger_y_pred = best_fn(bigger_x)
            
            id_y_true = tgt_fn(id_x)
            id_y_pred = best_fn(id_x)

            if np.isfinite(bigger_y_pred.mean()):
                my_r2 = sklearn.metrics.r2_score(bigger_y_true, bigger_y_pred)
            else:
                my_r2 = np.nan
            
            if np.isfinite(id_y_pred.mean()):
                my_r2_id = sklearn.metrics.r2_score(id_y_true, id_y_pred)
            else:
                my_r2_id = np.nan

            #assert np.isfinite(id_y_pred.mean()), "not finite"
            
            #assert np.isfinite(bigger_y_pred.mean()), "not finite"
            
            my_complexity = compute_complexity(best_eqn)
            #import pdb; pdb.set_trace()
            my_ted = compute_tree_distance(eqn, best_eqn)
            my_are = compute_relative_error(id_y_true, id_y_pred)

            complexities.append(my_complexity)
            mses.append(my_mse)
            r2s.append(my_r2_id)
            equivalents.append(is_correct)
            
            tree_distances.append(my_ted)
            ares.append(my_are)


        except:
            error_msg = f"evaluation failed with predicted expression {pred_skeleton}."
            wright = "incorrect"
            
            catastrophic_failure_count += 1
            is_correct = 0
            
            my_mse = np.nan
            my_r2 = np.nan
            my_r2_id = np.nan
            my_complexity = None
            
        wright = "correct" if is_correct else "incorrect"

        correct = 1 if is_correct else 0
        
        try:
            msg = f"eqn {hh+1}, trial {trial} predicted {wright} equation: \n    predicted: {best_eqn}\n"
            msg +=f"    target   : {tgt_eqn}"
            msg += f" with mse {mses[-1]:.3}\n"
        except:
            msg = ""
        print(msg)

        #columns = "eqn, seed, pred, target, correct, mse, method, range_low, range_high, r2, number_points"
        if logging:
            results = f"{eqn}, {hh}, {best_eqn}, {tgt_eqn}, {correct}, {my_mse}, {my_r2}, "\
                    f"{my_r2_id}, {target_complexity}, {my_complexity}, "\
                    f" {my_method}, {bottom}, {top}, {number_samples}, {d}, {e}, {f}, {g}"

            with open(f"{eval_tag}.csv", "a") as f:
                f.write(results)
                f.write("\n")
                
    msg = f"accuracy for equation {hh+1}: {np.mean(equivalents)}"\
            f" with mean mse: {np.mean(mses):3}, running total failure count: {catastrophic_failure_count}\n"
    print(msg)
    if len(mses):
        accuracies.append(np.mean(equivalents))
        all_mses.append(np.mean(mses))
        all_mse_sds.append(np.std(mses))
        all_r2_means.append(np.mean(r2s))
        all_r2_sds.append(np.std(r2s))

        all_ares.append(np.mean(ares))
        all_tree_distances.append(np.mean(tree_distances))

        all_are_sds.append(np.std(ares))
        all_tree_distance_sds.append(np.std(tree_distances))
    else:
        accuracies.append(0.0)
        all_mses.append(float("inf"))
        all_mse_sds.append(float("inf"))
        all_r2_means.append(-float("inf"))
        all_r2_sds.append(float("inf"))

        all_ares.append(float("inf"))
        all_tree_distances.append(float("inf"))

        all_are_sds.append(float("inf"))
        all_tree_distance_sds.append(float("inf"))
    
failure_msg = f"\nTotal failure count: {catastrophic_failure_count}, of {len(benchmark_eqns)*number_trials}"
failure_msg += f" = {catastrophic_failure_count / (len(benchmark_eqns)*number_trials)}"
print(failure_msg)

In [None]:

msg = f"{my_method} accuracies\n"
print(failure_msg)

for ii, eqn in enumerate(benchmark_eqns):
    
    msg += f"\n  complex-{ii+1}, exact equ. accuracy: {accuracies[ii]:5f}, "\
            f"\ntree edit distance: {all_tree_distances[ii]:.5} +/- {all_tree_distance_sds[ii]:.5} "\
            f"\nmse: {all_mses[ii]:.5} +/- {all_mse_sds[ii]:.5} "\
            f"\nrelative absolute error: {all_ares[ii]:.5} +/- {all_are_sds[ii]:.5} "\
            f"\nr^2 (i.d): {all_r2_means[ii]:.5} +/- {all_r2_sds[ii]:.5}\n\n"
    
    msg += f"  {sp.expand(eqn)} \n"

if logging:
    summary_file = f"{eval_tag}_summary.txt"
    with open(summary_file, "w") as f:
        f.write(msg)
        
print(msg)

In [None]:
"""
Total failure count: 200, of 800 = 0.25
valipour_w_bfgs accuracies

  complex-1, exact equ. accuracy: 0.000000, 
tree edit distance: 19.38 +/- 2.6068 
mse: 0.46741 +/- 0.3956 
relative absolute error: 4.1322e+09 +/- 5.9063e+09 
r^2 (i.d): -2.708 +/- 4.6607

  sin(x0*exp(x0)) 

  complex-2, exact equ. accuracy: 0.000000, 
tree edit distance: 16.3 +/- 1.7059 
mse: nan +/- nan 
relative absolute error: nan +/- nan 
r^2 (i.d): nan +/- nan

  x0 + log(x0**4) 

  complex-3, exact equ. accuracy: 0.000000, 
tree edit distance: 18.52 +/- 1.3452 
mse: 0.080003 +/- 0.014802 
relative absolute error: 0.13585 +/- 0.004742 
r^2 (i.d): 0.0095249 +/- 0.014791

  x0*sin(1/x0) + 1 

  complex-4, exact equ. accuracy: 0.000000, 
tree edit distance: inf +/- inf 
mse: inf +/- inf 
relative absolute error: inf +/- inf 
r^2 (i.d): -inf +/- inf

  sqrt(x0**3)*log(x0**2) 

  complex-5, exact equ. accuracy: 0.000000, 
tree edit distance: 34.0 +/- 0.0 
mse: 0.014393 +/- 0.0017051 
relative absolute error: 0.79621 +/- 0.092582 
r^2 (i.d): 0.99361 +/- 0.00033622

  x0**3/(x0*cos(x0**2) + 1) + x0/(x0*cos(x0**2) + 1) 

  complex-6, exact equ. accuracy: 0.000000, 
tree edit distance: inf +/- inf 
mse: inf +/- inf 
relative absolute error: inf +/- inf 
r^2 (i.d): -inf +/- inf

  x0/sqrt(x0**2 + sin(x0)) 

  complex-7, exact equ. accuracy: 0.000000, 
tree edit distance: 40.94 +/- 5.6477 
mse: 0.1896 +/- 0.10013 
relative absolute error: 3.5634 +/- 17.851 
r^2 (i.d): -6090.1 +/- 5.1363e+04

  cos(x0/(x0**3 + x0*log(x0**2)) + sin(x0)/(x0**3 + x0*log(x0**2))) 

  complex-8, exact equ. accuracy: 0.000000, 
tree edit distance: 36.7 +/- 3.1417 
mse: nan +/- nan 
relative absolute error: nan +/- nan 
r^2 (i.d): nan +/- nan

  sqrt(x0 + 1)*exp(x0)/x0**2 + exp(x0)*cos(x0**2)/x0**2 + exp(x0)/x0**2 

"""

In [None]:
"""
Total failure count: 200, of 800 = 0.25
valipour_no_bfgs accuracies

  complex-1, exact equ. accuracy: 0.000000, 
tree edit distance: 17.28 +/- 0.69397 
mse: 1.1009 +/- 0.10597 
relative absolute error: 1.3742e+10 +/- 1.759e+09 
r^2 (i.d): -10.76 +/- 0.02368

  sin(x0*exp(x0)) 

  complex-2, exact equ. accuracy: 0.000000, 
tree edit distance: 16.3 +/- 1.7059 
mse: nan +/- nan 
relative absolute error: nan +/- nan 
r^2 (i.d): nan +/- nan

  x0 + log(x0**4) 

  complex-3, exact equ. accuracy: 0.000000, 
tree edit distance: 18.0 +/- 0.0 
mse: 1.4787e+125 +/- 4.0004e+125 
relative absolute error: 4.2579e+60 +/- 0.0 
r^2 (i.d): -5.4681e+125 +/- 2.2546e+110

  x0*sin(1/x0) + 1 

  complex-4, exact equ. accuracy: 0.000000, 
tree edit distance: inf +/- inf 
mse: inf +/- inf 
relative absolute error: inf +/- inf 
r^2 (i.d): -inf +/- inf

  sqrt(x0**3)*log(x0**2) 

  complex-5, exact equ. accuracy: 0.000000, 
tree edit distance: 32.0 +/- 0.0 
mse: 3.2605 +/- 0.41506 
relative absolute error: 11.32 +/- 1.7764e-15 
r^2 (i.d): -1.5806 +/- 2.2204e-16

  x0**3/(x0*cos(x0**2) + 1) + x0/(x0*cos(x0**2) + 1) 

  complex-6, exact equ. accuracy: 0.000000, 
tree edit distance: inf +/- inf 
mse: inf +/- inf 
relative absolute error: inf +/- inf 
r^2 (i.d): -inf +/- inf

  x0/sqrt(x0**2 + sin(x0)) 

  complex-7, exact equ. accuracy: 0.000000, 
tree edit distance: 27.86 +/- 0.51029 
mse: 1.6557e+04 +/- 1.6452e+05 
relative absolute error: 3.0704e+09 +/- 1.1191e+10 
r^2 (i.d): -4.2884e+22 +/- 1.5631e+23

  cos(x0/(x0**3 + x0*log(x0**2)) + sin(x0)/(x0**3 + x0*log(x0**2))) 

  complex-8, exact equ. accuracy: 0.000000, 
tree edit distance: 33.48 +/- 2.9137 
mse: 3.7353e+125 +/- 7.3283e+125 
relative absolute error: 7.8466e+59 +/- 4.4094e+59 
r^2 (i.d): -2.3779e+122 +/- 1.3363e+122

  sqrt(x0 + 1)*exp(x0)/x0**2 + exp(x0)*cos(x0**2)/x0**2 + exp(x0)/x0**2 
"""