In [None]:
from utils import load_dataset, lossFunc, relativeErr
from models import SymbolicDiffusion, PointNetConfig
import torch
from torch.utils.data import DataLoader

In [None]:
n_embd = 32             
timesteps = 1000         
batch_size = 1024
learning_rate = 1e-4
num_epochs = 10
blockSize = 32
numVars = 1
numYs = 1
numPoints = 250
const_range = [-2.1, 2.1]
trainRange = [-3.0, 3.0]
decimals = 8
addVars = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
test_dataset, testText = load_dataset("/content/drive/MyDrive/Colab Notebooks/STAT946_proj/data/1_var_test.json")
test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=2
    )

In [None]:
import torch
from scipy.optimize import minimize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


pconfig = PointNetConfig(
    embeddingSize=n_embd,
    numberofPoints=250,
    numberofVars=1,
    numberofYs=1,
)
model = SymbolicDiffusion(
    pconfig=pconfig,
    vocab_size=50,
    max_seq_len=blockSize,
    padding_idx=test_dataset.paddingID,
    max_num_vars=9,
    n_layer=4,
    n_head=4,
    n_embd=n_embd,
    timesteps=timesteps,
    beta_start=0.0001,
    beta_end=0.02,
).to(device)

model_path = "/content/drive/MyDrive/Colab Notebooks/STAT946_proj/models/symbolic_diffusion_model.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


@torch.no_grad()
def test_model(model, test_loader, test_dataset, device):
    results = {'target': [], 'predicted': [], 'error': []}
    
    for batch_idx, (_, tokens, points, variables) in enumerate(test_loader):
        points = points.to(device)    # [B, 2, 250]
        tokens = tokens.to(device)    # [B, L]
        variables = variables.to(device)  # [B]

        generated_tokens = model.sample(points, variables, device)  

        for i in range(batch_size):  
            # Ground truth
            gt_tokens = tokens[i].cpu().numpy()
            gt_expr = ''.join(test_dataset.itos[int(idx)] for idx in gt_tokens)
            gt_expr = gt_expr.strip(test_dataset.paddingToken).strip('<').strip('>').split(':')[-1]

            # Predicted
            pred_tokens = generated_tokens[i].cpu().numpy()
            predicted = ''.join(test_dataset.itos[int(idx)] for idx in pred_tokens if int(idx) < len(test_dataset.itos))
            predicted = predicted.strip(test_dataset.paddingToken).strip('<').strip('>').split(':')[-1]
            predicted = predicted.replace('Ce', 'C*e')

            # train a regressor to find the constants (too slow)
            c = [1.0 for i,x in enumerate(predicted) if x=='C'] # initialize coefficients as 1
            # c[-1] = 0 # initialize the constant as zero
            b = [(-2,2) for i,x in enumerate(predicted) if x=='C']  # bounds on variables
            try:
                if len(c) != 0:
                    # This is the bottleneck in our algorithm
                    # for easier comparison, we are using minimize package  
                    cHat = minimize(lossFunc, c, #bounds=b,
                                args=(predicted, t['X'], t['Y'])) 

                    predicted = predicted.replace('C','{}').format(*cHat.x)
            except ValueError:
                raise 'Err: Wrong Equation {}'.format(predicted)
            except Exception as e:
                raise 'Err: Wrong Equation {}, Err: {}'.format(predicted, e)

            Ys = [] #t['YT']
            Yhats = []
            for xs in t['XT']:
                try:
                    eqTmp = gt_expr + '' # copy eq
                    eqTmp = eqTmp.replace(' ','')
                    eqTmp = eqTmp.replace('\n','')
                    for i,x in enumerate(xs):
                        # replace xi with the value in the eq
                        eqTmp = eqTmp.replace('x{}'.format(i+1), str(x))
                        if ',' in eqTmp:
                            assert 'There is a , in the equation!'
                    YEval = eval(eqTmp)
                    # YEval = 0 if np.isnan(YEval) else YEval
                    # YEval = 100 if np.isinf(YEval) else YEval
                except:
                    print('TA: For some reason, we used the default value. Eq:{}'.format(eqTmp))
                    print(i)
                    raise
                    continue # if there is any point in the target equation that has any problem, ignore it
                    YEval = 100 #TODO: Maybe I have to punish the model for each wrong template not for each point
                Ys.append(YEval)
                try:
                    eqTmp = predicted + '' # copy eq
                    eqTmp = eqTmp.replace(' ','')
                    eqTmp = eqTmp.replace('\n','')
                    for i,x in enumerate(xs):
                        # replace xi with the value in the eq
                        eqTmp = eqTmp.replace('x{}'.format(i+1), str(x))
                        if ',' in eqTmp:
                            assert 'There is a , in the equation!'
                    Yhat = eval(eqTmp)
                    # Yhat = 0 if np.isnan(Yhat) else Yhat
                    # Yhat = 100 if np.isinf(Yhat) else Yhat
                except:
                    print('PR: For some reason, we used the default value. Eq:{}'.format(eqTmp))
                    Yhat = 100
                Yhats.append(Yhat)
            err = relativeErr(Ys,Yhats, info=True)


            results['target'].append(gt_expr)
            results['predicted'].append(predicted)
            results['error'].append(err)

            print(f"\nSample {batch_idx * batch_size + i + 1}:")
            print(f"Target: {gt_expr}")
            print(f"Predicted: {predicted}")
            print(f"Relative Error: {err:.6f}")
            print("-" * 50)

    return results

In [None]:
print("Testing SymbolicDiffusion model...")
test_results = test_model(model, test_loader, test_dataset, device, num_samples=5)

print("\nSummary:")
for i in range(len(test_results['target'])):
    print(f"Sample {i+1}:")
    print(f"  Target: {test_results['target'][i]}")
    print(f"  Predicted: {test_results['predicted'][i]}")
    print(f"  Error: {test_results['error'][i]:.6f}")