In [1]:
import os
import pandas as pd
import re
from collections import OrderedDict
from IPython.display import display
import random
import networkx as nx
import matplotlib.pyplot as plt
import sys
 

#sys.setrecursionlimit(10**4)
# import display
def load_and_print_csvs_from_folders():
    cwd = os.getcwd()
    folders = [d for d in os.listdir(cwd) if os.path.isdir(os.path.join(cwd, d))]
    data = {}
    for folder in folders:
        folder_path = os.path.join(cwd, folder)
        for file in os.listdir(folder_path):
            if file.endswith('.csv'):
                file_path = os.path.join(folder_path, file)
                df = pd.read_csv(file_path)
                data[file] = df

    return data

def pretty_print_dict(dictionary, indent=0):
    for key, value in dictionary.items():
        if isinstance(value, dict):
            print('  ' * indent + str(key) + ':')
            pretty_print_dict(value, indent+1)
        else:
            print('  ' * indent + str(key) + ': ' + str(value))


#load_and_print_csvs_from_folders()


In [2]:
def read_bnf_file(file_path):
    # read file as string
    with open(file_path, 'r') as f:
        bnf_lines = f.read()
    return bnf_lines

def parse_bnf(grammar_str, depth=1):
    productions = OrderedDict()
    grammar_str = re.sub(r'\s*\|\s*', ' |', grammar_str)
    lines = [line.strip() for line in grammar_str.split('\n') if line.strip()]
    PRODUCTION_SEPARATOR = '::='
    RULE_SEPARATOR = '|'
    RECURSIVE_PRODUCTIONS = ['<expr>']
    for line in lines:
        lhs, rhs = line.split(PRODUCTION_SEPARATOR)
        lhs = lhs.strip()
        rhs_productions = []
        for rule in rhs.split(RULE_SEPARATOR):
            rule_list = list(rule.strip().split())
            rhs_productions.append(rule_list)

        productions[lhs] = rhs_productions

    def list_new_non_recursive_expressions(recursive_rule):
        # create depth expressions
        expressions = [recursive_rule]
        for i in range(1, depth):
            expr_rule = recursive_rule.replace('>', f'{i}>')
            expressions.append(expr_rule)
        return expressions

    def filter_non_recursive_productions(recursive_rule):
        non_recursive_productions = set()
        for exp in productions[recursive_rule]:
            if recursive_rule not in exp:
                non_recursive_productions.add(tuple(exp))

        return non_recursive_productions

    def replace_expression_in_production(old_rule, old, new):
        new_rule = list(old_rule)
        for i, exp in enumerate(new_rule):
            if new_rule[i] == old:
                new_rule[i] = new
        return new_rule

    def fix_recursive_production(recursive_rule):
        non_recursive_productions = filter_non_recursive_productions(recursive_rule)
        if depth == 1:
            productions[recursive_rule] = non_recursive_productions
            return

        new_expressions = list_new_non_recursive_expressions(recursive_rule)

        all_production = set(tuple(p) for p in productions[recursive_rule])
        recursive_productions = all_production - non_recursive_productions

        for i in range(0, depth-1):
            expr_rule = new_expressions[i]
            new_productions = []
            for rule in recursive_productions:
                new_rule = replace_expression_in_production(rule, recursive_rule, new_expressions[i+1])
                new_productions.append(new_rule)
            productions[expr_rule] = [*non_recursive_productions, *new_productions]


        productions[new_expressions[-1]] = list(non_recursive_productions)

    for recursive_rule in RECURSIVE_PRODUCTIONS:
        fix_recursive_production(recursive_rule)


    return productions

demo_bnf = """
<start> ::= <expr> <op> <expr>
<expr> ::= <term> <op> <term> | '(' <term> <op> <term> ')'
<op> ::= '+' | '-' | '/' | '*'
<term> ::= 'x1' | '0.5'
"""

bnf_lines = read_bnf_file('variable_depth.bnf')

def get_terminals(productions):
    terminals = set()
    non_terminals = set(productions.keys())

    for rhs in productions.values():
        for rule in rhs:
            for token in rule:
                if token not in non_terminals:
                    terminals.add(token)

    return terminals

def create_grammar_from_bnf(bnf_file, depth=1):
    bnf_lines = read_bnf_file(bnf_file)
    productions = parse_bnf(bnf_lines, depth)
    non_terminals = set(productions.keys())
    terminals = get_terminals(productions)
    return productions, non_terminals, terminals

productions = parse_bnf(bnf_lines, 5)
terminals = set(productions.keys())
pretty_print_dict(productions)

<start>: [['<expr>']]
<expr>: [('<number>',), ['(', '<uop>', '<expr1>', ')'], ['(', '<op>', '<expr1>', '<expr1>', ')']]
<number>: [['<integer>']]
<integer>: [['(', '<digit>', ')'], ['(', '<non-zero-digit>', '<digit>', ')'], ['(', '<non-zero-digit>', '<digit>', '<digit>', ')']]
<digit>: [['0'], ['<non-zero-digit>']]
<non-zero-digit>: [['1'], ['2'], ['3'], ['4'], ['5'], ['6'], ['7'], ['8'], ['9']]
<op>: [["'+'"], ["'-'"], ["'*'"], ["'/'"]]
<uop>: [["'abs'"]]
<expr1>: [('<number>',), ['(', '<uop>', '<expr2>', ')'], ['(', '<op>', '<expr2>', '<expr2>', ')']]
<expr2>: [('<number>',), ['(', '<uop>', '<expr3>', ')'], ['(', '<op>', '<expr3>', '<expr3>', ')']]
<expr3>: [('<number>',), ['(', '<uop>', '<expr4>', ')'], ['(', '<op>', '<expr4>', '<expr4>', ')']]
<expr4>: [('<number>',)]


In [3]:
terminals = get_terminals(productions)

print("terminals")
print(terminals)
non_terminals = set(productions.keys())
print("non_terminals")
print(non_terminals)


terminals
{'3', "'+'", '4', '9', "'*'", '6', '2', '5', ')', "'-'", '8', '1', '7', '0', "'abs'", "'/'", '('}
non_terminals
{'<expr>', '<uop>', '<expr4>', '<op>', '<integer>', '<non-zero-digit>', '<start>', '<number>', '<digit>', '<expr1>', '<expr2>', '<expr3>'}


In [4]:
def find_recursive_and_non_recursive_terminals(grammar):
    recursive_terminals = set()
    non_recursive_terminals = set()

    def is_recursive(nt, visited):
        if nt in visited:
            return True
        visited.add(nt)
        for rule in grammar[nt]:
            for token in rule:
                if token in non_terminals and is_recursive(token, visited):
                    return True
        visited.remove(nt)
        return False

    for nt in non_terminals:
        if is_recursive(nt, set()):
            recursive_terminals.add(nt)
        else:
            non_recursive_terminals.add(nt)

    non_recursive_terminals |= terminals

    return recursive_terminals, non_recursive_terminals


recursive_terminals, non_recursive_terminals = find_recursive_and_non_recursive_terminals(productions)

display("Recursive terminals:", recursive_terminals)
display("Non-recursive terminals:", non_recursive_terminals)


'Recursive terminals:'

set()

'Non-recursive terminals:'

{"'*'",
 "'+'",
 "'-'",
 "'/'",
 "'abs'",
 '(',
 ')',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '<digit>',
 '<expr1>',
 '<expr2>',
 '<expr3>',
 '<expr4>',
 '<expr>',
 '<integer>',
 '<non-zero-digit>',
 '<number>',
 '<op>',
 '<start>',
 '<uop>'}

In [5]:
def calculate_non_recursive_productions(symbol, grammar):
    non_recursive_indices = []
    recursive_terminals_in_grammar, _ = find_recursive_and_non_recursive_terminals(grammar)

    if symbol not in grammar:
        return non_recursive_indices

    for i, rule in enumerate(grammar[symbol]):
        if all(token in recursive_terminals_in_grammar for token in rule):
            non_recursive_indices.append(i)

    return non_recursive_indices

def calculate_recursive_productions(non_terminal, grammar, non_recursive_productions):
    all_productions = set(range(len(grammar[non_terminal])))
    non_recursive_set = set(non_recursive_productions[non_terminal])
    return list(all_productions - non_recursive_set)

def get_non_recursive_expansions(grammar):
    non_recursive_expansions_set = OrderedDict()
    for nt in non_terminals:
        non_recursive_expansions_set[nt] = calculate_non_recursive_productions(nt, grammar)
    return non_recursive_expansions_set

def get_recursive_expansions(grammar, non_recursive_expansions_dict):
    recursive_expansions_set = OrderedDict()
    for nt in non_terminals:
        recursive_expansions_set[nt] = calculate_recursive_productions(nt, grammar, non_recursive_expansions_dict)
    return recursive_expansions_set

# create the non-recursive dictionary for each non-terminal
non_recursive_expansions = get_non_recursive_expansions(productions)
recursive_expansions = get_recursive_expansions(productions, non_recursive_expansions)

print("Non recursive expansions per terminal:")
pretty_print_dict(non_recursive_expansions)


print("Recursive expansions per terminal:")
pretty_print_dict(recursive_expansions)


Non recursive expansions per terminal:
<expr>: []
<uop>: []
<expr4>: []
<op>: []
<integer>: []
<non-zero-digit>: []
<start>: []
<number>: []
<digit>: []
<expr1>: []
<expr2>: []
<expr3>: []
Recursive expansions per terminal:
<expr>: [0, 1, 2]
<uop>: [0]
<expr4>: [0]
<op>: [0, 1, 2, 3]
<integer>: [0, 1, 2]
<non-zero-digit>: [0, 1, 2, 3, 4, 5, 6, 7, 8]
<start>: [0]
<number>: [0]
<digit>: [0, 1]
<expr1>: [0, 1, 2]
<expr2>: [0, 1, 2]
<expr3>: [0, 1, 2]


In [6]:
def calculate_non_terminal_references(grammar, non_terminals_set):
    count_references = {nt: {} for nt in non_terminals_set}
    is_referenced_by = {nt: [] for nt in non_terminals_set}

    for nt in non_terminals_set:
        for production in grammar[nt]:
            count = {option: 0 for option in non_terminals_set}
            for option in production:
                if option in non_terminals_set:
                    is_referenced_by[option].append(nt)
                    count[option] += 1
            for key in count:
                count_references[key][nt] = max(count_references[key].get(nt, 0), count[key])

    return count_references, is_referenced_by

# Usage example
count_refs, ref_by = calculate_non_terminal_references(productions, non_terminals)
print("Ref count: ")

pretty_print_dict(count_refs)
print("Ref by: ")
pretty_print_dict(ref_by)

Ref count: 
<expr>:
  <expr>: 0
  <uop>: 0
  <expr4>: 0
  <op>: 0
  <integer>: 0
  <non-zero-digit>: 0
  <start>: 1
  <number>: 0
  <digit>: 0
  <expr1>: 0
  <expr2>: 0
  <expr3>: 0
<uop>:
  <expr>: 1
  <uop>: 0
  <expr4>: 0
  <op>: 0
  <integer>: 0
  <non-zero-digit>: 0
  <start>: 0
  <number>: 0
  <digit>: 0
  <expr1>: 1
  <expr2>: 1
  <expr3>: 1
<expr4>:
  <expr>: 0
  <uop>: 0
  <expr4>: 0
  <op>: 0
  <integer>: 0
  <non-zero-digit>: 0
  <start>: 0
  <number>: 0
  <digit>: 0
  <expr1>: 0
  <expr2>: 0
  <expr3>: 2
<op>:
  <expr>: 1
  <uop>: 0
  <expr4>: 0
  <op>: 0
  <integer>: 0
  <non-zero-digit>: 0
  <start>: 0
  <number>: 0
  <digit>: 0
  <expr1>: 1
  <expr2>: 1
  <expr3>: 1
<integer>:
  <expr>: 0
  <uop>: 0
  <expr4>: 0
  <op>: 0
  <integer>: 0
  <non-zero-digit>: 0
  <start>: 0
  <number>: 1
  <digit>: 0
  <expr1>: 0
  <expr2>: 0
  <expr3>: 0
<non-zero-digit>:
  <expr>: 0
  <uop>: 0
  <expr4>: 0
  <op>: 0
  <integer>: 1
  <non-zero-digit>: 0
  <start>: 0
  <number>: 0
  <digit>

In [7]:
def find_references(nt, start_symbol, is_referenced_by, count_references_by_prod):
    r = get_total_references_of_current_production(count_references_by_prod, nt)
    #print("r: ", r)
    results = []

    if nt == start_symbol:
        return 1

    for ref in is_referenced_by[nt]:
        results.append(find_references(ref, start_symbol, is_referenced_by, count_references_by_prod))

    references = r * max(results)
    return references

def get_total_references_of_current_production(count_references_by_prod, nt):
    return sum(count_references_by_prod[nt].values())


nt = "<digit>"  # Replace this with a non-terminal from your grammar
nt = '<non-zero-digit>'
references = find_references(nt, '<start>', ref_by, count_refs)

display(references)


320

In [15]:
def create_full_tree(grammar, genotype,first_symbol,  ref_by_dict, count_refs_dict, non_terminals_set):
    # for every non-terminal, we create a vector of size equal to the upper bound of productions
    for symbol in non_terminals_set:
        upper_bound = find_references(symbol, first_symbol, ref_by_dict, count_refs_dict)
        possible_productions = [0] * upper_bound
        productions_length = len(grammar[symbol])
        # pick a random production for each possibility
        for i in range(upper_bound):
            possible_productions[i] = random.randint(0, productions_length - 1)

        genotype[symbol] = possible_productions

def create_individual_probabilistic(grammar, max_depth, genotype, symbol, non_terminals_set, depth):
    stack = [(symbol, depth)]

    is_terminal_cache = {s not in non_terminals_set for s in grammar}
    unique_depths = set()
    non_recursive_expansions_dict = get_non_recursive_expansions(grammar)

    while stack:
        symbol, depth = stack.pop()
        if depth not in unique_depths:
          #print(f"Reached new unique depth: {depth}")
          unique_depths.add(depth)
        production_rules = grammar[symbol]
        expansion_index = random.randint(0, len(production_rules) - 1)

        expansion = production_rules[expansion_index]

        # handle case where symbol is recursive, checking the dict
        is_expansion_rec = expansion in recursive_expansions[symbol]
        #if not is_symbol_rec:
          #print(f"Symbol {symbol} is non recursive")
        if is_expansion_rec:
              if depth >= max_depth:
                  non_rec_exps = non_recursive_expansions_dict[symbol]
                  if len(non_rec_exps) == 0:
                      print("Symbol", symbol, "has no non-recursive productions")
                      raise ValueError("No valid productions in this case")
                  expansion_index = random.choice(non_rec_exps)
                  expansion = grammar[symbol][expansion_index]

        if symbol in genotype:
            genotype[symbol].append(expansion_index)
        else:
            genotype[symbol] = [expansion_index]

        expansion_symbols = production_rules[expansion_index]

        for sym in expansion_symbols:
            if not is_terminal_cache.get(sym, True):
                stack.append((sym, depth + 1))
    print("Unique depths: ", unique_depths)
#
# def create_individual_recursive(grammar, max_depth, genotype, symbol, depth):
#     production_rules = grammar[symbol]
#     expansion_index = random.randint(0, len(production_rules) - 1)
#
#     expansion = production_rules[expansion_index]
#     # check if the symbol is a recursive terminal, ie, it can expand to itself
#     if is_recursive(symbol):
#         # check if the expansion is recursive
#         if expansion in recursive_expansions[symbol]:
#             if depth >= max_depth:
#                 # get non recursive productions of the symbol
#                 non_rec_exps = non_recursive_expansions[symbol]
#                 if len(non_rec_exps) == 0:
#                     print("Symbol", symbol, "has no non-recursive productions")
#                     return
#                 expansion_index = random.choice(non_rec_exps)
#                 expansion = grammar[symbol][expansion_index]
#     else:
#       print(f"Symbol {symbol} is non recursive!")
#
#     if symbol in genotype:
#         genotype[symbol].append(expansion_index)
#     else:
#         genotype[symbol] = [expansion_index]
#
#     expansion_symbols = production_rules[expansion_index]
#
#     for sym in expansion_symbols:
#         if not is_terminal(sym):
#             create_individual_probabilistic(grammar, max_depth, genotype, sym, depth + 1)

def create_genotype(grammar_file='variable_depth.bnf', max_depth=6, option='full'):
    new_genotype = {}


    desired_depth_grammar, non_terminals, terminals = create_grammar_from_bnf(grammar_file, max_depth)

    first_symbol = next(iter(desired_depth_grammar.keys()))


    count_refs, ref_by = calculate_non_terminal_references(desired_depth_grammar, non_terminals)

    if option == 'full':
        create_full_tree(desired_depth_grammar, new_genotype, first_symbol, ref_by, count_refs, non_terminals)
    elif option == 'probabilistic':
        create_individual_probabilistic(desired_depth_grammar, max_depth, new_genotype, first_symbol, 0)
    return new_genotype, desired_depth_grammar

create_genotype(max_depth=10)


({'<expr>': [0],
  '<expr7>': [0,
   0,
   0,
   0,
   1,
   0,
   1,
   1,
   1,
   2,
   0,
   0,
   2,
   1,
   2,
   1,
   0,
   0,
   1,
   0,
   0,
   2,
   1,
   2,
   0,
   1,
   1,
   0,
   1,
   2,
   2,
   0,
   1,
   1,
   2,
   0,
   1,
   1,
   0,
   0,
   0,
   2,
   2,
   1,
   0,
   2,
   0,
   2,
   1,
   1,
   1,
   2,
   0,
   2,
   2,
   0,
   0,
   1,
   2,
   2,
   2,
   1,
   1,
   1,
   1,
   0,
   1,
   2,
   1,
   1,
   0,
   2,
   2,
   1,
   0,
   2,
   1,
   2,
   0,
   0,
   1,
   1,
   0,
   2,
   2,
   2,
   2,
   2,
   2,
   1,
   2,
   2,
   1,
   0,
   0,
   0,
   0,
   2,
   2,
   0,
   0,
   0,
   2,
   2,
   0,
   0,
   0,
   0,
   0,
   2,
   1,
   2,
   0,
   1,
   2,
   2,
   1,
   1,
   1,
   2,
   2,
   2,
   0,
   1,
   2,
   1,
   1,
   0],
  '<expr5>': [2,
   2,
   1,
   2,
   0,
   2,
   2,
   2,
   0,
   0,
   2,
   1,
   1,
   1,
   0,
   1,
   0,
   2,
   1,
   1,
   1,
   0,
   0,
   1,
   1,
   1,
   0,
   2,
   2,
   1,
   2,
   1],

In [23]:
class Tree:

    def __init__(self, genome=OrderedDict(), productions=OrderedDict()):
        # get first rule for start symbol from the grammar

        self.productions = productions
        self.non_terminals = set(productions.keys())
        first_rule = next(iter(productions))
        first_production = productions[first_rule][0]
        self.root = Node(first_rule)

        # initialize OrderedDict with each non-terminal as a key and an empty list as the value, without list comprehension
        self.genome = genome

    def __repr__(self):
        return self.root.__repr__()

    def __str__(self):
        return self.root.__str__()

    def _get_next_expansion(self):
        # find the first non-terminal that has not been expanded
        return self.root.find_first_unexpanded_non_terminal()

    def expand_next(self):
        node = self._get_next_expansion()
        if node is None:
            return False

        # get the vector of production indices for the current non-terminal
        production_indices = self.genome[node.label]

        if len(production_indices) == 0:
            raise ValueError(f"Genome for {node.label} is empty")


        # get the next production index
        production_index = production_indices.pop(0)

        # get the production for the desired non-terminal
        new_production = self.productions[node.label][production_index]

        #print("Expanding", node.label, "with", new_production)
        node.apply_rule(self.productions, new_production)
        return True


class Node:
    # create static variable

    def __init__(self, label, first_production=[]):
        self.label = label
        self.children = []
        self.is_terminal = False

        if label in terminals:
            self.is_terminal = True

        if len(first_production) > 0:
            self.apply_rule(first_production)

    def __repr__(self):
        #return f"Node({self.label}, {self.children})"
        # if it has children, call repr on each child
        if len(self.children) > 0:
            merged_string = " ".join([repr(child) for child in self.children])
            return merged_string
        else:
            return self.label
    def __str__(self):
        return self.label

    def __eq__(self, other):
        return self.label == other.label and self.children == other.children

    def __len__(self):
        return len(self.children)

    def __iter__(self):
        return iter(self.children)

    def __setitem__(self, index, value):
        self.children[index] = value

    def __getitem__(self, index):
        return self.children[index]

    def find_first_unexpanded_non_terminal(self):
        if self.is_terminal:
            return None

        if len(self.children) == 0:
            return self

        for child in self.children:
            non_terminal = child.find_first_unexpanded_non_terminal()
            if non_terminal is not None:
                return non_terminal

        return None

    def apply_rule(self, current_tree_productions, production: list):

        if production not in current_tree_productions[self.label]:
            raise ValueError(f"Production {production} not found in grammar for {self.label}")

        # create a new node for each symbol in the production

        self.children = list(map(Node, production))
        self.is_terminal = len(production) == 1 and production[0] in terminals

In [10]:
def plot_tree(tree):
    G = nx.DiGraph()

    def add_edges(node):
        for child in node.children:
            G.add_edge(node.label, child.label)
            add_edges(child)

    add_edges(tree.root)

    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog="dot")
    nx.draw(G, pos, with_labels=True, node_size=2000, node_color="skyblue", font_size=12, font_weight="bold", arrowsize=20)
    plt.show()



In [25]:
# create a node for the start symbol
base_genome = OrderedDict()

"""
base_genome['<start>'] = [0]
base_genome['<expr>'] = [1, 0]
base_genome['<op>'] = [2, 0, 3]
base_genome['<term>'] = [1, 1, 0, 0]
"""
genotype, grammar = create_genotype(max_depth=10)
pretty_print_dict(grammar)
tree = Tree(genotype, grammar)


while tree.expand_next():
    pass


print("Done")
display(tree)

#plot_tree(tree)


<start>: [['<expr>']]
<expr>: [('<number>',), ['(', '<uop>', '<expr1>', ')'], ['(', '<op>', '<expr1>', '<expr1>', ')']]
<number>: [['<integer>']]
<integer>: [['(', '<digit>', ')'], ['(', '<non-zero-digit>', '<digit>', ')'], ['(', '<non-zero-digit>', '<digit>', '<digit>', ')']]
<digit>: [['0'], ['<non-zero-digit>']]
<non-zero-digit>: [['1'], ['2'], ['3'], ['4'], ['5'], ['6'], ['7'], ['8'], ['9']]
<op>: [["'+'"], ["'-'"], ["'*'"], ["'/'"]]
<uop>: [["'abs'"]]
<expr1>: [('<number>',), ['(', '<uop>', '<expr2>', ')'], ['(', '<op>', '<expr2>', '<expr2>', ')']]
<expr2>: [('<number>',), ['(', '<uop>', '<expr3>', ')'], ['(', '<op>', '<expr3>', '<expr3>', ')']]
<expr3>: [('<number>',), ['(', '<uop>', '<expr4>', ')'], ['(', '<op>', '<expr4>', '<expr4>', ')']]
<expr4>: [('<number>',), ['(', '<uop>', '<expr5>', ')'], ['(', '<op>', '<expr5>', '<expr5>', ')']]
<expr5>: [('<number>',), ['(', '<uop>', '<expr6>', ')'], ['(', '<op>', '<expr6>', '<expr6>', ')']]
<expr6>: [('<number>',), ['(', '<uop>', '<ex

( '-' ( 9 0 0 ) ( 'abs' ( 'abs' ( '/' ( '-' ( '*' ( '-' ( '+' ( '-' ( 1 0 ) ( 9 ) ) ( 9 7 0 ) ) ( 4 ) ) ( '/' ( 'abs' ( '*' ( 9 0 ) ( 1 7 ) ) ) ( '*' ( 'abs' ( 9 0 ) ) ( 'abs' ( 7 ) ) ) ) ) ( 'abs' ( 'abs' ( '+' ( 9 ) ( 0 ) ) ) ) ) ( 6 7 ) ) ) ) )