### Requirements

In [1]:
import json
import os

from latex2sympy2 import latex2sympy, latex2latex
from sympy import simplify, srepr, Eq
from sympy.core.basic import Basic
from zss import simple_distance, Node
from difflib import SequenceMatcher
from PrettyPrint import PrettyPrintTree

from main import *

### Constants

In [2]:
BASE_PATH = os.path.dirname(os.path.abspath("__file__"))
EXAMPLE_DATA_FILE = "data_example.json"
DATA_FILE = "data.json"

### Read data

In [3]:
# JSON example tree data
with open(os.path.join(BASE_PATH, EXAMPLE_DATA_FILE), 'r') as file:
    example_json_data = json.load(file)
tree1 = example_json_data.get("exprl", {}) # Template answer
tree2 = example_json_data.get("expr2", {}) # Right answer
tree3 = example_json_data.get("expr3", {}) # Wrong answer

# Full JSON data
with open(os.path.join(BASE_PATH, DATA_FILE), 'r') as file:
    json_data = json.load(file)

# Latex string data
expr1 = r"\frac{d}{dx}(x^2 + 2*x) \times \int x \,dx"
expr2 = r"x^3 + x^2" # Correct would be x^3 + x^2
expr3 = r"\frac{(x^3 + x^3)}{\tan(10)}"

### Test of similarity tree analysis

In [None]:
tree1 = latex_to_tree(expr1)
tree2 = latex_to_tree(expr2)
expression_tree_similarity = get_tree_sequence_similarity(tree1, tree2)
print(f"Expression tree similarity: {round(expression_tree_similarity*100, 0)}%")

### Test Bert text similarity

In [None]:
from latex2sympy2 import latex2sympy, latex2latex
from sympy import simplify, srepr, Eq
from difflib import SequenceMatcher
from main import *

In [None]:
# Carregar o modelo BERT pré-treinado e o tokenizador
modelo = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
emb1 = get_bert_embeddings(latex2latex(expr1), modelo, tokenizer)
emb2 = get_bert_embeddings(latex2latex(expr2), modelo, tokenizer)
print(f"Expression tree similarity: {round(get_text_similarity(emb1, emb2)*100,0)}%")

### Tree edit distance (Zhang-Shasha)

In [4]:


t1 = json_data[100]['expr_l']
t2 = json_data[250]['expr_l']
compare_database_trees(t1, t2)

17.0

In [5]:
t1 = json_data[1]['expr_l']
t1_copy = {'val': 'POW',
 'id': 1,
 'children': [{'val': '7.0', 'id': 3},
  {'val': 'POW',
   'id': 4,
   'children': [{'val': '7.0', 'id': 9},
    {'val': 'POW',
     'id': 10,
     'children': [{'val': 'k', 'id': 21}, {'val': '-9.5', 'id': 22}]}]}]}

compare_database_trees(t1, t1_copy)

2.0

In [6]:
pt = PrettyPrintTree(lambda x: x.children, lambda x: x.label)

In [7]:
pt(parse_database_tree(t1))

     [100m POW [0m
  ┌────┴────┐
[100m 7.0 [0m     [100m POW [0m       
        ┌───┴────┐    
      [100m 7.0 [0m    [100m POW [0m  
              ┌──┴──┐ 
            [100m -9.5 [0m [100m k [0m


In [8]:
pt(parse_database_tree(t1_copy))

    [100m POW [0m
  ┌───┴────┐
[100m 7.0 [0m    [100m POW [0m        
        ┌──┴───┐      
      [100m 7.0 [0m  [100m POW [0m    
             ┌─┴──┐   
            [100m k [0m [100m -9.5 [0m


In [9]:
# Latex string data

expr1 = r"\frac{d}{dx}(x^2 + 2*x) \times \int x \,dx"
expr2 = r"x^3 + x^2" # Correct would be x^3 + x^2
expr3 = r"\frac{(x^3 + x^3)}{\tan(10)}"

In [10]:
# Build the tree from the expression
tree1 = build_tree(simplify_latex_expression(expr1))
tree2 = build_tree(simplify_latex_expression(expr2))
tree3 = build_tree(simplify_latex_expression(expr3))

In [12]:
pt(tree1)

                 [100m <class 'sympy.core.mul.Mul'> [0m
               ┌───────────────┴───────────────┐
[100m <class 'sympy.core.power.Pow'> [0m [100m <class 'sympy.core.add.Add'> [0m
             ┌─┴─┐                           ┌─┴─┐             
            [100m x [0m [100m 2 [0m                         [100m x [0m [100m 1 [0m


In [13]:
pt(tree2) 

                 [100m <class 'sympy.core.mul.Mul'> [0m
               ┌───────────────┴───────────────┐
[100m <class 'sympy.core.power.Pow'> [0m [100m <class 'sympy.core.add.Add'> [0m
             ┌─┴─┐                           ┌─┴─┐             
            [100m x [0m [100m 2 [0m                         [100m x [0m [100m 1 [0m


In [14]:
pt(tree3)

                     [100m <class 'sympy.core.mul.Mul'> [0m
              ┌────────────────────┴────────────────────┐
[100m <class 'sympy.core.mul.Mul'> [0m           [100m <class 'sympy.core.power.Pow'> [0m
     ┌────────┴────────┐                              ┌─┴──┐             
    [100m 2 [0m [100m <class 'sympy.core.power.Pow'> [0m            [100m tan [0m [100m -1 [0m           
                     ┌─┴─┐                            |                  
                    [100m x [0m [100m 3 [0m                          [100m 10 [0m


In [15]:
simple_distance(tree1, tree2), simple_distance(tree1, tree3), simple_distance(tree2, tree3)

(0.0, 7.0, 7.0)