In [1]:
%load_ext autoreload
%autoreload 2
from polysym.torch_operators_2 import Operators
from polysym.model import PolySymModel
from deap import gp, base, creator, tools
from polysym.utils import _RandConst
import sympy as sp

# Algo:
# Each node has to choose between: operator/primitive vs. variable/terminal
# Each node has a constraint from its upper node:
# 0, 1 or 2: if zero then free to choose anything, 1 then must choose scalar 2 must choose vector

# one rule may be that we never put a reducer operator when required node rank is 1

# for all other operators, we

In [41]:

# ---------- 1. real types ----------
class Scalar:  pass
class Vector:  pass
S, V = Scalar, Vector          # aliases

# ---------- 2. mock arrays ----------
class Arr:
    def __init__(self, a, b):  self.shape = (a, b)

# ---------- 3. model ----------
class ModelTest:
    def __init__(self):
        self.objective = 1                     # want a vector output
        self.X2d = Arr(0, 1)                  # 1 scalar input  (x0)
        self.X3d = Arr(0, 2)                  # 2 vector inputs (v0,v1)
        self.min_constant, self.max_constant = -10, 10
        self.operators = Operators(select_all=True)

        # sympy symbols ---------------------------------------------------
        n2, n3 = self.X2d.shape[1], self.X3d.shape[1]
        sy2 = [sp.symbols(f"x{i}") for i in range(n2)]
        sy3 = [sp.symbols(f"v{j}") for j in range(n3)]
        self.symbols = sy2 + sy3

        # build typed GP tool‑box -----------------------------------------
        self.pset = self._build_primitives()
        self.toolbox = self._setup_gp()

    # ---------- 4. primitive set ----------
    def _build_primitives(self):
        ret_type = V if self.objective == 2 else S
        # we pass an empty argument‑type list so DEAP does not create ARGx terminals
        pset = gp.PrimitiveSetTyped("MAIN", [], ret_type)

        # terminals: variables --------------------------------------------
        for idx, sym in enumerate(self.symbols):
            typ = S if idx < self.X2d.shape[1] else V
            pset.addTerminal(sym, typ, name=str(sym))

        # terminals: ephemeral scalar constants ---------------------------
        pset.addEphemeralConstant("randc",
                                  _RandConst(self.min_constant, self.max_constant),
                                  S)
        pset.arguments = []  # ensure no ARGx terminals linger

        # primitives: unary ----------------------------------------------
        for name, (fn, _, rank) in self.operators.unary_nonreduce.items():
            if rank == 0:                          # same-in same-out
                pset.addPrimitive(fn, [S], S, name=name)
                pset.addPrimitive(fn, [V], V, name=name)
            elif rank == 1:                        # scalar‑only
                pset.addPrimitive(fn, [S], S, name=name)
            elif rank == 2:                        # vector‑only
                pset.addPrimitive(fn, [V], V, name=name)
        for name, (fn, _, _) in self.operators.unary_reduce.items():
            pset.addPrimitive(fn, [V], S, name=name)     # vector → scalar

        # primitives: binary ---------------------------------------------
        for name, (fn, _, rank) in self.operators.binary_nonreduce.items():
            if rank in (0, 3):    # general or (vector,scalar) variants
                pset.addPrimitive(fn, [S, S], S, name=name)
                pset.addPrimitive(fn, [S, V], V, name=name)
                pset.addPrimitive(fn, [V, S], V, name=name)
                pset.addPrimitive(fn, [V, V], V, name=name)
        for name, (fn, _, _) in self.operators.binary_reduce.items():
            pset.addPrimitive(fn, [V, V], S, name=name)  # reducer

        return pset

    # ---------- 5. toolbox ----------
    def _setup_gp(self):
        creator.create("Fitness", base.Fitness, weights=(-1.0,))
        creator.create("Individual", gp.PrimitiveTree, fitness=creator.Fitness)
        tb = base.Toolbox()
        tb.register("expr_init", gp.genHalfAndHalf, pset=self.pset, min_=1, max_=5)
        tb.register("individual", tools.initIterate, creator.Individual, tb.expr_init)
        tb.register("population", tools.initRepeat, list, tb.individual)
        return tb

model = ModelTest()



In [44]:
ind = model.toolbox.population(n=1)[0]
str(ind)

'sub(sub(spearmanr(div(v1, v0), sub(v0, v1)), exp(cos(9.486))), min(div(exp(v1), std(v1))))'

In [45]:
from sympy import Expr
import sympy as sp
from graphviz import Digraph
from polysym.utils import _round_floats


def draw_deap_tree(ind: gp.PrimitiveTree,
                   filename: str = "expr_tree",
                   fmt: str = "png",
                   round_const: int = 2) -> None:
    """
    Render a DEAP PrimitiveTree *exactly* as stored (no SymPy simplification).
    Float terminals are rounded to `round_const` decimals.
    """
    dot = Digraph(format=fmt)
    counter = 0

    # stack keeps (node_index, parent_id); start with root at position 0
    stack = [(0, None)]

    while stack:
        idx, parent_id = stack.pop()
        node = ind[idx]
        node_id = str(counter); counter += 1

        # ----- label ----------------------------------------------------
        if node.arity == 0:                       # Terminal
            if isinstance(node.value, float):
                lbl = str(round(node.value, round_const))
            else:
                lbl = str(node.value)
        else:                                     # Primitive
            lbl = node.name

        dot.node(node_id, lbl)
        if parent_id is not None:
            dot.edge(parent_id, node_id)

        # ----- children -------------------------------------------------
        # children start right after the current node and occupy a prefix
        # of the subtree slice.  Push them in reverse order so the leftmost
        # child is processed first when the stack is popped.
        if node.arity:
            child_idx = idx + 1
            for _ in range(node.arity):
                stack.append((child_idx, node_id))
                # skip over the entire subtree of this child
                child_idx = ind.searchSubtree(child_idx).stop

    dot.render(filename, cleanup=True)

expr = str(ind)
print(str(ind))

draw_deap_tree(ind)


sub(sub(spearmanr(div(v1, v0), sub(v0, v1)), exp(cos(9.486))), min(div(exp(v1), std(v1))))


In [11]:
model.pset.terminals[Vector][1].name

'v1'

In [16]:
str(ind)

'log10(tan(div(neg(sub(v1, v1)), add(sin(x0), abs(v1)))))'

In [17]:
ind

[<deap.gp.Primitive at 0x13447e6b0>,
 <deap.gp.Primitive at 0x13447e430>,
 <deap.gp.Primitive at 0x13447ede0>,
 <deap.gp.Primitive at 0x13447e200>,
 <deap.gp.Primitive at 0x13447eb60>,
 <deap.gp.Terminal at 0x1344b4540>,
 <deap.gp.Terminal at 0x1344b4540>,
 <deap.gp.Primitive at 0x13447e980>,
 <deap.gp.Primitive at 0x13447e2f0>,
 <deap.gp.Terminal at 0x1344b4440>,
 <deap.gp.Primitive at 0x13447e2a0>,
 <deap.gp.Terminal at 0x1344b4540>]

In [23]:
str(model.toolbox.population(n=1)[0])

'div(neg(add(sub(v1, -4.463), log(v0))), min(exp(exp(v0))))'