In [None]:
operators_real = {
    'add': 2,
    'sub': 2,
    'mul': 2,
    'div': 2,
    'abs'    :1,
    'inv'    :1,
    'sqr'    :1,
    'sqrt'   :1,
    'log'    :1,
    'exp'    :1,
    'sin'    :1,
    'arcsin' :1,
    'cos'    :1,
    'arccos' :1,
    'tan'    :1,
    'arctan' :1,
}

operators_int = {
    'add': 2,
    'sub': 2,
    'mul': 2,
    'idiv':2,
    'mod': 2,
    'abs': 1,
    'sqr': 1,
    'relu': 1,
    'sign': 1,
    # 'step': 1,
}

operators_int_gen = {
    'add': 2,
    'sub': 2,
    'mul': 2,
    'idiv':2,
}

operators_extra = {
    'pow': 2
}

math_constants = ['e','pi','euler_gamma']

all_operators = {**operators_real, **operators_int, **operators_extra}

class Node():
    def __init__(self, value, params, children=None):
        self.value = value
        self.children = children if children else []
        self.params = params

    def push_child(self, child):
        self.children.append(child)

    def prefix(self):
        s = str(self.value)
        for c in self.children:
            s += ',' + c.prefix()
        return s

    # export to latex qtree format: prefix with \Tree, use package qtree
    def qtree_prefix(self):
        s = "[.$" + str(self.value) + "$ "
        for c in self.children:
            s += c.qtree_prefix()
        s += "]"
        return s

    def infix(self):
        nb_children = len(self.children)
        if nb_children <= 1:
            s = str(self.value)
            if nb_children == 1:
                if s == 'sqr': s = '(' + self.children[0].infix() + ')**2'
                else: s = s + '(' + self.children[0].infix() + ')'
            return s
        s = '(' + self.children[0].infix()
        for c in self.children[1:]:
            s = s + ' ' + str(self.value) + ' ' + c.infix()
        return s + ')'

    def __len__(self):
        lenc = 1
        for c in self.children:
            lenc += len(c)
        return lenc

    def __str__(self):
        # infix a default print
        return self.infix()
    
    def val(self, series, deterministic=False):
        """
        Calculating values of expression
        """
        curr_dim = len(series) %self.params.dimension
        # If the current node has no children, it represents a value or a constant
        if len(self.children) == 0:
            if str(self.value).startswith('x_'):
                _, dim, offset = self.value.split('_')
                dim, offset = int(dim), int(offset)
                dim_offset = dim-curr_dim
                return series[-offset*self.params.dimension+dim_offset]
            elif str(self.value) == 'n':
                return len(series)
            elif str(self.value) == 'rand':
                if deterministic: return 0
                if self.params.float_sequences:
                    return np.random.randn()
                else:
                    return int(np.random.choice([-1,0,1]))
            elif str(self.value) in math_constants:
                return getattr(np, str(self.value))
            else:
                return eval(self.value)
        # If the current node has children, it represents an operation.        
        if self.value == 'add':
            return self.children[0].val(series) + self.children[1].val(series)
        if self.value == 'sub':
            return self.children[0].val(series) - self.children[1].val(series)
        if self.value == 'mul':
            return self.children[0].val(series) * self.children[1].val(series)
        if self.value == 'pow':
            return self.children[0].val(series) ** self.children[1].val(series)
        if self.value == 'max':
            return max(self.children[0].val(series), self.children[1].val(series))
        if self.value == 'min':
            return min(self.children[0].val(series), self.children[1].val(series))
        if self.value == 'mod':
            if self.children[1].val(series)==0: return np.nan
            else: return self.children[0].val(series) % self.children[1].val(series) 
        if self.value == 'div':
            if self.children[1].val(series)==0: return np.nan
            else: return self.children[0].val(series) / self.children[1].val(series)
        if self.value == 'idiv':
            if self.children[1].val(series)==0: return np.nan
            else: return self.children[0].val(series) // self.children[1].val(series)
        if self.value == 'inv':
            return 1/(self.children[0].val(series))
        if self.value == 'sqr':
            return (self.children[0].val(series))**2
        if self.value == 'abs':
            return abs(self.children[0].val(series))
        if self.value == 'sign':
            return int(self.children[0].val(series)>=0)*2-1
        if self.value == "relu":
            x = self.children[0].val(series)
            return x if x>0 else 0
        if self.value == "step":
            x = self.children[0].val(series)
            return 1 if x>0 else 0
        if self.value == "id":
            return self.children[0].val(series)
        
    def get_recurrence_degree(self):
        recurrence_degree=0
        if len(self.children) == 0:
            if str(self.value).startswith('x_'):
                _, _, offset = self.value.split('_')
                offset=int(offset)
                if offset>recurrence_degree:
                    recurrence_degree=offset
            return recurrence_degree
        return max([child.get_recurrence_degree() for child in self.children])
    
    def get_n_ops(self):
        if self.value in all_operators:
            return 1 + sum([child.get_n_ops() for child in self.children])
        else: 
            return 0

In [None]:
def generate_tree_poly(self):
    tree = Node('mul', self.params) # creates a new node with an initial value of 0
    #nodes = [tree] # any modification to the nodes in empty_nodes is reflected in tree
    p_coeff = ['2','3','4']
    nb_ops = len(p_coeff)
    print('coeff=', p_coeff)

    # Start with the main multiplication node
    main_mul_node = tree

    for i in range(nb_ops):
        # Start with the sub node
        sub_tree = Node('sub', self.params)
        # nodes.append(sub_tree)

        # 1 node
        one_node = Node('1', self.params)
        sub_tree.push_child(one_node)
        # nodes.append(one_node)

        # Multiplication node
        mul_node = Node('mul', self.params)
        sub_tree.push_child(mul_node)
        # nodes.append(mul_node)

        # n node (child of multiplication node)
        n_node = Node('n', self.params)
        mul_node.push_child(n_node)
        # nodes.append(n_node)

        # Constant node (child of multiplication node)
        const_node = Node(p_coeff[i], self.params)
        mul_node.push_child(const_node)
        # nodes.append(const_node)

        # Add this sub-tree to the main multiplication node
        main_mul_node.push_child(sub_tree)
        # Multiplication node
        #mul_node = Node('mul', self.params)
        #main_mul_node.push_child(mul_node)
        #nodes.append(mul_node)
            
    print('gen_tree:',tree)
    #for n in nodes:
    #    print(n.value)
    #tree = self.check_tree(tree, degree)
    
    return tree

In [None]:
trees = []
trees.append(generate_tree_poly)
series = []
for i in range(10):
    vals = trees.val(series)
    series.extend(vals)

print('remaining points:',series)