In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

from symr.helpers import r2_over_threshold, r2_auc, plot_r2_over_threshold

import pandas as pd
import sympy as sp
import math

In [None]:
# modify the results path and string tag to load different csvs
root_path = ".."
results_path = os.path.join(root_path, "results", "metric_message")
string_tag = ["860200"]

In [None]:

univariate = [' sin(x*exp(x))',\
                  ' x + log(x**4)',\
                  ' x*sin(1/x) + 1',\
                  ' sqrt(x**3)*log(x**2)',\
                  ' (x**3 + x)/(x*cos(x**2) + 1)',\
                  ' x/sqrt(x**2 + sin(x))',\
                  ' cos((x + sin(x))/(x**3 + x*log(x**2)))',\
                  ' (sqrt(x + 1) + cos(x**2) + 1)*exp(x)/x**2',\
                  ' x**3 + x**2 + x',\
                  ' x**4 + x**3 + x**2 + x',\
                  ' x**5 + x**4 + x**3 + x**2 + x',\
                  ' x**6 + x**5 + x**3 + x**2 + x',\
                  ' sin(x**2)*cos(x) - 1', \
                  ' sin(x) + sin(x**2 + x)',\
                  ' log(x + 1) + log(x**2 + 1)',\
                  ' sqrt(x)']
expression_dict = {' sin(x*exp(x))': "A.B.-1" ,\
                  ' x + log(x**4)': "A.B.-2" ,\
                  ' x*sin(1/x) + 1': "A.B.-3" ,\
                  ' sqrt(x**3)*log(x**2)': "A.B.-4",\
                  ' (x**3 + x)/(x*cos(x**2) + 1)': "A.B.-5",\
                  ' x/sqrt(x**2 + sin(x))': "A.B.-6",\
                  ' cos((x + sin(x))/(x**3 + x*log(x**2)))': "A.B.-7",\
                  ' (sqrt(x + 1) + cos(x**2) + 1)*exp(x)/x**2': "A.B-8",\
                  ' x**3 + x**2 + x': "Nguyen-1",\
                  ' x**4 + x**3 + x**2 + x': "Nguyen-2",\
                  ' x**5 + x**4 + x**3 + x**2 + x': "Nguyen-3",\
                  ' x**6 + x**5 + x**3 + x**2 + x' : "Nguyen-4",\
                  ' sin(x**2)*cos(x) - 1' : "Nguyen-5",\
                  ' sin(x) + sin(x**2 + x)' : "Nguyen-6",\
                  ' log(x + 1) + log(x**2 + 1)' : "Nguyen-7",\
                  ' sqrt(x)' : "Nguyen-8",\
                  ' sin(x) + sin(y**2)' : "Nguyen-9",\
                  ' 2*sin(x)*cos(y)' : "Nguyen-10",\
                  ' x**y' : "Nguyen-11",\
                  ' x**4 - x**3 + y**2/2 - y'  : "Nguyen-12"\
                  }

In [None]:
methods.shape

In [None]:


list_dir = os.listdir(results_path)

active_dir = list_dir

results = {key:{} for key in univariate}
        
for filename in active_dir:
    
    if "PySR" in filename:
        # not including PySR in this table, because it doesn't use post-inference optimization
        active_dir.remove(filename)
    
for filename in active_dir:
    
    if "b1" in filename:
        print(filename)
        filepath = os.path.join(results_path, filename)

        df = pd.read_csv(filepath)

        methods = df["method"].unique()[1:]
        expressions = df[" expression"].unique()
        
        if methods.shape[0] == 1:
            method = methods[0]
        else:
            import pdb; pdb.sort_trace()
            # looks like this csv contains multiple methods
            
        for expression in expressions:
            if expression in univariate:
                in_success = " False" == df.loc[df[" expression"] == expression]["failed"].to_numpy()

                # r2
                in_r2_raw = df.loc[df[" expression"] == expression]["in_r2"]
                in_r2_raw = np.array(in_r2_raw, dtype=float)
                in_success[" None" == in_r2_raw] = 0      
                in_success[False == np.isfinite(in_r2_raw)] = 0

                in_r2 = in_r2_raw[in_success]

                # Still not sure what the best way is to honestly report r^2 
                # while taking into account scores can be unboundedly bad 
                # reports in the NSR literature truncate r^2 at a lower bound of 0.0 and/or use median values
                # see ()
                mean_r2 = np.mean(np.clip(in_r2,-1,1.))
                std_dev_r2 = np.std(np.clip(in_r2,-1,1.))

                results[expression][method] = [mean_r2, std_dev_r2]
            


In [None]:

latex_table = """
\\begin{table}[!h]
    \\centering
    \\begin{tabular}{l | c  c  c  c  c  c c}
    Number & Random & Fourier & NSRTS & SymGPT & Symformer & Polynomials \\\\ \hline
"""

for expression in results.keys():
    
    table_row = f"{expression_dict[expression]} & "
    for method in results[expression].keys():
        
        #table_row += f"{results[expression][method][0]:.2f} & "
        #OR
        table_row += f"${results[expression][method][0]:.2f} \pm "\
                f"{results[expression][method][1]:.2f}$ & "
        
        
    table_row += "\\\\ \n"
    latex_table += table_row
    
latex_table += """
    \\end{tabular}
    \\caption{Performance of BFGS on $R^2$ over different methods for benchmark equations \citep{uy2011}}
    \\label{table:random_accuracy}
\\end{table}
"""

print(latex_table)