In [None]:
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Iterable
import sys
import time
from ortools.sat.python import cp_model

In [4]:
class Node:
    def __init__(
        self,
        properties: set[str],
        parents: set[Node]=set(),
    ) -> None:
        self.properties = properties
        self.parents = parents

    def __str__(self) -> str:
        return f"<N:{','.join(self.properties)}>"

    def __repr__(self) -> str:
        return f"{self}"
    
    def IsA(self, other: Node) -> bool:
        if not other.properties.issubset(self.properties):
            return False
        return True

class Transform:
    def __init__(self) -> None:
        self.requires: list[Node] = list()
        self.produces: list[Node] = list()
        self._input_group_map: dict[int, list[Node]] = {}

    def __str__(self) -> str:
        def _props(d: Node):
            return "{"+",".join(d.properties)+"}"
        return f"<{','.join(_props(r) for r in self.requires)}->{','.join(_props(p) for p in self.produces)}>"

    def __repr__(self): return f"{self}"

    def AddRequirement(self, properties: Iterable[str], parents: set[Node]=set()):
        return self._add_dependency(self.requires, properties, parents)

    def AddProduct(self, properties: Iterable[str], parents: set[Node]=set()):
        return self._add_dependency(self.produces, properties, parents)

    def _add_dependency(self, destination: list[Node], properties: Iterable[str], parents: set[Node]=set()):
        _dep = Node(properties=set(properties), parents=parents)
        # assert not any(e.IsA(_dep) for e in destination), f"prev. dep ⊆ new dep"
        # assert not any(_dep.IsA(e) for e in destination), f"new dep ⊆ prev. dep "
        # destination.add(_dep)
        destination.append(_dep)
        if destination == self.requires:
            i = len(self.requires)-1
            for p in parents:
                assert p in self.requires, f"{p} not added as a requirement"
            self._input_group_map[i] = self._input_group_map.get(i, [])+list(parents)
        return _dep
    
def _s(s: str):
    return set(s.split(", "))

transforms: list[Transform] = []

t = Transform()
t.AddRequirement(_s("annable"))
t.AddProduct(_s("ann"))
transforms.append(t)

given = [Node(_s("annable, taxable, ge1"))]
targets = []

model = cp_model.CpModel()

nodes = []
for tr in transforms:
    for n in tr.requires+tr.produces:
        nodes.append(n)

inventory = {}
applications = {}
links = {}
HORIZON = 16
for i in range(HORIZON):
    _state = f"s{i:02}"
    for j, n in enumerate(nodes):
        inventory[(i, j)] = model.NewBoolVar(f"{_state}: have {n}")

        for k, m in enumerate(nodes):
            links[(i, j, k)] = model.NewBoolVar(f"{_state}: {n}={m}")

    for j, tr in enumerate(transforms):
        applications[(i, j)] = model.NewBoolVar(f"{_state}: apply {tr}")

    

solver = cp_model.CpSolver()
solver.parameters.linearization_level = 0
solver.parameters.enumerate_all_solutions = True
# solver.parameters.enumerate_all_solutions = False
solution_limit = 5

class SolPrinter(cp_model.CpSolverSolutionCallback):
    def __init__(self):
        cp_model.CpSolverSolutionCallback.__init__(self)
        self._solutions = 0

    def on_solution_callback(self):
        self._solutions += 1
        print([self.Value(x) for x in test])
        if self._solutions >= solution_limit:
            self.StopSearch()

solver.Solve(model, SolPrinter())

[1, 0]
[0, 1]


4