In [45]:
import os
import pandas as pd
import re
from collections import OrderedDict
from IPython.display import display
import random
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import numba
from numba import jit, prange, njit
import sys
from numba import int64, boolean, types
from numba.experimental import jitclass
from numba.typed import Dict, List

#sys.setrecursionlimit(10**4)
# import display
def load_and_print_csvs_from_folders():
    cwd = os.getcwd()
    folders = [d for d in os.listdir(cwd) if os.path.isdir(os.path.join(cwd, d))]
    data = {}
    for folder in folders:
        folder_path = os.path.join(cwd, folder)
        for file in os.listdir(folder_path):
            if file.endswith('.csv'):
                file_path = os.path.join(folder_path, file)
                df = pd.read_csv(file_path)
                data[file] = df

    return data

def pretty_print_dict(dictionary, indent=0):
    for key, value in dictionary.items():
        if isinstance(value, dict):
            print('  ' * indent + str(key) + ':')
            pretty_print_dict(value, indent+1)
        else:
            print('  ' * indent + str(key) + ': ' + str(value))


#load_and_print_csvs_from_folders()


In [46]:
def read_bnf_file(file_path):
    # read file as string
    with open(file_path, 'r') as f:
        bnf_lines = f.read()
    return bnf_lines

def parse_bnf(grammar_str, depth=1, variables=1):
    #productions = Dict.empty(key_type=types.unicode_type, value_type=types.ListType(types.ListType(types.unicode_type)))
    productions = {}
    grammar_str = re.sub(r'\s*\|\s*', ' |', grammar_str)
    lines = [line.strip() for line in grammar_str.split('\n') if line.strip()]
    PRODUCTION_SEPARATOR = '::='
    RULE_SEPARATOR = '|'
    RECURSIVE_PRODUCTIONS = ['<expr>']
    for line in lines:
        lhs, rhs = line.split(PRODUCTION_SEPARATOR)
        lhs = lhs.strip()
        rhs_productions = []
        for rule in rhs.split(RULE_SEPARATOR):
            rule_list = list(rule.strip().split())
            rhs_productions.append(rule_list)

        productions[lhs] = rhs_productions

    def list_new_non_recursive_expressions(recursive_rule):
        # create depth expressions
        expressions = [recursive_rule]
        for i in range(1, depth):
            expr_rule = recursive_rule.replace('>', f'{i}>')
            expressions.append(expr_rule)
        return expressions

    def filter_non_recursive_productions(recursive_rule):
        non_recursive_productions = set()
        for exp in productions[recursive_rule]:
            if recursive_rule not in exp:
                non_recursive_productions.add(tuple(exp))

        return non_recursive_productions

    def replace_expression_in_production(old_rule, old, new):
        new_rule = list(old_rule)
        for i, exp in enumerate(new_rule):
            if new_rule[i] == old:
                new_rule[i] = new
        return new_rule

    def fix_recursive_production(recursive_rule):
        non_recursive_productions = filter_non_recursive_productions(recursive_rule)
        if depth == 1:
            productions[recursive_rule] = non_recursive_productions
            return

        new_expressions = list_new_non_recursive_expressions(recursive_rule)

        all_production = set(tuple(p) for p in productions[recursive_rule])
        recursive_productions = all_production - non_recursive_productions

        for i in range(0, depth-1):
            expr_rule = new_expressions[i]
            new_productions = []
            for rule in recursive_productions:
                new_rule = replace_expression_in_production(rule, recursive_rule, new_expressions[i+1])
                new_productions.append(new_rule)
            productions[expr_rule] = [*non_recursive_productions, *new_productions]


        productions[new_expressions[-1]] = list(non_recursive_productions)

    for recursive_rule in RECURSIVE_PRODUCTIONS:
        fix_recursive_production(recursive_rule)

    variable_cases = ['x', 'y', 'z', 'w', 'k']
    for var_index in range(variables+1):
        terminal_vars = variable_cases[:var_index]
        productions["<var>"] = terminal_vars

    return productions

demo_bnf = """
<start> ::= <expr> <op> <expr>
<expr> ::= <term> <op> <term> | '(' <term> <op> <term> ')'
<op> ::= '+' | '-' | '/' | '*'
<term> ::= 'x1' | '0.5'
"""

bnf_lines = read_bnf_file('variable_depth.bnf')

def get_terminals(productions):
    terminals = set()
    non_terminals = set(productions.keys())

    for rhs in productions.values():
        for rule in rhs:
            for token in rule:
                if token not in non_terminals:
                    terminals.add(token)

    return terminals

def create_typed_grammar(productions_dict):
    productions = Dict.empty(key_type=types.unicode_type, value_type=types.ListType(types.ListType(types.unicode_type)))

    for nt, rhs_list in productions_dict.items():
        rhs_typed_list = List.empty_list(types.ListType(types.unicode_type))
        for rhs in rhs_list:
            rhs_typed_list.append(List(rhs))
        productions[nt] = rhs_typed_list

    return productions

def create_grammar_from_bnf(bnf_file, depth=1):
    bnf_lines = read_bnf_file(bnf_file)
    productions = parse_bnf(bnf_lines, depth)
    productions = create_typed_grammar(productions)
    non_terminals = set(productions.keys())
    terminals = get_terminals(productions)
    return productions, non_terminals, terminals

productions = parse_bnf(bnf_lines, 5, 3)
terminals = set(productions.keys())
pretty_print_dict(productions)

<start>: [['<expr>']]
<expr>: [('<var>',), ('<number>',), ['(', '<uop>', '<expr1>', ')'], ['(', '<op>', '<expr1>', '<expr1>', ')']]
<number>: [['<integer>']]
<integer>: [['(', '<digit>', ')'], ['(', '<non-zero-digit>', '<digit>', ')'], ['(', '<non-zero-digit>', '<digit>', '<digit>', ')']]
<digit>: [['0'], ['<non-zero-digit>']]
<non-zero-digit>: [['1'], ['2'], ['3'], ['4'], ['5'], ['6'], ['7'], ['8'], ['9']]
<op>: [["'+'"], ["'-'"], ["'*'"], ["'/'"]]
<uop>: [["'abs'"]]
<var>: ['x', 'y', 'z']
<expr1>: [('<var>',), ('<number>',), ['(', '<uop>', '<expr2>', ')'], ['(', '<op>', '<expr2>', '<expr2>', ')']]
<expr2>: [('<var>',), ('<number>',), ['(', '<uop>', '<expr3>', ')'], ['(', '<op>', '<expr3>', '<expr3>', ')']]
<expr3>: [('<var>',), ('<number>',), ['(', '<uop>', '<expr4>', ')'], ['(', '<op>', '<expr4>', '<expr4>', ')']]
<expr4>: [('<var>',), ('<number>',)]


In [47]:
terminals = get_terminals(productions)

print("terminals")
print(terminals)
non_terminals = set(productions.keys())
print("non_terminals")
print(non_terminals)


terminals
{'4', '7', '3', '(', '6', '5', 'y', 'x', "'/'", "'*'", "'abs'", '8', '0', 'z', "'+'", '9', '1', '2', ')', "'-'"}
non_terminals
{'<expr>', '<number>', '<expr1>', '<var>', '<expr3>', '<op>', '<expr4>', '<integer>', '<digit>', '<uop>', '<start>', '<non-zero-digit>', '<expr2>'}


In [48]:
def find_recursive_and_non_recursive_terminals(grammar):
    recursive_terminals = set()
    non_recursive_terminals = set()

    def is_recursive(nt, visited):
        if nt in visited:
            return True
        visited.add(nt)
        for rule in grammar[nt]:
            for token in rule:
                if token in non_terminals and is_recursive(token, visited):
                    return True
        visited.remove(nt)
        return False

    for nt in non_terminals:
        if is_recursive(nt, set()):
            recursive_terminals.add(nt)
        else:
            non_recursive_terminals.add(nt)

    non_recursive_terminals |= terminals

    return recursive_terminals, non_recursive_terminals


recursive_terminals, non_recursive_terminals = find_recursive_and_non_recursive_terminals(productions)

display("Recursive terminals:", recursive_terminals)
display("Non-recursive terminals:", non_recursive_terminals)


'Recursive terminals:'

set()

'Non-recursive terminals:'

{"'*'",
 "'+'",
 "'-'",
 "'/'",
 "'abs'",
 '(',
 ')',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '<digit>',
 '<expr1>',
 '<expr2>',
 '<expr3>',
 '<expr4>',
 '<expr>',
 '<integer>',
 '<non-zero-digit>',
 '<number>',
 '<op>',
 '<start>',
 '<uop>',
 '<var>',
 'x',
 'y',
 'z'}

In [49]:
def calculate_non_recursive_productions(symbol, grammar):
    non_recursive_indices = []
    recursive_terminals_in_grammar, _ = find_recursive_and_non_recursive_terminals(grammar)

    if symbol not in grammar:
        return non_recursive_indices

    for i, rule in enumerate(grammar[symbol]):
        if all(token in recursive_terminals_in_grammar for token in rule):
            non_recursive_indices.append(i)

    return non_recursive_indices

def calculate_recursive_productions(non_terminal, grammar, non_recursive_productions):
    all_productions = set(range(len(grammar[non_terminal])))
    non_recursive_set = set(non_recursive_productions[non_terminal])
    return list(all_productions - non_recursive_set)

def get_non_recursive_expansions(grammar):
    non_recursive_expansions_set = OrderedDict()
    for nt in non_terminals:
        non_recursive_expansions_set[nt] = calculate_non_recursive_productions(nt, grammar)
    return non_recursive_expansions_set

def get_recursive_expansions(grammar, non_recursive_expansions_dict):
    recursive_expansions_set = OrderedDict()
    for nt in non_terminals:
        recursive_expansions_set[nt] = calculate_recursive_productions(nt, grammar, non_recursive_expansions_dict)
    return recursive_expansions_set

# create the non-recursive dictionary for each non-terminal
non_recursive_expansions = get_non_recursive_expansions(productions)
recursive_expansions = get_recursive_expansions(productions, non_recursive_expansions)

print("Non recursive expansions per terminal:")
pretty_print_dict(non_recursive_expansions)


print("Recursive expansions per terminal:")
pretty_print_dict(recursive_expansions)


Non recursive expansions per terminal:
<expr>: []
<number>: []
<expr1>: []
<var>: []
<expr3>: []
<op>: []
<expr4>: []
<integer>: []
<digit>: []
<uop>: []
<start>: []
<non-zero-digit>: []
<expr2>: []
Recursive expansions per terminal:
<expr>: [0, 1, 2, 3]
<number>: [0]
<expr1>: [0, 1, 2, 3]
<var>: [0, 1, 2]
<expr3>: [0, 1, 2, 3]
<op>: [0, 1, 2, 3]
<expr4>: [0, 1]
<integer>: [0, 1, 2]
<digit>: [0, 1]
<uop>: [0]
<start>: [0]
<non-zero-digit>: [0, 1, 2, 3, 4, 5, 6, 7, 8]
<expr2>: [0, 1, 2, 3]


In [50]:
from numba.typed import Dict, List
from numba.core import types

count_references_type = Dict.empty(key_type=types.unicode_type, value_type=types.DictType(types.unicode_type, types.int64))
is_referenced_by_type = Dict.empty(key_type=types.unicode_type, value_type=types.ListType(types.unicode_type))
def calculate_non_terminal_references(grammar, non_terminals_set):
    #count_references = {nt: {} for nt in non_terminals_set}
    #count_references = Dict.empty(key_type=types.unicode_type, value_type=types.DictType(types.unicode_type, types.int64))
    """
    count_references = dict()
    for nt in non_terminals_set:
        #count_references[nt] = Dict.empty(key_type=types.unicode_type, value_type=types.int64)
        count_references[nt] = dict()
    #is_referenced_by = Dict.empty(key_type=types.unicode_type, value_type=types.unicode_type[:])
    is_referenced_by = dict()
    for nt in non_terminals_set:
        #is_referenced_by[nt] = List.empty_list(types.int64)
        is_referenced_by[nt] = []
    #is_referenced_by = {nt: [] for nt in non_terminals_set}
    """
    count_references = count_references_type.copy()
    is_referenced_by = is_referenced_by_type.copy()

    for nt in non_terminals_set:
        count_references[nt] = Dict.empty(key_type=types.unicode_type, value_type=types.int64)
        is_referenced_by[nt] = List.empty_list(types.unicode_type)

    for nt in non_terminals_set:
        for production in grammar[nt]:
            count = {option: 0 for option in non_terminals_set}
            for option in production:
                if option in non_terminals_set:
                    is_referenced_by[option].append(nt)
                    count[option] += 1
            for key in count:
                count_references[key][nt] = max(count_references[key].get(nt, 0), count[key])

    return count_references, is_referenced_by

# Usage example
count_refs, ref_by = calculate_non_terminal_references(productions, non_terminals)
print("count_refs")

pretty_print_dict(count_refs)
print("Ref by: ")
pretty_print_dict(ref_by)

count_refs
<expr>: {<expr>: 0, <number>: 0, <expr1>: 0, <var>: 0, <expr3>: 0, <op>: 0, <expr4>: 0, <integer>: 0, <digit>: 0, <uop>: 0, <start>: 1, <non-zero-digit>: 0, <expr2>: 0}
<number>: {<expr>: 1, <number>: 0, <expr1>: 1, <var>: 0, <expr3>: 1, <op>: 0, <expr4>: 1, <integer>: 0, <digit>: 0, <uop>: 0, <start>: 0, <non-zero-digit>: 0, <expr2>: 1}
<expr1>: {<expr>: 2, <number>: 0, <expr1>: 0, <var>: 0, <expr3>: 0, <op>: 0, <expr4>: 0, <integer>: 0, <digit>: 0, <uop>: 0, <start>: 0, <non-zero-digit>: 0, <expr2>: 0}
<var>: {<expr>: 1, <number>: 0, <expr1>: 1, <var>: 0, <expr3>: 1, <op>: 0, <expr4>: 1, <integer>: 0, <digit>: 0, <uop>: 0, <start>: 0, <non-zero-digit>: 0, <expr2>: 1}
<expr3>: {<expr>: 0, <number>: 0, <expr1>: 0, <var>: 0, <expr3>: 0, <op>: 0, <expr4>: 0, <integer>: 0, <digit>: 0, <uop>: 0, <start>: 0, <non-zero-digit>: 0, <expr2>: 2}
<op>: {<expr>: 1, <number>: 0, <expr1>: 1, <var>: 0, <expr3>: 1, <op>: 0, <expr4>: 0, <integer>: 0, <digit>: 0, <uop>: 0, <start>: 0, <non-ze

In [51]:
@jit(nopython=True)
def find_references(nt, start_symbol, is_referenced_by, count_references_by_prod):
    r = get_total_references_of_current_production(count_references_by_prod, nt)
    results = []

    if nt == start_symbol:
        return 1
    nt_str = str(nt)
    for ref in is_referenced_by[nt_str]:
        results.append(find_references(ref, start_symbol, is_referenced_by, count_references_by_prod))

    references = r * np.max(np.array(results))
    return references

@jit(nopython=True)
def get_total_references_of_current_production(count_references_by_prod, nt):
    nt_str = str(nt)
    return np.sum(np.array(list(count_references_by_prod[nt_str].values())))


nt = "<digit>"  # Replace this with a non-terminal from your grammar
nt = '<non-zero-digit>'
references = find_references(nt, '<start>', ref_by, count_refs)

display(references)


320

In [52]:
@njit
def seed(a):
    random.seed(a)

@njit
def rand():
    return random.random()

@njit
def random_int(a, b):
    return random.randint(a, b)


production_type = types.ListType(types.int64)
new_genotype_type = Dict.empty(
    key_type=types.unicode_type,
    value_type=production_type
)

@jit(nopython=False)
def create_full_tree(grammar, genotype,first_symbol, ref_by_dict, count_refs_dict, non_terminals_set):
    #print(non_terminals_set)
    #print(numba.typeof(non_terminals_set))
    #print(numba.version_info)
    # for every non-terminal, we create a vector of size equal to the upper bound of productions
    non_terminals_list = list(non_terminals_set)
    for symbol_index in prange(len(non_terminals_list)):
        symbol = str(non_terminals_list[symbol_index])

        upper_bound = find_references(symbol, first_symbol, ref_by_dict, count_refs_dict)

        productions_length = len(grammar[symbol])
        possible_productions = List.empty_list(types.int64)
        for i in range(upper_bound):
            codon = random_int(0, productions_length - 1)
            possible_productions.append(codon)

        genotype[symbol] = possible_productions

def create_individual_probabilistic(grammar, max_depth, genotype, symbol, non_terminals_set, depth):
    stack = [(symbol, depth)]

    is_terminal_cache = {s not in non_terminals_set for s in grammar}
    unique_depths = set()
    non_recursive_expansions_dict = get_non_recursive_expansions(grammar)

    while stack:
        symbol, depth = stack.pop()
        if depth not in unique_depths:
          #print(f"Reached new unique depth: {depth}")
          unique_depths.add(depth)
        production_rules = grammar[symbol]
        expansion_index = random.randint(0, len(production_rules) - 1)

        expansion = production_rules[expansion_index]

        # handle case where symbol is recursive, checking the dict
        is_expansion_rec = expansion in recursive_expansions[symbol]
        #if not is_symbol_rec:
          #print(f"Symbol {symbol} is non recursive")
        if is_expansion_rec:
              if depth >= max_depth:
                  non_rec_exps = non_recursive_expansions_dict[symbol]
                  if len(non_rec_exps) == 0:
                      print("Symbol", symbol, "has no non-recursive productions")
                      raise ValueError("No valid productions in this case")
                  expansion_index = random.choice(non_rec_exps)
                  expansion = grammar[symbol][expansion_index]

        if symbol in genotype:
            genotype[symbol].append(expansion_index)
        else:
            genotype[symbol] = [expansion_index]

        expansion_symbols = production_rules[expansion_index]

        for sym in expansion_symbols:
            if not is_terminal_cache.get(sym, True):
                stack.append((sym, depth + 1))
    print("Unique depths: ", unique_depths)
#
# def create_individual_recursive(grammar, max_depth, genotype, symbol, depth):
#     production_rules = grammar[symbol]
#     expansion_index = random.randint(0, len(production_rules) - 1)
#
#     expansion = production_rules[expansion_index]
#     # check if the symbol is a recursive terminal, ie, it can expand to itself
#     if is_recursive(symbol):
#         # check if the expansion is recursive
#         if expansion in recursive_expansions[symbol]:
#             if depth >= max_depth:
#                 # get non recursive productions of the symbol
#                 non_rec_exps = non_recursive_expansions[symbol]
#                 if len(non_rec_exps) == 0:
#                     print("Symbol", symbol, "has no non-recursive productions")
#                     return
#                 expansion_index = random.choice(non_rec_exps)
#                 expansion = grammar[symbol][expansion_index]
#     else:
#       print(f"Symbol {symbol} is non recursive!")
#
#     if symbol in genotype:
#         genotype[symbol].append(expansion_index)
#     else:
#         genotype[symbol] = [expansion_index]
#
#     expansion_symbols = production_rules[expansion_index]
#
#     for sym in expansion_symbols:
#         if not is_terminal(sym):
#             create_individual_probabilistic(grammar, max_depth, genotype, sym, depth + 1)

def create_genotype(grammar_file='variable_depth.bnf', max_depth=6, option='full'):
    new_genotype = new_genotype_type.copy()

    desired_depth_grammar, non_terminals, terminals = create_grammar_from_bnf(grammar_file, max_depth)
    first_symbol = next(iter(desired_depth_grammar.keys()))

    count_refs, ref_by = calculate_non_terminal_references(desired_depth_grammar, non_terminals)

    if option == 'full':
        create_full_tree(desired_depth_grammar, new_genotype, first_symbol, ref_by, count_refs, non_terminals)
    elif option == 'probabilistic':
        create_individual_probabilistic(desired_depth_grammar, max_depth, new_genotype, first_symbol, 0)
    return new_genotype, desired_depth_grammar


create_genotype(max_depth=10)


  @jit(nopython=False)
Encountered the use of a type that is scheduled for deprecation: type 'reflected set' found for argument 'non_terminals_set' of function 'create_full_tree'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types
[1m
File "C:\Users\diogo\AppData\Local\Temp\ipykernel_11884\4051938699.py", line 21:[0m
[1m@jit(nopython=False)
[1mdef create_full_tree(grammar, genotype,first_symbol, ref_by_dict, count_refs_dict, non_terminals_set):
[0m[1m^[0m[0m
[0m
Compilation is falling back to object mode WITH looplifting enabled because Function create_full_tree failed at nopython mode lowering due to: cannot store {i8*, i32, i8*, i8*, i32}* to i8*: mismatching types[0m
  @jit(nopython=False)
Compilation is falling back to object mode WITHOUT looplifting enabled because Function "create_full_tree" failed type inference due to: [1m[1mCannot determine Numba type of <class 'numba.core.d

(DictType[unicode_type,ListType[int64]]<iv=None>({<expr>: [2, ...], <number>: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [53]:
class Tree:

    def __init__(self, genome, productions):
        # get first rule for start symbol from the grammar

        self.productions = productions
        self.non_terminals = set(productions.keys())
        first_rule = next(iter(productions))
        first_production = productions[first_rule][0]
        self.root = Node(first_rule)

        # initialize OrderedDict with each non-terminal as a key and an empty list as the value, without list comprehension
        self.genome = genome

    def __repr__(self):
        #return f"Node({self.label}, {self.children})"
        # if it has children, call repr on each child
        return self.root.children[0].simple_repr()

    def __str__(self):
        return self.root.__str__()

    def _get_next_expansion(self):
        # find the first non-terminal that has not been expanded
        return self.root.find_first_unexpanded_non_terminal()

    def expand_next(self):
        node = self._get_next_expansion()
        if node is None:
            return False

        # get the vector of production indices for the current non-terminal
        production_indices = self.genome[node.label]

        if len(production_indices) == 0:
            raise ValueError(f"Genome for {node.label} is empty")


        # get the next production index
        production_index = production_indices.pop(0)

        # get the production for the desired non-terminal
        new_production = self.productions[node.label][production_index]

        #print("Expanding", node.label, "with", new_production)
        node.apply_rule(self.productions, new_production)
        return True


class Node:
    # create static variable

    def __init__(self, label, first_production=[]):
        self.label = label
        self.children = []
        self.is_terminal = False

        if len(first_production) > 0:
            self.apply_rule(first_production)
    def simple_repr(self):
        if len(self.children) > 0:
            merged_string = " ".join([child.simple_repr() for child in self.children])
            return merged_string
        return self.label

    def __repr__(self):
        #return f"Node({self.label}, {self.children})"
        # if it has children, call repr on each child

        if len(self.children) > 0:
            merged_string = " ".join([repr(child) for child in self.children])
            base = "Node(", self.label, ")"
            return f"{base} {merged_string}"

        return self.label
    def __str__(self):
        return self.label

    def __eq__(self, other):
        return self.label == other.label and self.children == other.children

    def __len__(self):
        return len(self.children)

    def __iter__(self):
        return iter(self.children)

    def __setitem__(self, index, value):
        self.children[index] = value

    def __getitem__(self, index):
        return self.children[index]
    def clean_parenthesis(self):
        pass

    def find_first_unexpanded_non_terminal(self):
        if self.is_terminal:
            return None

        if len(self.children) == 0:
            return self

        for child in self.children:
            non_terminal = child.find_first_unexpanded_non_terminal()
            if non_terminal is not None:
                return non_terminal

        return None

    def apply_rule(self, current_tree_productions, production: list):

        if production not in current_tree_productions[self.label]:
            raise ValueError(f"Production {production} not found in grammar for {self.label}")

        # create a new node for each symbol in the production

        children_list = list(map(Node, production))
        #print(current_tree_productions)
        for child in children_list:
            if child.label not in current_tree_productions:
                #print("Adding terminal ", child.label)
                child.is_terminal = True
        self.children = children_list

genotype, grammar = create_genotype(max_depth=5)
#pretty_print_dict(grammar)
tree = Tree(genotype, grammar)


while tree.expand_next():
    pass


print("Done")
display(tree)

Done


( 'abs' x )

In [54]:
def plot_tree(tree):
    G = nx.DiGraph()

    def add_edges(node):
        for child in node.children:
            G.add_edge(node.label, child.label)
            add_edges(child)

    add_edges(tree.root)

    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog="dot")
    nx.draw(G, pos, with_labels=True, node_size=2000, node_color="skyblue", font_size=12, font_weight="bold", arrowsize=20)
    plt.show()



In [61]:
def clean_string(target: str):
    # replace quotes with nothing
    return target.replace("'", "").replace('"', '')
def evaluate(node: Node, variables):
    if node.label == "<start>":
        return evaluate(node.children[0], variables)
    elif re.match(r'<expr\d*>', node.label):
        if len(node.children) == 1:  # <number>
            return evaluate(node.children[0], variables)
        else:
            operator_node = node.children[1]
            if len(node.children) == 4:  # ( <uop> <expr> )
                #print("Operator node: ", operator_node)
                uop = evaluate(operator_node, variables)
                uop = clean_string(uop)
                #print("Evaluated operator: [", uop, "]")
                operand_node = node.children[2]
                operand = evaluate(operand_node, variables)
                if operand is None:
                    print("Operand was ", operand_node, "before evaluation, error")
                    #operand = evaluate(operand_node)
                if uop == 'abs':
                    #print("Doing abs of ", operand, " =",abs(operand))
                    return abs(operand)
                else:
                    raise ValueError(f"Unknown unary operator {uop}")
            elif len(node.children) == 5:  # ( <op> <expr> <expr> )
                op = evaluate(operator_node, variables)
                op = clean_string(op)
                op1 = evaluate(node.children[2], variables)
                op2 = evaluate(node.children[3], variables)
                if op == "+":
                    return op1 + op2
                elif op == "-":
                    return op1 - op2
                elif op == "*":
                    return op1 * op2
                elif op == "/":
                    # safe division
                    if op2 == 0:
                        return 0;
                    return op1 / op2

    elif node.label == "<number>":
        return evaluate(node.children[0], variables)

    elif node.label == "<integer>":
        # Remove nodes where label is parenthesis from children
        parenthesis =  ['(', ')']
        digit_nodes = [child for child in node.children if child.label not in parenthesis]

        # Evaluate each digit node and join them together
        merged_digits = [str(evaluate(child, variables)) for child in digit_nodes]
        return int("".join(merged_digits))

    elif node.label == "<var>":
        desired_variable = node.children[0].label
        desired_variable = clean_string(desired_variable)
        if desired_variable not in variables:
            raise ValueError(f"Variable {desired_variable} not found in variables {variables}")
        return variables[desired_variable]
    elif node.label == "<op>":
        return node.children[0].label
    elif node.label == "<uop>":
        return node.children[0].label
    elif node.label in ["<non-zero-digit>", "<digit>"]:
        return evaluate(node.children[0], variables)
    elif node.label in "0123456789":
        return int(node.label)
    elif node.label == "(" or node.label == ")":
        return ""
    else:
        raise ValueError(f"Unexpected node label: {node.label}")

genotype, grammar = create_genotype(max_depth=5)
#pretty_print_dict(grammar)
tree = Tree(genotype, grammar)

while tree.expand_next():
    pass

print("Done")
variables = {'x': 42, 'y': 10}
display(tree)
value = evaluate(tree.root, variables)
print(value)

Done


( 'abs' ( '-' ( 9 0 ) x ) )

48
