In [184]:
from __future__ import annotations
from dataclasses import dataclass, field
import os, sys
from typing import Any, Generator, Iterable, Literal
import hashlib
import numpy as np
import json
from collections import deque

from limes_x.utils import KeyGenerator

class Namespace:
    def __init__(self) -> None:
        self.node_signatures: dict[int, str] = {}
        self._last_k: int = 0
        self._kg = KeyGenerator(True)
        self._KLEN = 4
        self._MAX_K = len(self._kg.vocab)**self._KLEN

    def NewKey(self):
        self._last_k += 1
        assert self._last_k < self._MAX_K
        return self._last_k, self._kg.FromInt(self._last_k, self._KLEN)

class Hashable:
    def __init__(self, ns: Namespace) -> None:
        self.namespace = ns
        self.hash, self.key = ns.NewKey()

    def __hash__(self) -> int:
        return self.hash
    
    def __eq__(self, __value: object) -> bool:
        K = "key"
        return hasattr(__value, K) and self.key == getattr(__value, K)

class Node(Hashable):
    def __init__(
        self,
        ns: Namespace,
        properties: set[str],
        parents: set[Node],
    ) -> None:
        super().__init__(ns)
        self.namespace = ns
        self.properties = properties
        self.parents = parents
        self._sig: str|None = None
        # self._diffs = set()
        # self._sames = set()

    def __str__(self) -> str:
        return f"({self.key}:{'-'.join(self.properties)})"

    def __repr__(self) -> str:
        return f"{self}"
    
    def IsA(self, other: Node) -> bool:
        # if other.key in self._diffs: return False
        # if other.key in self._sames: return True
        if not other.properties.issubset(self.properties):
            # self._diffs.add(other.key)
            return False
        # self._sames.add(other.key)
        # if compare_lineage: return  other.parents.issubset(self.parents)
        return True

    def Signature(self):
        if self._sig is None:
            psig = ",".join(sorted(p.Signature() for p in self.parents))
            sig = ",".join(sorted(self.properties))
            self._sig = f'{sig}:[{psig}]' if len(self.parents)>0 else sig
        return self._sig

class Dependency(Node):
    def __init__(self, namespace: Namespace, properties: set[str], parents: set[Node]) -> None:
        super().__init__(namespace, properties, parents)

    def __str__(self) -> str:
        return f"(D:{'-'.join(self.properties)})"
    
class Endpoint(Node):
    def __init__(self, namespace: Namespace, properties: set[str], parents: dict[Endpoint, Node]=dict()) -> None:
        super().__init__(namespace, properties, set(parents))
        self._parent_map = parents # real, proto

    def Iterparents(self):
        """real, prototype"""
        for e, p in self._parent_map.items():
            yield e, p

class Transform(Hashable):
    def __init__(self, ns: Namespace) -> None:
        super().__init__(ns)
        self.requires: list[Dependency] = list()
        self.produces: list[Dependency] = list()
        self._ns = ns
        self._input_group_map: dict[int, list[Dependency]] = {}
        self._key = ns.NewKey()
        self._seen: set[str] = set()

    def __str__(self) -> str:
        def _props(d: Dependency):
            return "{"+"-".join(d.properties)+"}"
        return f"{','.join(_props(r) for r in self.requires)}->{','.join(_props(p) for p in self.produces)}"

    def __repr__(self): return f"{self}"

    def AddRequirement(self, properties: Iterable[str], parents: set[Dependency]=set()):
        return self._add_dependency(self.requires, properties, parents)

    def AddProduct(self, properties: Iterable[str], parents: set[Dependency]=set()):
        return self._add_dependency(self.produces, properties, parents)

    def _add_dependency(self, destination: list[Dependency], properties: Iterable[str], parents: set[Dependency]=set()):
        _parents: Any = parents
        _dep = Dependency(properties=set(properties), parents=_parents, namespace=self._ns)
        # assert not any(e.IsA(_dep) for e in destination), f"prev. dep ⊆ new dep"
        # assert not any(_dep.IsA(e) for e in destination), f"new dep ⊆ prev. dep "
        # destination.add(_dep)
        destination.append(_dep)
        if destination == self.requires:
            i = len(self.requires)-1
            for p in _parents:
                assert p in self.requires, f"{p} not added as a requirement"
            self._input_group_map[i] = self._input_group_map.get(i, [])+list(_parents)
        return _dep

    def _sig(self, endpoints: Iterable[Endpoint]):
        # return "".join(e.key for e in endpoints)
        return self.key+"-"+ "".join(e.key for e in endpoints)

    # just all possibilities regardless of lineage
    def Possibilities(self, have: set[Endpoint], constraints: dict[Dependency, Endpoint]=dict()) -> Generator[list[Endpoint], Any, None]:
        matches: list[list[Endpoint]] = []
        constraints_used = False
        for req in self.requires:
            if req in constraints:
                must_use = constraints[req]
                _m = [must_use]
            else:
                _m = [m for m in have if m.IsA(req)]
            if len(_m) == 0: return None
            matches.append(_m)
        if len(constraints)>0 and not constraints_used: return None

        indexes = [0]*len(matches)
        indexes[0] = -1
        def _advance():
            i = 0
            while True:
                indexes[i] += 1
                if indexes[i] < len(matches[i]): return True
                indexes[i] = 0
                i += 1
                if i >= len(matches): return False
        while _advance():
            yield [matches[i][j] for i, j in enumerate(indexes)]
    
    # filter possibilities based on correct lineage
    def Valids(self, matches: Iterable[list[Endpoint]]):
        black_list: set[tuple[int, Endpoint]] = set()
        white_list: set[tuple[int, Endpoint]] = set()

        choosen: list[Endpoint] = []
        for config in matches:
            ok = True
            for i, (e, r) in enumerate(zip(config, self.requires)):
                k = (i, e)
                if k in black_list: ok=False; break
                if k in white_list: continue
                
                parents = self._input_group_map.get(i, [])
                if len(parents) == 0: # no lineage req.
                    white_list.add(k)
                    continue
                
                for prototype in parents:
                    # parent must already be in choosen, since it must have been added
                    # as a req. before being used as a parent during setup
                    found = False
                    for p in choosen:
                        if not p.IsA(prototype): continue
                        if p in e.parents: found=True; break
                    if not found: black_list.add(k); ok=False; break
                if not ok: break
            if ok: yield config

    def Apply(self, inputs: Iterable[tuple[Endpoint, Node]]):
        for r, (e, e_proto) in zip(self.requires, inputs):
            assert e.IsA(r), f"{e_proto}, {e}, {r}"

        inputs_dict = dict(inputs)
        parent_dict: dict[Any, Any] = {}
        for e, _ in inputs_dict.items():
            for p, pproto in e.Iterparents():
                if p in parent_dict: continue
                parent_dict[p] = pproto
        for e, eproto in inputs_dict.items():
            parent_dict[e] = eproto
        produced = {
            Endpoint(
                namespace=self._ns,
                properties=out.properties,
                parents=parent_dict
            ):out
        for out in self.produces}
        return Application(self, inputs_dict, produced)

@dataclass
class Application:
    transform: Transform
    used: dict[Endpoint, Node]
    produced: dict[Endpoint, Dependency]

    def __str__(self) -> str:
        return f"{self.transform} || {','.join(str(e) for e in self.used.keys())} >> {','.join(str(e) for e in self.produced)}"

    def __repr__(self) -> str:
        return f"{self}"

@dataclass
class Result:
    steps: int

@dataclass
class TrResult(Result):
    application: Application
    dependency_plan: list[Application]
    
@dataclass
class DepResult(Result):
    plan: list[Application]
    endpoint: Endpoint

def Solve(given: Iterable[Endpoint], target: Transform, transforms: Iterable[Transform]):
    @dataclass
    class State:
        have: dict[Endpoint, Dependency]
        target: Dependency|Transform
        lineage_requirements: dict[Node, Endpoint]
        depth: int

    def _get_producers_of(target: Dependency):
        for tr in transforms:
            for p in tr.produces:
                if p.IsA(target):
                    yield tr
                    break

    # if DEBUG: debug_print = lambda *args: None
    # if DEBUG: debug_print = lambda *args: None
    DEBUG = True
    # DEBUG = False
    log = open("debug_log.txt", "w")
    debug_print = lambda *args: log.write(" ".join(str(a) for a in args)+"\n") if args[0] != "END" else log.close()

    _apply_cache: dict[str, Application] = {}
    def _apply(target: Transform, inputs: Iterable[tuple[Endpoint, Node]]):
        sig  = "".join(e.key+d.key for e, d in inputs)
        if sig in _apply_cache:
            return _apply_cache[sig]
        appl = target.Apply(inputs)
        _apply_cache[sig] = appl
        return appl

    def _satisfies_lineage(tproto: Dependency, candidate: Endpoint):
        for tp_proto in tproto.parents:
            if all(not p.IsA(tp_proto) for p, _ in candidate.Iterparents()):
                return False
        return True

    HORIZON=64
    def _solve_dep(s: State) -> list[DepResult]:
        if s.depth >= HORIZON:
            if DEBUG: debug_print(f" <-  HORIZON", s.depth)
            return []
        target: Any = s.target
        assert isinstance(target, Dependency), f"{s.target}, not dep"
        if DEBUG: debug_print(f" ->", s.target, s.lineage_requirements)
        if DEBUG: debug_print(f"   ", s.have.keys())

        candidates:list[DepResult] = []
        for e, eproto in s.have.items():
            if not e.IsA(target): continue
            acceptable = True
            for rproto, r in s.lineage_requirements.items():
                if e == r: continue
                if eproto.IsA(rproto): # e is protype, but explicitly breaks lineage
                    acceptable=False; break

                for p, pproto in e.Iterparents():
                    if rproto.IsA(pproto):
                        if p != r:
                            acceptable=False; break

            if not acceptable:
                continue
            else:
                if DEBUG: debug_print(f"    ^candidate", e, eproto, e.parents)
                if DEBUG: debug_print(f"    ^reqs.    ", s.lineage_requirements)
                candidates.append(DepResult(0, [], e))
            # elif quality == 2:
            #     if DEBUG: debug_print(f" <-", s.target, e, "DIRECT")
            #     return [DepResult(0, [], e)]

        def _add_result(res: TrResult):
            ep: Endpoint|None = None
            for e in res.application.produced:
                if e.IsA(target):
                    ep = e; break
            assert isinstance(ep, Endpoint)
            if not _satisfies_lineage(target, ep): return
            candidates.append(DepResult(
                res.steps,
                res.dependency_plan+[res.application],
                ep,
            ))

        for tr in _get_producers_of(target):
            results = _solve_tr(State(s.have, tr, s.lineage_requirements, s.depth))
            for res in results:
                _add_result(res)

        if DEBUG: debug_print(f" <-", s.target, f"{len(candidates)} sol.", candidates[0].endpoint if len(candidates)>0 else None)
        return candidates

    _transform_cache: dict[str, list[TrResult]] = {}
    def _solve_tr(s: State) -> list[TrResult]:
        assert isinstance(s.target, Transform), f"{s.target} not tr"
        if DEBUG: debug_print(f">>>{s.depth:02}", s.target, s.lineage_requirements)
        for h in s.have:
            if DEBUG: debug_print(f"      ", h)

        # memoization
        sig = "".join(e.key for e in s.have)
        sig += f":{s.target.key}"
        sig += ":"+"".join(e.key for e in s.lineage_requirements.values())
        if sig in _transform_cache:
            if DEBUG: debug_print(f"<<<{s.depth:02} CACHED: {len(_transform_cache[sig])} solutions")
            return _transform_cache[sig]
        
        plans: list[list[DepResult]] = []
        for i, req in enumerate(s.target.requires):
            req_p = {}
            for proto, e in s.lineage_requirements.items():
                if req.IsA(proto): continue
                req_p[proto] = e

            results = _solve_dep(State(s.have, req, req_p, s.depth+1))
            
            if len(results) == 0:
                if DEBUG: debug_print(f"<<< FAIL", s.target, req)
                return []
            else:
                plans.append(results)

        def _iter_plans():
            indexes = [0]*len(plans)
            indexes[0] = -1
            def _advance():
                i = 0
                while True:
                    indexes[i] += 1
                    if indexes[i] < len(plans[i]): return True
                    indexes[i] = 0
                    i += 1
                    if i >= len(plans): return False
            while _advance():
                yield [plans[i][j] for i, j in enumerate(indexes)]

        target: Transform = s.target
        def _iter_satisfies():
            input_sigs = set()
            if DEBUG: debug_print(f"    #", s.lineage_requirements)
            for i, inputs in enumerate(_iter_plans()):
                sig = "".join(e.endpoint.key for e in inputs)
                if sig in input_sigs: continue
                input_sigs.add(sig)
                if DEBUG: debug_print(f"    .", i+1)

                deps: dict[Node, Endpoint] = {}
                _fail = False
                used: set[Endpoint] = set()
                for res, req in zip(inputs, target.requires):
                    if DEBUG: debug_print(f"          ", deps)
                    if DEBUG: debug_print(f"    ___", req, req.parents)
                    if DEBUG: debug_print(f"        __", res.endpoint, list(res.endpoint.Iterparents()))
                    if res.endpoint in used:
                        if DEBUG: debug_print(f"    ___ FAIL: duplicate input", res.endpoint)
                        _fail=True; break
                    used.add(res.endpoint)

                    if not _satisfies_lineage(req, res.endpoint):
                        if DEBUG: debug_print(f"    ___ FAIL: unsatisfied lineage", req)
                        _fail=True; break

                    for rproto in req.parents:
                        r = deps[rproto]
                        # if all(not p.IsA(rproto) for p, pproto in res.endpoint.Iterparents()):
                        #     if DEBUG: debug_print(f"    ___ FAIL: unsatisfied lineage", rproto)
                        #     _fail=True; break
                        res_parents = list(res.endpoint.Iterparents())
                        res_parents.reverse()
                        for p, pproto in res_parents:
                            if not p.IsA(rproto): continue
                            if p!=r:
                                if DEBUG: debug_print(f"    ___ FAIL: lineage mismatch", p, r)
                                _fail=True; break
                            else:
                                break # in the case of asm -> bin, the closest ancestor takes priority
                        if _fail: break
                    if _fail: break

                    deps[req] = res.endpoint
                if _fail:
                    continue
                if DEBUG: debug_print(f"    ___ KEPT")
                yield inputs

        if DEBUG: debug_print(f"<<<{s.depth:02}", s.target, s.lineage_requirements)
        if DEBUG: debug_print(f"     ", [len(x) for x in plans])
        solutions: list[TrResult] = []
        for inputs in _iter_satisfies():
            my_appl = _apply(s.target, [(res.endpoint, req) for req, res in zip(s.target.requires, inputs)])
            consolidated_plan: list[Application] = []
            produced_sigs: set[str] = {p.Signature() for p in my_appl.produced}
            # if DEBUG: debug_print(f"   __", my_appl)
            for res in inputs:
                for appl in res.plan:
                    if all(p.Signature() in produced_sigs for p in appl.produced): continue
                    consolidated_plan.append(appl)
                    produced_sigs = produced_sigs.union(p.Signature() for p in appl.produced)
            solutions.append(TrResult(
                len(consolidated_plan),
                my_appl,
                consolidated_plan,
            ))
            # if DEBUG: debug_print(f"    *", my_appl)
            # if DEBUG: debug_print(f"     ", [res.endpoint for res in inputs])
            # if DEBUG: debug_print(f"    .", target.requires)
            # for appl in consolidated_plan:
            #     if DEBUG: debug_print(f"    __", appl)
        if DEBUG: debug_print(f"     ", f"{len(solutions)} sol.", solutions[0].application.produced if len(solutions)>0 else None)
        solutions = sorted(solutions, key=lambda s: s.steps)
        _transform_cache[sig] = solutions
        return solutions

    input_tr = Transform(target._ns)
    given_dict = {g:input_tr.AddProduct(g.properties) for g in given}
    res = _solve_tr(State(given_dict, target, {}, 0))
    if DEBUG: debug_print("END")
    return res


def _set(s: str):
    return set(s.split(", "))
 
transforms = []
NS = Namespace()

t = Transform(NS)
t.AddRequirement(_set("dna"))
t.AddProduct(_set("contigs, asm, annable"))
transforms.append(t)

t = Transform(NS)
r = t.AddRequirement(_set("dna"))
t.AddRequirement(_set("contigs, asm"), {r})
t.AddProduct(_set("contigs, bin, annable"))
transforms.append(t)

t = Transform(NS)
t.AddRequirement(_set("db"))
t.AddRequirement(_set("annable"))
t.AddProduct(_set("ann"))
transforms.append(t)

t = Transform(NS)
ann = t.AddRequirement(_set("annable"))
r = t.AddRequirement(_set("db, cog"))
t.AddRequirement(_set("ann"), {r, ann})
r = t.AddRequirement(_set("db, kegg"))
t.AddRequirement(_set("ann"), {r, ann})
t.AddProduct(_set("table"))
transforms.append(t)

t = Transform(NS)
db1 = t.AddRequirement(_set("db, cog"))
db2 = t.AddRequirement(_set("db, kegg"))
asm = t.AddRequirement(_set("contigs, asm"))
bin = t.AddRequirement(_set("contigs, bin"))
# t.AddRequirement(_set("ann"), {asm, db1})
# t.AddRequirement(_set("ann"), {asm, db2})
# t.AddRequirement(_set("ann"), {bin, db1})
# t.AddRequirement(_set("ann"), {bin, db2})
t.AddRequirement(_set("table"), {asm, db1, db2})
t.AddRequirement(_set("table"), {bin, db1, db2})
t.AddProduct(_set("figure"))
transforms.append(t)

t = Transform(NS)
t.AddRequirement(_set("precog"))
t.AddProduct(_set("db, cog"))
transforms.append(t)

# t = Transform(NS)
# t.AddRequirement(_set("prekegg"))
# t.AddProduct(_set("db, kegg"))
# transforms.append(t)

##############
# failing because lineage requirement may be split, thus relieving some inputs of lineage
# but can't proceed if the first input must be relieved by the following inputs
# which can't run becuase they "depend" on the first input
# todo: look ahead (no pathing, so fast) to determine which inputs can be relieved
##############


haves = [Endpoint(NS, _set(r)) for r in [
    # "precog",
    "db, cog",
    "db, kegg",
    "dna",
]]

target = Transform(NS)
# r = target.AddRequirement(_set("bin"))
# db = target.AddRequirement(_set("cog"))
# target.AddRequirement(_set("ann"), {db, r})
# target.AddRequirement(_set("ann"), {db})

# r = target.AddRequirement(_set("bin"))
# target.AddRequirement(_set("table"), {r})

target.AddRequirement(_set("figure"))

solutions = None
def _test():
    global solutions
    solutions = Solve(haves, target, transforms)
    for res in solutions:
        print(res.steps)
        return f"{len(solutions)} solutions", res.dependency_plan+[res.application]
    # if res is not None:
_test()

9


('1 solutions',
 [{dna}->{annable-contigs-asm} || (c000:dna) >> (l000:annable-contigs-asm),
  {dna},{contigs-asm}->{annable-bin-contigs} || (c000:dna),(l000:annable-contigs-asm) >> (m000:annable-bin-contigs),
  {db},{annable}->{ann} || (a000:db-cog),(l000:annable-contigs-asm) >> (n000:ann),
  {db},{annable}->{ann} || (b000:kegg-db),(l000:annable-contigs-asm) >> (o000:ann),
  {annable},{db-cog},{ann},{kegg-db},{ann}->{table} || (l000:annable-contigs-asm),(a000:db-cog),(n000:ann),(b000:kegg-db),(o000:ann) >> (r000:table),
  {db},{annable}->{ann} || (a000:db-cog),(m000:annable-bin-contigs) >> (p000:ann),
  {db},{annable}->{ann} || (b000:kegg-db),(m000:annable-bin-contigs) >> (q000:ann),
  {annable},{db-cog},{ann},{kegg-db},{ann}->{table} || (m000:annable-bin-contigs),(a000:db-cog),(p000:ann),(b000:kegg-db),(q000:ann) >> (s000:table),
  {db-cog},{kegg-db},{contigs-asm},{bin-contigs},{table},{table}->{figure} || (a000:db-cog),(b000:kegg-db),(l000:annable-contigs-asm),(m000:annable-bin-conti

In [185]:
for res in solutions:
    for a in res.dependency_plan:
        print(a)
    print(res.application)
    print()

{dna}->{annable-contigs-asm} || (c000:dna) >> (l000:annable-contigs-asm)
{dna},{contigs-asm}->{annable-bin-contigs} || (c000:dna),(l000:annable-contigs-asm) >> (m000:annable-bin-contigs)
{db},{annable}->{ann} || (a000:db-cog),(l000:annable-contigs-asm) >> (n000:ann)
{db},{annable}->{ann} || (b000:kegg-db),(l000:annable-contigs-asm) >> (o000:ann)
{annable},{db-cog},{ann},{kegg-db},{ann}->{table} || (l000:annable-contigs-asm),(a000:db-cog),(n000:ann),(b000:kegg-db),(o000:ann) >> (r000:table)
{db},{annable}->{ann} || (a000:db-cog),(m000:annable-bin-contigs) >> (p000:ann)
{db},{annable}->{ann} || (b000:kegg-db),(m000:annable-bin-contigs) >> (q000:ann)
{annable},{db-cog},{ann},{kegg-db},{ann}->{table} || (m000:annable-bin-contigs),(a000:db-cog),(p000:ann),(b000:kegg-db),(q000:ann) >> (s000:table)
{db-cog},{kegg-db},{contigs-asm},{bin-contigs},{table},{table}->{figure} || (a000:db-cog),(b000:kegg-db),(l000:annable-contigs-asm),(m000:annable-bin-contigs),(r000:table),(s000:table) >> (t000:fig

In [186]:

NS = Namespace()
transforms = []

t = Transform(NS)
t.AddRequirement(_set("reads"))
t.AddProduct(_set("annable, taxable"))
transforms.append(t)

t = Transform(NS)
t.AddRequirement(_set("annable"))
t.AddProduct(_set("ann"))
transforms.append(t)

# t = Transform(NS)
# t.AddRequirement(_set("ann"))
# t.AddProduct(_set("annable"))
# transforms.append(t)

t = Transform(NS)
t.AddRequirement(_set("taxable"))
t.AddProduct(_set("tax"))
transforms.append(t)

t = Transform(NS)
d_parent = t.AddRequirement(_set("annable, taxable"))
d_ann = t.AddRequirement(_set("ann"), {d_parent})
d_tax = t.AddRequirement(_set("tax"), {d_parent})
t.AddProduct(_set("sum"))
transforms.append(t)

# M, N = 2, 1
# M, N = 2, 2
M, N = 5, 3
# M, N = 6, 6
# M, N = 8, 6
# M, N = 8, 8
# M, N = 50, 2
# M, N = 60, 2
haves = [Endpoint(NS, _set(f"{i+1}, reads")) for i in range(M)]

target = Transform(NS)
# for e in haves[-N:]:
for e in haves[:N]:
    de = target.AddRequirement(e.properties)
    target.AddRequirement(_set("sum"), {de})
    # target.AddRequirement(_set("sum"))

print("Start")
# %prun r = Solve(haves, target, transforms)
%prun solutions = Solve(haves, target, transforms)
# f"input size [{N}], states checked [{r.steps}], {r.message}, {len(target.requires)}"

Start
 

         25653 function calls (23079 primitive calls) in 0.010 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4658    0.001    0.000    0.003    0.000 562999156.py:256(<genexpr>)
       26    0.001    0.000    0.007    0.000 562999156.py:373(_iter_satisfies)
4066/1570    0.001    0.000    0.004    0.000 {method 'join' of 'str' objects}
     1301    0.001    0.000    0.005    0.000 562999156.py:256(<lambda>)
     1300    0.001    0.000    0.001    0.000 {method 'write' of '_io.TextIOWrapper' objects}
     1495    0.000    0.000    0.001    0.000 562999156.py:55(__repr__)
     1843    0.000    0.000    0.000    0.000 562999156.py:87(Iterparents)
     1302    0.000    0.000    0.001    0.000 562999156.py:52(__str__)
     1122    0.000    0.000    0.000    0.000 562999156.py:79(__str__)
      9/1    0.000    0.000    0.009    0.009 562999156.py:329(_solve_tr)
      336    0.000    0.000    0.001    0.000 562999156.py:267(

In [187]:
print(len(solutions))
for i, res in enumerate(solutions):
    print(res.steps)
    for a in res.dependency_plan:
        print(a)
    print(res.application)
    print()
    if i > 10: break

1
12
{reads}->{annable-taxable} || (J000:1-reads) >> (d000:annable-taxable)
{annable}->{ann} || (d000:annable-taxable) >> (i000:ann)
{taxable}->{tax} || (d000:annable-taxable) >> (n000:tax)
{annable-taxable},{ann},{tax}->{sum} || (d000:annable-taxable),(i000:ann),(n000:tax) >> (s000:sum)
{reads}->{annable-taxable} || (K000:reads-2) >> (e000:annable-taxable)
{annable}->{ann} || (e000:annable-taxable) >> (j000:ann)
{taxable}->{tax} || (e000:annable-taxable) >> (o000:tax)
{annable-taxable},{ann},{tax}->{sum} || (e000:annable-taxable),(j000:ann),(o000:tax) >> (t000:sum)
{reads}->{annable-taxable} || (L000:3-reads) >> (f000:annable-taxable)
{annable}->{ann} || (f000:annable-taxable) >> (k000:ann)
{taxable}->{tax} || (f000:annable-taxable) >> (p000:tax)
{annable-taxable},{ann},{tax}->{sum} || (f000:annable-taxable),(k000:ann),(p000:tax) >> (u000:sum)
{1-reads},{sum},{reads-2},{sum},{3-reads},{sum}-> || (J000:1-reads),(s000:sum),(K000:reads-2),(t000:sum),(L000:3-reads),(u000:sum) >> 



In [188]:


        # todo: deque[Dependency] = deque()
        # loop_marker: Dependency|None = None
        # # check here for lineage constraints
        # while len(todo)>0:
        #     req = todo.popleft()
        #     if req == loop_marker:
        #         if DEBUG: debug_print(f"<<< FAIL", s.target, req)
        #         return

        #     req_p = {}
        #     for proto, e in s.required_parents.items():
        #         if req.IsA(proto): continue
        #         # if already satisfied by other req and lineage not specified for this req: skip
        #         # if e in satisfied_lineages and all(not pproto.IsA(proto) for pproto in req.parents): continue
        #         req_p[proto] = e
        #     if any(p not in deps for p in req.parents):
        #         res = None # requirements of node not satisfied yet
        #     else:
        #         rreq = {p:deps[p] for p in req.parents}
        #         res = _solve_dep(State(_have, req, req_p|rreq, s.steps+1, s.depth+1))
        #         # if res is None:


        #     if res is None:
        #         todo.append(req)
        #         if loop_marker is None: loop_marker = req
        #         continue
        #     loop_marker = None

        #     if res.endpoint in plans: continue # for duplicate reqs...
        #     plans[res.endpoint] = res.plan
        #     deps[req] = res.endpoint
        #     steps += res.steps
        #     for appl in res.plan:
        #         _have |= appl.produced
        #     satisfied_lineages[res.endpoint] = req

In [189]:
    # reqs = deque()
    # for r in target.requires:
    #     reqs.append(r)

    # todo: deque[State] = deque()
    # todo.append(State(set(given), [], target, [], reqs))
    # steps, MAX_S = 0, 5
    # while len(todo)>0:
    #     steps += 1
    #     if steps>MAX_S: 
    #         print("step limit")
    #         return

    #     s = todo.popleft()
        
    #     print(s.target)
    #     # print(s.have)
    #     for x in s.plan:
    #         print(x)
    #     print()

    #     if len(s.requirements) == 0: return s

    #     if isinstance(s.target, Dependency):
    #         for e in s.have:
    #             if not e.IsA(s.target): continue
    #             todo.append(State(
    #                 s.have,
    #                 [],
    #                 s.requirements.popleft(),
    #                 s.all_plans+[s.plan+[e]],
    #                 s.requirements,
    #             ))
            
    #         for tr in _get_producers_of(s.target):
    #             todo.append(State(
    #                 s.have,
    #                 s.plan + [tr],
    #                 tr,
    #                 s.all_plans,
    #                 s.requirements,
    #             ))
    #     else:
    #         for req in target.requires:
    #             if len(req.parents)>0:
    #                 continue # figure out later
    #             todo.append(State(
    #                 s.have,
    #                 s.plan + [req],
    #                 req,
    #                 s.all_plans,
    #                 s.requirements
    #             ))

    # # @dataclass
    # # class State:
    # #     target: Transform
    # #     have: set[Endpoint]
    # #     constraints: dict[Dependency, Endpoint]
    # #     plan: list[Transform]

    # # todo: deque[State] = deque()
    # # todo.append(State(target, set(given), {}, []))
    # # while len(todo)>0:
    # #     s = todo.popleft()
    # #     cons = s.constraints
        
    # #     for tr in transforms:
    # #         fwds = tr.Valids(tr.Possibilities(s.have, cons))

In [190]:
# def _apply_one(have: set[Endpoint], tr: Transform, sources: set[Endpoint]):
#         match = next(tr.NextValid(tr.Possibilities(have, sources)), None)
#         if match is not None:
#             return tr.Apply(match)
    
#     # res = _map()
#     # if not res.success: return res


        # for e in given if len(sources)==0 else sources:
        #     if not e.IsA(target): continue
        #     return MapResult([], e)

        # if "sum" in target.properties and any("2" in s.properties for s in sources):
        # if "sum" in target.properties:
        #     x = 1
        #     print(target, sources)

        # todo: deque[MapState] = deque()
        # todo.append(MapState(given, [], {t for t in transforms}))

        # while len(todo)>0:
        #     s = todo.popleft()
        #     for tr in curr.remaining_transforms:
        #         next_step = _apply_one(curr.have, tr)
        #         if next_step is None: continue
        #             # next_step = _apply_one(curr.have, tr)
        #             # if next_step is None: continue
        #         for e in next_step.produced:
        #             if not e.IsA(target): continue
        #             return MapResult(curr.plan+[next_step], e)

        #         todo.append(MapState(
        #             curr.have | next_step.produced,
        #             curr.plan + [next_step],
        #             curr.remaining_transforms - {tr}
        #         ))

#     def _solve_tr(given: set[Endpoint], target: Transform):
#         have = set(given)
#         plan: list[Application] = []
#         dep2ep: dict[Node, Endpoint] = {} # really Dep -> Ep
#         dep_parent_sets: dict[Dependency, set[Endpoint]] = {}
#         todo: deque[Dependency] = deque()
#         for r in target.requires: todo.append(r)
#         loop_landmark = None
#         while len(todo)>0:
#             curr = todo.popleft()
#             def _skip():
#                 nonlocal loop_landmark
#                 if loop_landmark is None: loop_landmark = curr
#                 todo.append(curr)

#             if loop_landmark is not None and curr == loop_landmark:
#                 return Result([], f"can't make {curr}", info=have)
#             # if any parent not generated, skip for now
#             if any(p not in dep2ep for p in curr.parents): _skip(); continue

#             if curr not in dep_parent_sets:
#                 parents = {dep2ep[p] for p in curr.parents}
#                 dep_parent_sets[curr] = parents
#             # print(f"---",dep_parent_sets)

#             # if "sum" in curr.properties:
#             #     print(curr, dep_parent_sets[curr])
#             sol = _solve_dep(have, curr, dep_parent_sets[curr])
#             if sol is None: _skip(); continue
#             loop_landmark = None

#             # print(curr, dep_parent_sets[curr], loop_landmark)
#             # print(">")
#             # print(have)
#             # print(todo)
#             # print(sol)
#             # for a in sol.plan:
#             #     print(a)
#             # print()

#             dep2ep[curr] = sol.endpoint
#             for a in sol.plan:
#                 have |= a.produced
#             plan += sol.plan

#     return _solve_tr(set(given), target, set())

In [191]:
    # sol = res.solution
    # last_l = 0
    # while last_l != len(sol):
    #     last_l = len(sol)
    #     used = set()
    #     for a in sol:
    #         used |= a.used
    #     sol = [a for a in sol if a.transform==target or any(e in used for e in a.produced)]
    # res.solution = sol    

# @dataclass
    # class State:
    #     have: set[Endpoint]
    #     plan: list[Application]
    #     usage_sigs: set[str]

    # def _local_solve(have: set[Endpoint], target: Dependency):
    #     todo: deque[State] = deque()
    #     todo.append(State(have, [], set()))
    #     MAX_S = 10_000
    #     steps = 0
    #     # _last_depth = 0
    #     while len(todo) > 0:
    #         steps += 1
    #         if steps>MAX_S: return Result([], "step limit", steps, info=todo)
    #         curr = todo.popleft()    

# @dataclass
    # class SubGoal:
    #     target: Dependency

    # have = set(given)
    # dep2endpoint: dict[Dependency, Endpoint] = {}
    # todo: deque[SubGoal] = deque()
    # for d in target.requires: todo.appendleft(SubGoal(d))
    # while len(todo)>0:
    #     subgoal = todo.pop()
    #     sources: list[Endpoint] = []
    #     ok = True
    #     for p in subgoal.target.parents:
    #         if p not in dep2endpoint:
    #             todo.appendleft(subgoal)
    #             ok = False; break
    #         sources.append(dep2endpoint[p])
    #     if not ok: continue

In [192]:
# def Apply(self, have: Iterable[Endpoint], use_signatures: set[str]) -> Iterable[Application]:
#         matches = self.Possibilities(have)
#         if len(matches) == 0: return []

#         # can reduce exponential trial here by enforcning the input groups first
#         def _possible_configs(i: int, choosen: list[Endpoint]) -> list[list[Endpoint]]:
#             if i >= len(self.requires): return [choosen]
#             candidates = matches[i]
#             parents = self._input_group_map.get(i, [])
#             # print(parents, candidates, choosen)
#             if len(parents) > 0:
#                 for prototype in parents:
#                     # parent must be in choosen, since it must have been added
#                     # as a req. before being used as a parent
#                     parent: None|Endpoint = None
#                     for p in choosen:
#                         if p.IsA(prototype): parent = p; break
#                     if parent is None: return []
#                     candidates = [c for c in candidates if parent in c.parents]
#             configs = []
#             for c in candidates:
#                 configs += _possible_configs(i+1, choosen+[c])
#             return configs
#         configs = _possible_configs(0, [])

#         def _same(a: Endpoint, b: Endpoint):
#             return a.properties.issubset(b.properties) and b.properties.issubset(a.properties) \
#                 and a.parents.issubset(b.parents) and b.parents.issubset(a.parents)

#         for input_set in configs:
#             sis = set(input_set)
#             sig = self._sig(input_set)
#             if sig in use_signatures: continue
#             _parents = sis|{p for g in [e.parents for e in input_set] for p in g}
#             produced = {
#                 Endpoint(
#                     namespace=self._ns,
#                     properties=out.properties,
#                     parents=_parents
#                 )
#             for out in self.produces}
#             # if all(_same(e, p) for e in have for p in produced):
#             #     continue
#             #     print(have)
#             #     print(produced)
#             #     print()
#             yield Application(self, sis, produced, sig)

In [193]:
        # if len(target.parents) == 0:
        #     sol = _map_to(have, target)
        #     if sol is None: return Result([], "x")
        #     return Result(sol, success=True)
        # else:
        #     for p in target.parents:
        #         _p: Any = p
        #         res = _map_to(have, target, _p)

        #         print(">",res)
        #         have |= {e for g in [a.produced for a in res.solution] for e in g}
        #         if not res.success: return res

    # have = set(given)
    # for d in target.requires:
    #     _solve(have, d)

In [194]:
    # def Signature(self):
    #     cache = self.namespace.node_signatures
    #     if self.key not in cache:
    #         props = ",".join(sorted(self.properties))
    #         parents = ",".join(sorted([p.Signature() for p in self.parents]))
    #         sig = f"{props}-{parents}"
    #         cache[self.hash] = sig
    #     return cache[self.hash]

In [195]:

    # def _solve():
    #     todo: deque[State] = deque()
    #     todo.append(State(set(given), [], set()))
    #     MAX_S = 100_000
    #     steps = 0
    #     # _last_depth = 0
    #     while len(todo) > 0:
    #         steps += 1
    #         if steps>MAX_S: return Result([], "step limit", todo, steps)
    #         # curr = todo.popleft()
    #         curr = todo.pop()

    #         final_appl = _check_done(curr)
    #         if final_appl is not None: return Result(curr.plan+[final_appl], steps=steps)

    #         # _depth = len(curr.plan)
    #         # if _depth != _last_depth:
    #         #     todo = _deduplicate_states(curr, todo)
    #         #     _last_depth = _depth

    #         next_states = _get_next_states(curr)
    #         for n in next_states:
    #             todo.append(n)

    #     return Result([], "no sol", steps=steps)

In [196]:
    # plans: dict[Endpoint, Path] = {}
    # def _path_to(have: Iterable[Endpoint], target: Dependency) -> Path|None:
    #     if any(e.IsA(target) for e in have): return Path([])
    #     if target in plans: return plans[target]

    #     # DFS back from e
    #     for tr in transforms:
    #         if not any(d.IsA(target) for d in tr.produces): continue 
    #         for req in tr.requires:
    #             path_result = _path_to(have, req)
    #             if path_result is None: continue
    #             path_result.plan.append(tr)
    #             return path_result
    # x = [
    # # @dataclass
    # # class State:
    # #     have: Iterable[Endpoint]
    # #     targets: Iterable[Dependency]
    # #     plan: list[Transform]

    # # todo: deque[State] = deque(maxlen=64)
    # # todo.append(State([], [t for t in target.requires], []))
    # # while len(todo)>0:
    # #     _s = todo.popleft()
    # #     t = next(iter(_s.targets))
    # #     plan = 
    # ]

    # usage_signatures: dict[Transform, set[str]] = {t:set() for t in transforms}
    # def _solve(have: list[Endpoint], target: Transform, sigs: dict) -> list[Application]|None:
    #     possibilities = target.Apply(have, sigs[target])
    #     if len(possibilities)>0: return possibilities[0:1]

    #     for t in target.requires:
    #         path = _path_to(have, t)
    #         if path is None: return None
    #         fist_tr = path.plan[0]
    #         poss = fist_tr.Apply(have, sigs[fist_tr])
    #         if poss
                        
            


    # _solve(list(given), target)

In [197]:
# def Solve(given: Iterable[Endpoint], target: Transform, transforms: Iterable[Transform]):
#     @dataclass
#     class State:
#         have: list[Endpoint]
#         usage_signatures: dict[int, set[str]]
#         plan: list[Application]

#     transforms = list(transforms)
    
#     def _done(state: State):
#         appl = target.Apply(state.have, set())
#         return appl 

#     def _solve() -> Result:
#         MAXS = 10_000
#         todo: deque[State] = deque([State(
#             have = list(given),
#             plan = [],
#             usage_signatures={},
#         )], maxlen=MAXS)
        

#         def _deduplicate_states(current: State):
#             def _get_sig(s: State):
#                 haves_sig = '|'.join([e.Signature() for e in s.have])
#                 return haves_sig
#             seen = {_get_sig(current)}
#             new_todo: deque[State] = deque([], MAXS)
#             for s in todo:
#                 if _get_sig(s) in seen: continue
#                 new_todo.append(s)

#             if len(todo) != len(new_todo):
#                 for s in todo:
#                     print(s)
#                 print("-")
#                 for s in new_todo:
#                     print(s)
#                 print()
#             return new_todo

#         _steps = 0
#         _empty = set()
#         _last_depth = 0
#         while len(todo)>0:
#             _steps += 1
#             if _steps > MAXS: return Result([], f"step limit exceeded", steps=_steps)
#             _s = todo.popleft()

#             _target_applications = target.Apply(_s.have, _empty)
#             if len(_target_applications)>0:
#                 return Result(solution=_s.plan+[_target_applications[0]], steps=_steps)

#             _depth = len(_s.plan)
#             if _depth != _last_depth:
#                 todo = _deduplicate_states(_s)
#                 _last_depth = _depth

#             if _done(_s): return Result(_s.plan, steps=_steps)
#             for tr in transforms:
#                 possibilities = tr.Apply(_s.have, _s.usage_signatures.get(tr.hash, set()))
#                 # for app in possibilities:
#                 #     usage_sigs = _s.usage_signatures.copy()
#                 #     usage_sigs[tr.hash] = usage_sigs.get(tr.hash, set())|{app.signature}
#                 #     todo.append(State(
#                 #         have = _s.have+app.produced,
#                 #         plan = _s.plan+[app],
#                 #         usage_signatures = usage_sigs,
#                 #     ))

#                 if len(possibilities) == 0: continue
#                 usage_sigs = _s.usage_signatures.copy()
#                 new_have = _s.have.copy()
#                 for app in possibilities:
#                     usage_sigs[tr.hash] = usage_sigs.get(tr.hash, set())|{app.signature}
#                     new_have += app.produced
#                 todo.append(State(
#                     have = new_have,
#                     plan = _s.plan+possibilities,
#                     usage_signatures=usage_sigs
#                 ))
#         return Result([], f"ran out of things to try", steps = _steps)
    
#     res = _solve()
#     sol = res.solution
#     last_l = 0
#     while last_l != len(sol):
#         last_l = len(sol)
#         used = set()
#         for a in sol:
#             used |= a.used
#         sol = [a for a in sol if a.transform==target or any(e in used for e in a.produced)]
#     res.solution = sol
#     return res

In [198]:
# from __future__ import annotations
# import os, sys
# import asyncio
# from typing import Iterable, Callable, Any
# from pathlib import Path

# from limes_x.solver import DependencySolver, Plan, Dependency
# from limes_x.persistence import ProjectState, Instance
# from limes_x.compute_module import ComputeModule

# mpath = Path("./test_solver/")
# modules = [
#     ComputeModule(mpath.joinpath(d)) for d in os.listdir(mpath)
# ]
# print(modules)

# given = [
#     ("a", "./test_data/a1"),
#     ("a", "./test_data/a2"),
#     ("b", "./test_data/b1"),
# ]

# prj_path = "./cache/man_test01/"
# state = ProjectState(prj_path, on_exist="overwrite")
# for dtype, val in given:
#     state.RegisterInstance(Instance.Str(dtype, val))
# for m in modules:
#     state.RegisterInstance(Instance.ComputeModule(m))

# deps = []
# for k, inst in state._instances.items():
#     if not inst.IsPyType(ComputeModule): continue
#     deps.append(Dependency(inst.val.requires, inst.val.produces, k))

# solver = DependencySolver(deps)
# # plan = solver.Solve({"a"}, {"reuse", "linear", "branched"})
# plan = solver.Solve({"a"}, {"branched"})
# assert plan != False
# [state.GetInstance(m.ref_key) for m in plan]

In [199]:
# def make_dependency(module: ComputeModule):
#     return Dependency(module.requires, module.produces, module)

# modules = Path("./test_solver/")
# solver = Plan([
#     make_dependency(ComputeModule(p))
# for p in [
#     modules.joinpath(p) for p in os.listdir(modules)
# ]])
# plan = solver.Solve({"a"}, {"reuse", "linear", "branched"})
# plan

In [200]:
# from limes_x.compute_module import ComputeModule

# a = ComputeModule("./test_modules/copy/")
# b = ComputeModule("./test_modules/copy2/")

# a.requires, b.requires

In [201]:
# state = ProjectState("./cache/test_persist")
# ok = Instance("asdf", 1)
# ov = Instance("s", 2)
# state._lineage[ok] = [ov]
# state.Save()

# s2 = ProjectState.Load("./cache/test_persist")
# for k, v in s2._lineage.items():
#     _te = k.type, k.value, ok == k, [(i.type, i.value, i == ov) for i in v]
#     print(_te)

# ok._id