### Requirements

In [3]:
!pip install latex2sympy2

Defaulting to user installation because normal site-packages is not writeable
Collecting latex2sympy2
  Downloading latex2sympy2-1.9.1-py3-none-any.whl (89 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.8/89.8 KB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting antlr4-python3-runtime==4.7.2
  Downloading antlr4-python3-runtime-4.7.2.tar.gz (112 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.3/112.3 KB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: antlr4-python3-runtime
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25ldone
[?25h  Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.7.2-py3-none-any.whl size=140949 sha256=1a12c1fb9668ea8ce8e6da0054b52450a2523c17eb6cfba440a52187ad7cacf1
  Stored in directory: /home/lucas/.cache/pip/wheels/79/20/ec/30bf7dabc29319ccc0d0c96f910a64

In [1]:
import json
import os

from latex2sympy2 import latex2sympy, latex2latex
from sympy import simplify, srepr, Eq
from difflib import SequenceMatcher
from main import *

  from .autonotebook import tqdm as notebook_tqdm


### Constants

In [2]:
BASE_PATH = os.path.dirname(os.path.abspath("__file__"))
DATA_FILE = "data_example.json"

### Read data

In [3]:
# Latex string data

expr1 = r"\frac{d}{dx}(x^2 + 2*x) \times \int x \,dx"
expr2 = r"x^3 + x^2" # Correct would be x^3 + x^2

In [4]:
# JSON tree data

with open(os.path.join(BASE_PATH, DATA_FILE), 'r') as file:
    json_data = json.load(file)

tree1 = json_data.get("exprl", {}) # Template answer
tree2 = json_data.get("expr2", {}) # Right answer
tree3 = json_data.get("expr3", {}) # Wrong answer

### Test of similarity tree analysis

In [5]:
tree1 = latex_to_tree(expr1)
tree2 = latex_to_tree(expr2)
expression_tree_similarity = get_tree_sequence_similarity(tree1, tree2)
print(f"Expression tree similarity: {round(expression_tree_similarity*100, 0)}%")

Expression tree similarity: 100.0%


### Test Bert text similarity

In [6]:
from latex2sympy2 import latex2sympy, latex2latex
from sympy import simplify, srepr, Eq
from difflib import SequenceMatcher
from main import *

In [10]:
# Carregar o modelo BERT pré-treinado e o tokenizador
modelo = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

config.json: 100%|██████████| 570/570 [00:00<00:00, 1.65MB/s]
model.safetensors:   0%|          | 0.00/440M [00:03<?, ?B/s]


KeyboardInterrupt: 

In [None]:
emb1 = get_bert_embeddings(latex2latex(expr1), modelo, tokenizer)
emb2 = get_bert_embeddings(latex2latex(expr2), modelo, tokenizer)
print(f"Expression tree similarity: {round(get_text_similarity(emb1, emb2)*100,0)}%")

### Tree edit distance (Zhang-Shasha)

In [1]:
!pip install zss

Defaulting to user installation because normal site-packages is not writeable


In [15]:
#Load example dataset
import json

json_data = open('test_datagen.json')
data = json.load(json_data)

In [26]:
data[2]['expr_l'] == data[3]['expr_l']

False

In [16]:
from zss import simple_distance, Node

def parse_tree(expression):
    children = []

    if 'children' not in expression:
        return Node(expression['val'])

    for i in range(len(expression['children'])):
        children.append(parse_tree(expression['children'][i]))
    
    root = Node(expression['val'], children=children)

    return root

t1 = parse_tree(data['expr_l'])

TypeError: list indices must be integers or slices, not str

In [9]:

print_list = [0 for _ in range(7)]

def print_tree(root, pos=0):
    print("--------------------------------")
    print(root.label)
    print_list.insert(pos, root.label)
    print("Children:")
    if len(root.children) > 0:

        for i, child in enumerate(root.children):
            print_tree(child, pos=pos+i+1)
                                                                 

print_tree(t1)


--------------------------------
POW
Children:
--------------------------------
7.0
Children:
--------------------------------
POW
Children:
--------------------------------
7.0
Children:
--------------------------------
POW
Children:
--------------------------------
-9.5
Children:
--------------------------------
k
Children:


In [27]:
def compare_trees(tree1, tree2):
    t1 = parse_tree(tree1)
    t2 = parse_tree(tree2)
    
    return simple_distance(t1, t2)

In [30]:
compare_trees(data[0]['expr_l'], data[0]['expr_l'])

0.0

In [10]:
t2 = parse_tree(data['expr3'])

print_tree(t2)

--------------------------------
POW
Children:
--------------------------------
7.0
Children:
--------------------------------
MUL
Children:
--------------------------------
7.0
Children:
--------------------------------
POW
Children:
--------------------------------
-9.5
Children:
--------------------------------
k
Children:


In [11]:
t1 = parse_tree(data['exprl'])
t2 = parse_tree(data['expr3'])

simple_distance(t1, t2)

1.0

In [5]:
def tree_edit_distance(tree1, tree2):
    # Helper function to calculate the cost between two nodes
    def node_cost(node1, node2):
        return 0 if node1["val"] == node2["val"] else 1

    # Helper function to compute the minimum cost between subtrees
    def edit_distance(node1, node2, memo):
        if "children" not in node1 and "children" not in node2:
            cost = node_cost(node1, node2)
        else:
            cost = float('inf')
            for child1 in node1.get("children", []):
                for child2 in node2.get("children", []):
                    cost = min(cost, edit_distance(child1, child2, memo))

            cost += node_cost(node1, node2)

        memo[(node1["id"], node2["id"])] = cost
        return cost

    # Initialize memoization dictionary
    memo = {}

    # Calculate the tree edit distance
    distance = edit_distance(tree1, tree2, memo)
    return distance

In [6]:
distance = tree_edit_distance(tree1, tree2)
print("Tree Edit Distance:", distance)

Tree Edit Distance: 0


In [7]:
distance = tree_edit_distance(tree1, tree3)
print("Tree Edit Distance:", distance)

Tree Edit Distance: 0


In [13]:
tree1

"Mul(Pow(Symbol('x'), Integer(2)), Add(Symbol('x'), Integer(1)))"