In [91]:
from __future__ import annotations
from dataclasses import dataclass, field
import os, sys
from typing import Any, Iterable, Literal
import hashlib
import numpy as np
import json
from collections import deque

from limes_x.utils import KeyGenerator

class Namespace:
    def __init__(self) -> None:
        self.node_signatures: dict[int, str] = {}
        self._last_k: int = 0
        self._kg = KeyGenerator(True)
        self._KLEN = 4
        self._MAX_K = len(self._kg.vocab)**self._KLEN

    def NewKey(self):
        self._last_k += 1
        assert self._last_k < self._MAX_K
        return self._last_k, self._kg.FromInt(self._last_k, self._KLEN)

class Hashable:
    def __init__(self, ns: Namespace) -> None:
        self.namespace = ns
        self.hash, self.key = ns.NewKey()

    def __hash__(self) -> int:
        return self.hash
    
    def __eq__(self, __value: object) -> bool:
        K = "key"
        return hasattr(__value, K) and self.key == getattr(__value, K)

class Node(Hashable):
    def __init__(
        self,
        ns: Namespace,
        properties: set[str],
        parents: set[Node],
    ) -> None:
        super().__init__(ns)
        self.namespace = ns
        self.properties = properties
        self.parents = parents
        # self._diffs = set()
        # self._sames = set()

    def __str__(self) -> str:
        return f"<{self.key}:{','.join(self.properties)}>"

    def __repr__(self) -> str:
        return f"{self}"
    
    def IsA(self, other: Node, compare_lineage=False) -> bool:
        # if other.key in self._diffs: return False
        # if other.key in self._sames: return True
        if not other.properties.issubset(self.properties):
            # self._diffs.add(other.key)
            return False
        # self._sames.add(other.key)
        if compare_lineage: return not other.parents.issubset(self.parents)
        return True

    def Signature(self):
        cache = self.namespace.node_signatures
        if self.key not in cache:
            props = ",".join(sorted(self.properties))
            parents = ",".join(sorted([p.Signature() for p in self.parents]))
            sig = f"{props}-{parents}"
            cache[self.hash] = sig
        return cache[self.hash]

    def MatchesMemberOf(self, collection: Iterable[Node]):
        return any(self == m for m in collection)

class Dependency(Node):
    def __init__(self, namespace: Namespace, properties: set[str], parents: set[Node]) -> None:
        super().__init__(namespace, properties, parents)

class Endpoint(Node):
    def __init__(self, namespace: Namespace, properties: set[str], parents: set[Node]=set()) -> None:
        super().__init__(namespace, properties, parents)

class Transform(Hashable):
    def __init__(self, ns: Namespace) -> None:
        super().__init__(ns)
        self.requires: list[Dependency] = []
        self.produces: list[Dependency] = []
        self._ns = ns
        self._input_group_map: dict[int, list[Dependency]] = {}
        self._key = ns.NewKey()
        self._seen: set[str] = set()

    def __str__(self) -> str:
        def _props(d: Dependency):
            return "{"+"-".join(d.properties)+"}"
        return f"<{','.join(_props(r) for r in self.requires)}->{','.join(_props(p) for p in self.produces)}>"

    def __repr__(self): return f"{self}"

    def AddRequirement(self, properties: Iterable[str], parents: set[Dependency]=set()):
        return self._add_dependency(self.requires, properties, parents)

    def AddProduct(self, properties: Iterable[str], parents: set[Dependency]=set()):
        return self._add_dependency(self.produces, properties, parents)

    def _add_dependency(self, destination: list[Dependency], properties: Iterable[str], parents: set[Dependency]=set()):
        _parents: Any = parents
        _dep = Dependency(properties=set(properties), parents=_parents, namespace=self._ns)
        # assert not any(e.IsA(_dep) for e in destination), f"prev. dep ⊆ new dep"
        # assert not any(_dep.IsA(e) for e in destination), f"new dep ⊆ prev. dep "
        destination.append(_dep)
        if destination == self.requires:
            i = len(self.requires)-1
            for p in _parents:
                assert p in self.requires, f"{p} not added as a requirement"
            self._input_group_map[i] = self._input_group_map.get(i, [])+list(_parents)
        return _dep

    def _sig(self, endpoints: Iterable[Endpoint]):
        # return "".join(e.key for e in endpoints)
        return self.key+"-"+ "".join(e.key for e in endpoints)

    def Possibilities(self, have: Iterable[Endpoint]):
        matches: list[list[Endpoint]] = []
        for req in self.requires:
            _m = [m for m in have if m.IsA(req)]
            if len(_m) == 0: return []
            matches.append(_m)
        return matches

    def Apply(self, have: Iterable[Endpoint], use_signatures: set[str]) -> Iterable[Application]:
        matches = self.Possibilities(have)
        if len(matches) == 0: return []

        # can reduce exponential trial here by enforcning the input groups first
        def _possible_configs(i: int, choosen: list[Endpoint]) -> list[list[Endpoint]]:
            if i >= len(self.requires): return [choosen]
            candidates = matches[i]
            parents = self._input_group_map.get(i, [])
            # print(parents, candidates, choosen)
            if len(parents) > 0:
                for prototype in parents:
                    # parent must be in choosen, since it must have been added
                    # as a req. before being used as a parent
                    parent: None|Endpoint = None
                    for p in choosen:
                        if p.IsA(prototype): parent = p; break
                    if parent is None: return []
                    candidates = [c for c in candidates if parent in c.parents]
            configs = []
            for c in candidates:
                configs += _possible_configs(i+1, choosen+[c])
            return configs
        configs = _possible_configs(0, [])

        def _same(a: Endpoint, b: Endpoint):
            return a.properties.issubset(b.properties) and b.properties.issubset(a.properties) \
                and a.parents.issubset(b.parents) and b.parents.issubset(a.parents)

        for input_set in configs:
            sis = set(input_set)
            sig = self._sig(input_set)
            if sig in use_signatures: continue
            _parents = sis|{p for g in [e.parents for e in input_set] for p in g}
            produced = {
                Endpoint(
                    namespace=self._ns,
                    properties=out.properties,
                    parents=_parents
                )
            for out in self.produces}
            # if all(_same(e, p) for e in have for p in produced):
            #     continue
            #     print(have)
            #     print(produced)
            #     print()
            yield Application(self, sis, produced, sig)

@dataclass
class Application:
    transform: Transform
    used: set[Endpoint]
    produced: set[Endpoint]
    signature: str

@dataclass
class Result:
    solution: list[Application]
    message: str = ""
    info: Any = None
    steps: int = 0
    success: bool = False
    
def Solve(given: Iterable[Endpoint], target: Transform, transforms: Iterable[Transform]):
    @dataclass
    class State:
        have: set[Endpoint]
        plan: list[Application]
        usage_sigs: set[str]

    def _get_next_states(curr: State):
        for tr in transforms:
            for appl in tr.Apply(curr.have, curr.usage_sigs):
                yield State(
                    have = curr.have|appl.produced,
                    plan = curr.plan + [appl],
                    usage_sigs = curr.usage_sigs|{appl.signature},
                )
    
    def _check_done(curr: State):
        appls = target.Apply(curr.have, set())
        for a in appls:
            return a
    
    MAX_S = 100_000
    MAX_D = 32
    steps = 0
    def _solve(curr: State, depth: int, depth_lim: int) -> Result:
        nonlocal steps
        if depth>=MAX_D: return Result([], f"depth limit: {depth}")
        steps += 1
        if steps>MAX_S: return Result([], f"step limit: {steps}", curr, steps)

        final_appl = _check_done(curr)
        if final_appl is not None: return Result(curr.plan+[final_appl], steps=steps, success=True)

        for n in _get_next_states(curr):
            res = _solve(n, depth+1, depth_lim)
            if res.success: return res
        return Result([], "no sol", curr, steps)
    
    start = State(set(given), [], set())
    res = _solve(start, 0, MAX_D)
    # while res.success:
    #     _res = _solve(start, 0, len(res.solution))
    #     # _res = _solve(start, 0, 6)
    #     if not _res.success:
    #         print("no futher opt")
    #         break
    #     res = _res
    sol = res.solution
    last_l = 0
    while last_l != len(sol):
        last_l = len(sol)
        used = set()
        for a in sol:
            used |= a.used
        sol = [a for a in sol if a.transform==target or any(e in used for e in a.produced)]
    res.solution = sol
    return res

NS = Namespace()
def _set(s: str):
    return set(s.split(", "))

anner = Transform(NS)
anner.AddRequirement(_set("annable"))
anner.AddProduct(_set("ann"))

taxer = Transform(NS)
taxer.AddRequirement(_set("taxable"))
taxer.AddProduct(_set("tax"))

sumer = Transform(NS)
d_parent = sumer.AddRequirement(_set("annable, taxable"))
d_ann = sumer.AddRequirement(_set("ann"), {d_parent})
d_tax = sumer.AddRequirement(_set("tax"), {d_parent})
sumer.AddProduct(_set("sum"))

M, N = 3, 2
haves = [Endpoint(NS, _set(f"{i+1}, annable, taxable")) for i in range(M)]

target = Transform(NS)
for e in haves[-N:]:
    de = target.AddRequirement(e.properties)
    target.AddRequirement(_set("sum"), {de})

# ss = [Endpoint(NS, _set("sum")) for e in bs]
tr = [anner, taxer, sumer]
# %prun Solve(bs, ss, [anner, taxer, sumer])
# r = Solve(bs, ss, [anner, taxer, sumer])
# r

# sumer.Apply(test_have)
print("Start")
# %prun r = Solve(haves, target, tr)
r = Solve(haves, target, tr)
f"input size [{N}], states checked [{r.steps}], {r.message}, {len(target.requires)}"

Start


'input size [2], states checked [10], , 4'

In [92]:
r.solution

[Application(transform=<{annable}->{ann}>, used={<G000:annable,taxable,2>}, produced={<O000:ann>}, signature='1000-G000'),
 Application(transform=<{annable}->{ann}>, used={<H000:annable,taxable,3>}, produced={<P000:ann>}, signature='1000-H000'),
 Application(transform=<{taxable}->{tax}>, used={<G000:annable,taxable,2>}, produced={<R000:tax>}, signature='5000-G000'),
 Application(transform=<{taxable}->{tax}>, used={<H000:annable,taxable,3>}, produced={<S000:tax>}, signature='5000-H000'),
 Application(transform=<{annable-taxable},{ann},{tax}->{sum}>, used={<G000:annable,taxable,2>, <O000:ann>, <R000:tax>}, produced={<V000:sum>}, signature='9000-G000O000R000'),
 Application(transform=<{annable-taxable},{ann},{tax}->{sum}>, used={<H000:annable,taxable,3>, <S000:tax>, <P000:ann>}, produced={<W000:sum>}, signature='9000-H000P000S000'),
 Application(transform=<{annable-taxable-2},{sum},{annable-taxable-3},{sum}->>, used={<G000:annable,taxable,2>, <H000:annable,taxable,3>, <W000:sum>, <V000:su

In [71]:
transforms = []
NS = Namespace()

t = Transform(NS)
t.AddRequirement(_set("dna"))
t.AddProduct(_set("contigs, asm, annable"))
transforms.append(t)

t = Transform(NS)
r = t.AddRequirement(_set("dna"))
t.AddRequirement(_set("contigs, asm"), {r})
t.AddProduct(_set("contigs, bin, annable"))
transforms.append(t)

t = Transform(NS)
t.AddRequirement(_set("annable"))
t.AddRequirement(_set("db"))
t.AddProduct(_set("ann"))
transforms.append(t)

t = Transform(NS)
r = t.AddRequirement(_set("db, cog"))
t.AddRequirement(_set("ann"), {r})
r = t.AddRequirement(_set("db, kegg"))
t.AddRequirement(_set("ann"), {r})
t.AddProduct(_set("table"))
transforms.append(t)

t = Transform(NS)
db1 = t.AddRequirement(_set("db, cog"))
db2 = t.AddRequirement(_set("db, kegg"))
r = t.AddRequirement(_set("contigs, asm"))
t.AddRequirement(_set("table"), {r, db1, db2})
r = t.AddRequirement(_set("contigs, bin"))
t.AddRequirement(_set("table"), {r, db1, db2})
t.AddProduct(_set("figure"))

# print(t.requires)

transforms.append(t)

haves = [Endpoint(NS, _set(r)) for r in [
    "db, cog",
    "db, kegg",
    "dna",
]]

target = Transform(NS)
target.AddRequirement(_set("figure"))

# change to target transform, where inputs are targets
#   this captures the required parents better
# can use parents in req. to get subtasks (parent -> req. dep.) 
r = Solve(haves, target, transforms)
print(r.steps)
for a in r.solution:
    print(a)
# print(r)

no futher opt
99996
Application(transform=<{dna}->{contigs-annable-asm}>, used={<X000:dna>}, produced={<fOM0:contigs,annable,asm>}, signature='1000-X000')
Application(transform=<{dna},{contigs-asm}->{contigs-annable-bin}>, used={<X000:dna>, <fOM0:contigs,annable,asm>}, produced={<gOM0:contigs,annable,bin>}, signature='5000-X000fOM0')
Application(transform=<{annable},{db}->{ann}>, used={<W000:kegg,db>, <fOM0:contigs,annable,asm>}, produced={<hOM0:ann>}, signature='A000-fOM0W000')
Application(transform=<{annable},{db}->{ann}>, used={<fOM0:contigs,annable,asm>, <V000:cog,db>}, produced={<iOM0:ann>}, signature='A000-fOM0V000')
Application(transform=<{annable},{db}->{ann}>, used={<W000:kegg,db>, <gOM0:contigs,annable,bin>}, produced={<jOM0:ann>}, signature='A000-gOM0W000')
Application(transform=<{cog-db},{ann},{kegg-db},{ann}->{table}>, used={<W000:kegg,db>, <hOM0:ann>, <iOM0:ann>, <V000:cog,db>}, produced={<lOM0:table>}, signature='F000-V000iOM0W000hOM0')
Application(transform=<{cog-db},{a

In [None]:

    # def _solve():
    #     todo: deque[State] = deque()
    #     todo.append(State(set(given), [], set()))
    #     MAX_S = 100_000
    #     steps = 0
    #     # _last_depth = 0
    #     while len(todo) > 0:
    #         steps += 1
    #         if steps>MAX_S: return Result([], "step limit", todo, steps)
    #         # curr = todo.popleft()
    #         curr = todo.pop()

    #         final_appl = _check_done(curr)
    #         if final_appl is not None: return Result(curr.plan+[final_appl], steps=steps)

    #         # _depth = len(curr.plan)
    #         # if _depth != _last_depth:
    #         #     todo = _deduplicate_states(curr, todo)
    #         #     _last_depth = _depth

    #         next_states = _get_next_states(curr)
    #         for n in next_states:
    #             todo.append(n)

    #     return Result([], "no sol", steps=steps)

In [None]:
    # plans: dict[Endpoint, Path] = {}
    # def _path_to(have: Iterable[Endpoint], target: Dependency) -> Path|None:
    #     if any(e.IsA(target) for e in have): return Path([])
    #     if target in plans: return plans[target]

    #     # DFS back from e
    #     for tr in transforms:
    #         if not any(d.IsA(target) for d in tr.produces): continue 
    #         for req in tr.requires:
    #             path_result = _path_to(have, req)
    #             if path_result is None: continue
    #             path_result.plan.append(tr)
    #             return path_result
    # x = [
    # # @dataclass
    # # class State:
    # #     have: Iterable[Endpoint]
    # #     targets: Iterable[Dependency]
    # #     plan: list[Transform]

    # # todo: deque[State] = deque(maxlen=64)
    # # todo.append(State([], [t for t in target.requires], []))
    # # while len(todo)>0:
    # #     _s = todo.popleft()
    # #     t = next(iter(_s.targets))
    # #     plan = 
    # ]

    # usage_signatures: dict[Transform, set[str]] = {t:set() for t in transforms}
    # def _solve(have: list[Endpoint], target: Transform, sigs: dict) -> list[Application]|None:
    #     possibilities = target.Apply(have, sigs[target])
    #     if len(possibilities)>0: return possibilities[0:1]

    #     for t in target.requires:
    #         path = _path_to(have, t)
    #         if path is None: return None
    #         fist_tr = path.plan[0]
    #         poss = fist_tr.Apply(have, sigs[fist_tr])
    #         if poss
                        
            


    # _solve(list(given), target)

In [62]:
# def Solve(given: Iterable[Endpoint], target: Transform, transforms: Iterable[Transform]):
#     @dataclass
#     class State:
#         have: list[Endpoint]
#         usage_signatures: dict[int, set[str]]
#         plan: list[Application]

#     transforms = list(transforms)
    
#     def _done(state: State):
#         appl = target.Apply(state.have, set())
#         return appl 

#     def _solve() -> Result:
#         MAXS = 10_000
#         todo: deque[State] = deque([State(
#             have = list(given),
#             plan = [],
#             usage_signatures={},
#         )], maxlen=MAXS)
        

#         def _deduplicate_states(current: State):
#             def _get_sig(s: State):
#                 haves_sig = '|'.join([e.Signature() for e in s.have])
#                 return haves_sig
#             seen = {_get_sig(current)}
#             new_todo: deque[State] = deque([], MAXS)
#             for s in todo:
#                 if _get_sig(s) in seen: continue
#                 new_todo.append(s)

#             if len(todo) != len(new_todo):
#                 for s in todo:
#                     print(s)
#                 print("-")
#                 for s in new_todo:
#                     print(s)
#                 print()
#             return new_todo

#         _steps = 0
#         _empty = set()
#         _last_depth = 0
#         while len(todo)>0:
#             _steps += 1
#             if _steps > MAXS: return Result([], f"step limit exceeded", steps=_steps)
#             _s = todo.popleft()

#             _target_applications = target.Apply(_s.have, _empty)
#             if len(_target_applications)>0:
#                 return Result(solution=_s.plan+[_target_applications[0]], steps=_steps)

#             _depth = len(_s.plan)
#             if _depth != _last_depth:
#                 todo = _deduplicate_states(_s)
#                 _last_depth = _depth

#             if _done(_s): return Result(_s.plan, steps=_steps)
#             for tr in transforms:
#                 possibilities = tr.Apply(_s.have, _s.usage_signatures.get(tr.hash, set()))
#                 # for app in possibilities:
#                 #     usage_sigs = _s.usage_signatures.copy()
#                 #     usage_sigs[tr.hash] = usage_sigs.get(tr.hash, set())|{app.signature}
#                 #     todo.append(State(
#                 #         have = _s.have+app.produced,
#                 #         plan = _s.plan+[app],
#                 #         usage_signatures = usage_sigs,
#                 #     ))

#                 if len(possibilities) == 0: continue
#                 usage_sigs = _s.usage_signatures.copy()
#                 new_have = _s.have.copy()
#                 for app in possibilities:
#                     usage_sigs[tr.hash] = usage_sigs.get(tr.hash, set())|{app.signature}
#                     new_have += app.produced
#                 todo.append(State(
#                     have = new_have,
#                     plan = _s.plan+possibilities,
#                     usage_signatures=usage_sigs
#                 ))
#         return Result([], f"ran out of things to try", steps = _steps)
    
#     res = _solve()
#     sol = res.solution
#     last_l = 0
#     while last_l != len(sol):
#         last_l = len(sol)
#         used = set()
#         for a in sol:
#             used |= a.used
#         sol = [a for a in sol if a.transform==target or any(e in used for e in a.produced)]
#     res.solution = sol
#     return res

In [65]:
# from __future__ import annotations
# import os, sys
# import asyncio
# from typing import Iterable, Callable, Any
# from pathlib import Path

# from limes_x.solver import DependencySolver, Plan, Dependency
# from limes_x.persistence import ProjectState, Instance
# from limes_x.compute_module import ComputeModule

# mpath = Path("./test_solver/")
# modules = [
#     ComputeModule(mpath.joinpath(d)) for d in os.listdir(mpath)
# ]
# print(modules)

# given = [
#     ("a", "./test_data/a1"),
#     ("a", "./test_data/a2"),
#     ("b", "./test_data/b1"),
# ]

# prj_path = "./cache/man_test01/"
# state = ProjectState(prj_path, on_exist="overwrite")
# for dtype, val in given:
#     state.RegisterInstance(Instance.Str(dtype, val))
# for m in modules:
#     state.RegisterInstance(Instance.ComputeModule(m))

# deps = []
# for k, inst in state._instances.items():
#     if not inst.IsPyType(ComputeModule): continue
#     deps.append(Dependency(inst.val.requires, inst.val.produces, k))

# solver = DependencySolver(deps)
# # plan = solver.Solve({"a"}, {"reuse", "linear", "branched"})
# plan = solver.Solve({"a"}, {"branched"})
# assert plan != False
# [state.GetInstance(m.ref_key) for m in plan]

In [66]:
# def make_dependency(module: ComputeModule):
#     return Dependency(module.requires, module.produces, module)

# modules = Path("./test_solver/")
# solver = Plan([
#     make_dependency(ComputeModule(p))
# for p in [
#     modules.joinpath(p) for p in os.listdir(modules)
# ]])
# plan = solver.Solve({"a"}, {"reuse", "linear", "branched"})
# plan

In [67]:
# from limes_x.compute_module import ComputeModule

# a = ComputeModule("./test_modules/copy/")
# b = ComputeModule("./test_modules/copy2/")

# a.requires, b.requires

In [68]:
# state = ProjectState("./cache/test_persist")
# ok = Instance("asdf", 1)
# ov = Instance("s", 2)
# state._lineage[ok] = [ov]
# state.Save()

# s2 = ProjectState.Load("./cache/test_persist")
# for k, v in s2._lineage.items():
#     _te = k.type, k.value, ok == k, [(i.type, i.value, i == ov) for i in v]
#     print(_te)

# ok._id