In [1]:
from copy import deepcopy
from itertools import product, combinations_with_replacement
from operator import itemgetter
from treelib import Tree
from sympy import srepr, sympify

In [39]:
max_depth = 2

In [74]:
all_nodes = {
    'x': {
        'children': 0,
        'op': "Symbol('x')"
    },
    'c': {
        'children': 0,
        'op': "Symbol('c')" # Konstante
    },
    'exp': {
        'children': 1,
        'op': 'exp'
    },
    'sin': {
        'children': 1,
        'op': 'sin'
    },
    'cos': {
        'children': 1,
        'op': 'cos'
    },
    '+': {
        'children': 2,
        'op': 'Add',
        'commutative': True
    },
    '-': {
        'children': 2,
        'op': 'Add',
        'commutative': False
    },
    '*': {
        'children': 2,
        'op': 'Mul',
        'commutative': True
    },
    '/': {
        'children': 2,
        'op': 'Mul',
        'commutative': False
    }
}

In [67]:
def sympify_tree(tree, as_expr=False):
    expr = ''
    last_op_tag = '' # Needed for tracking - and / (special cases)
    close_in = [] # Tracks open brackets
    
    for node in tree.expand_tree(mode=Tree.DEPTH, sorting=False):
        tag = tree[node].tag
        children, op = itemgetter('children', 'op')(all_nodes[tag])
        if children == 0:
            # Add operator/function to expression
            # sympy doesn't support subtract & divide, so it needs to be replaced by add & multiply
            if len(close_in) == 0:
                expr = op
                break
            elif close_in[-1] == 1 and last_op_tag == '-':
                expr += f'Mul(Integer(-1), {op})'
            elif close_in[-1] == 1 and last_op_tag == '/':
                expr += f'Pow({op}, Integer(-1))'
            else:
                expr += op

            # Close all open brackets that need to be closed in 1
            for i, n in reversed(list(enumerate(close_in))):
                close_in[i] -= 1
                if n <= 1:
                    expr += ')'
                else:
                    expr += ','
                    break
            close_in = [n for n in close_in if n > 0]
        else:
            expr += f'{op}('
            close_in.append(children)
            last_op_tag = tag

    sym_expr = sympify(expr)
    return sym_expr if as_expr else srepr(sym_expr)

In [75]:
leaf_keys = [key for key, value in all_nodes.items() if value['children'] == 0]
print(leaf_keys)
commutative_pairs = list(combinations_with_replacement(all_nodes.keys(), 2))
all_pairs = list(product(all_nodes.keys(), repeat = 2))
commutative_leaf_pairs = list(combinations_with_replacement(leaf_keys, 2))
all_leaf_pairs = list(product(leaf_keys, repeat = 2))

def possible_children(node, depth):
    children = all_nodes[node]['children']
    is_internal_node = depth < (max_depth - 1)
    if children == 0:
        raise ValueError('Leaf nodes can not have any children')
    if children == 1:
        return list(all_nodes.keys()) if is_internal_node else leaf_keys
    elif all_nodes[node]['commutative']:
        return commutative_pairs if is_internal_node else commutative_leaf_pairs
    else:
        return all_pairs if is_internal_node else all_leaf_pairs
    
def format_combination(combination):
    return [(c,) if type(c) == str else c for c in combination]
    
def generate_trees(tree, depth):
    if (depth >= max_depth):
        yield tree
    else:
        # expand tree at leaves; 'x' and 'c' leaves can't have any more children and are excluded
        leaves = list(tree.filter_nodes(lambda n: n.is_leaf() and n.tag not in leaf_keys))
        if len(leaves) == 0:
            yield tree
        else:
            # depending on the node type, the leaves have different amounts of children, see "all_possible_nodes"
            leaves_children = [possible_children(leaf.tag, depth) for leaf in leaves]
            for combination in product(*leaves_children):
                combination = format_combination(combination)
                new_tree = deepcopy(tree)
                for index, leaf in enumerate(leaves):
                    for c in combination[index]:
                        new_tree.create_node(c, parent=leaf)
                yield from generate_trees(new_tree, depth + 1)
                
with open('expressions3.csv', 'w') as file:
    for key in all_nodes.keys():
        base_tree = Tree()
        base_tree.create_node(key)
        for tree in generate_trees(base_tree, 0):
            # file.write(','.join([tree[node].tag for node in tree.expand_tree(mode=Tree.DEPTH, sorting=False)]))
            # file.write('\n')
            file.write(sympify_tree(tree))
            file.write('\n')
        
file.close()

['x', 'c']


In [73]:
read_file = 'expressions3.csv'
write_file = 'uniques_depth2_old.csv'

uniques = set()
with open(read_file, 'r') as file:
  for line in file:
    uniques.add(line)

with open(write_file, 'w') as file:
  for line in uniques:
    file.write(line)