In [15]:
from __future__ import annotations
from dataclasses import dataclass, field
import os, sys
from typing import Any, Iterable, Literal
import hashlib
import numpy as np
from limes_x.utils import KeyGenerator

class Namespace:
    def __init__(self) -> None:
        self.node_signatures: dict[str, str] = {}
        self._keygen = KeyGenerator()
        self._keys: set[str] = set()

    def NewKey(self):
        l = 4 if len(self._keys) < 1_000_000 else 8
        return self._keygen.GenerateUID(l=l, blacklist=self._keys)

class Hashable:
    def __init__(self, ns: Namespace) -> None:
        self.namespace = ns
        self.key = ns.NewKey()
        self.hash = int(hashlib.md5(self.key.encode("latin1")).hexdigest(), 16)

    def __hash__(self) -> int:
        return self.hash
    
    def __eq__(self, __value: object) -> bool:
        K = "hash"
        return hasattr(__value, K) and self.hash == getattr(__value, K)

class Node(Hashable):
    def __init__(
        self,
        ns: Namespace,
        properties: set[str],
        parents: set[Node],
    ) -> None:
        super().__init__(ns)
        self.namespace = ns
        self.properties = properties
        self.parents = parents

    def __str__(self) -> str:
        return f"<{self.key}:{','.join(self.properties)}>"

    def __repr__(self) -> str:
        return f"{self}"
    
    # # x == y if x is a "subset" of y
    # # that is, x has at least all features of y
    # def __eq__(self, __value: object) -> bool:
    #     if not isinstance(__value, Node): return False
    #     # if not __value.properties.issubset(self.properties): return False
    #     for p in __value.properties:
    #         if p not in self.properties: return False
    #     for p in __value.parents:
    #         if all(p != op for op in self.parents): return False
    #     return True
    
    def IsA(self, other: Node) -> bool:
        if not other.properties.issubset(self.properties): return False
        # if not other.parents.issubset(self.parents): return False
        return True

    def Signature(self):
        cache = self.namespace.node_signatures
        if self.key not in cache:
            props = "".join(sorted(self.properties))
            parents = "".join(sorted([p.Signature() for p in self.parents]))
            sig = props+parents
            cache[self.key] = sig
        return cache[self.key]

    def MatchesMemberOf(self, collection: Iterable[Node]):
        return any(self == m for m in collection)

class Dependency(Node):
    def __init__(self, namespace: Namespace, properties: set[str], parents: set[Node]) -> None:
        super().__init__(namespace, properties, parents)

class Endpoint(Node):
    def __init__(self, namespace: Namespace, properties: set[str], parents: set[Node]=set()) -> None:
        super().__init__(namespace, properties, parents)

class Transform(Hashable):
    def __init__(self, ns: Namespace) -> None:
        super().__init__(ns)
        self.requires: list[Dependency] = []
        self.produces: list[Dependency] = []
        self._ns = ns
        self._input_group_map: dict[int, list[Dependency]] = {}
        self._key = ns.NewKey()

    def __str__(self) -> str:
        def _props(d: Dependency):
            return "{"+"-".join(d.properties)+"}"
        return f"<{','.join(_props(r) for r in self.requires)}->{','.join(_props(p) for p in self.produces)}>"

    def __repr__(self): return f"{self}"

    def AddRequirement(self, properties: Iterable[str], parents: set[Dependency]=set()):
        return self._add_dependency(self.requires, properties, parents)

    def AddProduct(self, properties: Iterable[str], parents: set[Dependency]=set()):
        return self._add_dependency(self.produces, properties, parents)

    def _add_dependency(self, destination: list[Dependency], properties: Iterable[str], parents: set[Dependency]=set()):
        _parents: Any = parents
        _dep = Dependency(properties=set(properties), parents=_parents, namespace=self._ns)
        # assert not any(e.IsA(_dep) for e in destination), f"prev. dep ⊆ new dep"
        # assert not any(_dep.IsA(e) for e in destination), f"new dep ⊆ prev. dep "
        destination.append(_dep)
        if destination == self.requires:
            i = len(self.requires)-1
            for p in _parents:
                assert p in self.requires, f"{p} not added as a requirement"
            self._input_group_map[i] = self._input_group_map.get(i, [])+list(_parents)
        return _dep

    def _sig(self, endpoints: Iterable[Endpoint]):
        return "".join(e.key for e in endpoints)

    def Apply(self, have: Iterable[Endpoint], blacklist: set[str]):
        matches: list[list[Endpoint]] = []
        for req in self.requires:
            _m = [m for m in have if m.IsA(req)]
            if len(_m) == 0: return []
            matches.append(_m)

        # can reduce exponential trial here by enforcning the input groups first
        def _possible_configs(i: int, choosen: list[Endpoint]) -> list[list[Endpoint]]:
            if i >= len(self.requires): return [choosen]
            candidates = matches[i]
            parents = self._input_group_map.get(i, [])
            # print(parents, candidates, choosen)
            if len(parents) > 0:
                for prototype in parents:
                    # parent must be in choosen, since it must have been added
                    # as a req. before being used as a parent
                    parent: None|Endpoint = None
                    for p in choosen:
                        if p.IsA(prototype): parent = p; break
                    if parent is None: return []
                    candidates = [c for c in candidates if parent in c.parents]
            configs = []
            for c in candidates:
                configs += _possible_configs(i+1, choosen+[c])
            return configs
        configs = _possible_configs(0, [])

        # todo: next optimization is DFS, with saved subplans

        applications: list[Application] = []
        for input_set in configs:
            sis = set(input_set)
            sig = self._sig(input_set)
            if sig in blacklist: continue
            _parents = sis|{p for g in [e.parents for e in input_set] for p in g}
            produced = [
                Endpoint(
                    namespace=self._ns,
                    properties=out.properties,
                    parents=_parents
                )
            for out in self.produces]
            applications.append(Application(self, sis, produced, sig))
        return applications

@dataclass
class Application:
    transform: Transform
    used: set[Endpoint]
    produced: list[Endpoint]
    signature: str

@dataclass
class Result:
    solution: list[Application]
    message: str = ""
    evidence: Any = None
    steps: int = 0
    
def Solve(given: Iterable[Endpoint], targets: Iterable[Endpoint], transforms: Iterable[Transform]):
    if all(any(t.IsA(g) for g in given) for t in targets): return Result([], "given all targets")

    @dataclass
    class State:
        targets: list[Endpoint]
        have: list[Endpoint]
        usage_signatures: dict[str, set[str]]
        plan: list[Application]
        last_tr_i: int

    transforms = list(transforms)
    solution_transforms: set[Transform] = set()
    for i, tr in enumerate(transforms):
        for p in tr.produces:
            if all(not p.IsA(t) for t in targets): continue
            solution_transforms.add(tr)
            break
    todo = [State(
        targets = list(targets),
        have = list(given),
        plan = [],
        last_tr_i = 0,
        usage_signatures={},
    )]
    _steps = 0
    MAXS = 9999
    while len(todo)>0:
        _steps += 1
        if _steps > MAXS: return Result([], f"step limit exceeded", steps=_steps)
        _s = todo.pop()
        # print(len(_s.plan))
        # for a in _s.plan:
        #     print(a)
        # print(_s.targets)
        # print()
        tri = _s.last_tr_i+1
        if tri>len(transforms): tri = 0
        trs = transforms[tri:]+transforms[:tri]
        # print(trs)
        for tr in trs:
            # print("/:", tr)
            for app in tr.Apply(_s.have, _s.usage_signatures.get(tr.key, set())):
                new_targets = _s.targets
                def _solves(target: Endpoint):
                    produced: None|Endpoint = None
                    for p in app.produced:
                        if p.IsA(target): produced = p; break
                    if produced is None: return False
                    # print(all(tp in produced.parents for tp in target.parents), produced.parents, target.parents)
                    if all(tp in produced.parents for tp in target.parents):
                        return True
                    return False
                if tr in solution_transforms:
                    new_targets = [t for t in new_targets if not _solves(t)]
                    # print(new_targets)
                    # print(app.produced)
                    # print()
                    if len(new_targets) == 0: return Result(_s.plan+[app], steps=_steps)
                sigs = _s.usage_signatures.copy()
                sigs[tr.key] = sigs.get(tr.key, set())|{app.signature}
                todo.append(State(
                    targets = new_targets,
                    have = _s.have+app.produced,
                    last_tr_i = tri,
                    plan = _s.plan+[app],
                    usage_signatures = sigs,
                ))
        #     print("//", tr)
        # print()
    return Result([], f"ran out of things to try", steps = _steps)

# x, steps = Solve([asm, bin], [sum_asm, sum_bin], [anner, taxer, sumer])
# x, steps

# x, steps = Solve([asm, bin]+bs, [sum_asm, sum_bin]+ss, [anner, taxer, sumer])
# x, steps

NS = Namespace()
def _set(s: str):
    return set(s.split(", "))
# d_ann = sumer.AddRequirement(_set("ann"))
# d_tax = sumer.AddRequirement(_set("tax"))

anner = Transform(NS)
anner.AddRequirement(_set("annable"))
anner.AddProduct(_set("ann"))

taxer = Transform(NS)
taxer.AddRequirement(_set("taxable"))
taxer.AddProduct(_set("tax"))

sumer = Transform(NS)
d_parent = sumer.AddRequirement(_set("annable, taxable"))
d_ann = sumer.AddRequirement(_set("ann"), {d_parent})
d_tax = sumer.AddRequirement(_set("tax"), {d_parent})
sumer.AddProduct(_set("sum"))

N = 3
haves = [Endpoint(NS, _set(f"{i+1}, annable, taxable")) for i in range(N)]
targets = [Endpoint(NS, _set("sum"), {e}) for e in haves]


# ss = [Endpoint(NS, _set("sum")) for e in bs]
tr = [anner, taxer, sumer]
# %prun Solve(bs, ss, [anner, taxer, sumer])
# r = Solve(bs, ss, [anner, taxer, sumer])
# r

test_have = []
for b in haves[:N]:
    test_have.append(b)
    test_have.append(Endpoint(NS, _set("ann"), {b}))
    test_have.append(Endpoint(NS, _set("tax"), {b}))

# sumer.Apply(test_have)
print("Start")
# %prun r = Solve(bs, ss, tr)
r = Solve(haves, targets, tr)
f"input size [{N}], states checked [{r.steps}], {r.message}, {len(targets)}"

Start


'input size [3], states checked [9], , 3'

In [16]:
len(r.solution), r.steps

(9, 9)

In [17]:
r.solution

[Application(transform=<{annable}->{ann}>, used={<LGfH:3,annable,taxable>}, produced=[<UHDk:ann>], signature='LGfH'),
 Application(transform=<{taxable}->{tax}>, used={<LGfH:3,annable,taxable>}, produced=[<VXyY:tax>], signature='LGfH'),
 Application(transform=<{annable-taxable},{ann},{tax}->{sum}>, used={<LGfH:3,annable,taxable>, <VXyY:tax>, <UHDk:ann>}, produced=[<srBF:sum>], signature='LGfHUHDkVXyY'),
 Application(transform=<{taxable}->{tax}>, used={<iuNk:2,annable,taxable>}, produced=[<a5Lt:tax>], signature='iuNk'),
 Application(transform=<{annable}->{ann}>, used={<iuNk:2,annable,taxable>}, produced=[<vXVT:ann>], signature='iuNk'),
 Application(transform=<{taxable}->{tax}>, used={<nI13:annable,taxable,1>}, produced=[<X24o:tax>], signature='nI13'),
 Application(transform=<{annable-taxable},{ann},{tax}->{sum}>, used={<vXVT:ann>, <a5Lt:tax>, <iuNk:2,annable,taxable>}, produced=[<9QdS:sum>], signature='iuNkvXVTa5Lt'),
 Application(transform=<{annable}->{ann}>, used={<nI13:annable,taxable

In [14]:
transforms = []

t = Transform(NS)
t.AddRequirement(_set("dna"))
t.AddProduct(_set("contigs, asm, annable"))
transforms.append(t)

t = Transform(NS)
r = t.AddRequirement(_set("dna"))
t.AddRequirement(_set("contigs, asm"), {r})
t.AddProduct(_set("contigs, bin, annable"))
transforms.append(t)

t = Transform(NS)
t.AddRequirement(_set("annable"))
t.AddRequirement(_set("db"))
t.AddProduct(_set("ann"))
transforms.append(t)

t = Transform(NS)
r = t.AddRequirement(_set("db, cog"))
t.AddRequirement(_set("ann"), {r})
r = t.AddRequirement(_set("db, kegg"))
t.AddRequirement(_set("ann"), {r})
t.AddProduct(_set("table"))
transforms.append(t)

t = Transform(NS)
r = t.AddRequirement(_set("contigs, asm"))
t.AddRequirement(_set("table"), {r})
r = t.AddRequirement(_set("contigs, bin"))
t.AddRequirement(_set("table"), {r})
t.AddProduct(_set("figure"))

# print(t.requires)

transforms.append(t)

haves = [Endpoint(NS, _set(r)) for r in [
    "db, cog",
    "db, kegg",
    "dna",
]]

targets = [Endpoint(NS, _set(r)) for r in [
    # "asm",
    # "bin",
    "table",
    # "figure",
]]

# change to target transform, where inputs are targets
#   this captures the required parents better
# can use parents in req. to get subtasks (parent -> req. dep.) 
r = Solve(haves, targets, transforms)
for a in r.solution:
    print(a)

Application(transform=<{dna}->{contigs-asm-annable}>, used={<JWpT:dna>}, produced=[<6R9Z:contigs,asm,annable>], signature='JWpT')
Application(transform=<{dna},{contigs-asm}->{contigs-annable-bin}>, used={<JWpT:dna>, <6R9Z:contigs,asm,annable>}, produced=[<52OE:contigs,annable,bin>], signature='JWpT6R9Z')
Application(transform=<{annable},{db}->{ann}>, used={<BmHI:kegg,db>, <52OE:contigs,annable,bin>}, produced=[<rHZf:ann>], signature='52OEBmHI')
Application(transform=<{annable},{db}->{ann}>, used={<odCd:cog,db>, <52OE:contigs,annable,bin>}, produced=[<BsfZ:ann>], signature='52OEodCd')
Application(transform=<{cog-db},{ann},{kegg-db},{ann}->{table}>, used={<BmHI:kegg,db>, <rHZf:ann>, <BsfZ:ann>, <odCd:cog,db>}, produced=[<Mexy:table>], signature='odCdBsfZBmHIrHZf')


In [62]:
# def Solve(given: Iterable[Endpoint], targets: Iterable[Endpoint], transforms: Iterable[Transform]):
#     if all(any(t.IsA(g) for g in given) for t in targets): return Result([], "given all targets")

#     @dataclass
#     class State:
#         targets: set[Endpoint]
#         usage_signatures: set[str]
#         plan: list[Application]

#     prod_map: dict[str, set[Transform]] = {}
#     for tr in transforms:
#         for prod in tr.produces:
#             for prop in prod.properties:
#                 prod_map[prop] = prod_map.get(prop, set()) | {tr}

#     # sub_solutions: dict[Endpoint, list[Transform]] = {}
#     def _solve_target(target: Node):
#         # if can't produce a property, can't produce target
#         if any(p not in prod_map for p in target.properties): return []
#         # get transforms that can create all properties
#         candidates: None|set[Transform] = None
#         for p in target.properties:
#             if candidates is None: candidates = prod_map[p]
#             else: candidates = candidates ^ prod_map[p]
#         if candidates is None or len(candidates) == 0: return []
#         # ensure transforms can create target. properties are no same product
#         valid_transforms: list[Transform] = []
#         for tr in candidates:
#             for prod in tr.produces:
#                 # print(prod.properties, target.properties)
#                 # print(prod.parents, target.parents)
#                 if not prod.IsA(target): continue
#                 valid_transforms.append(tr)
#                 break
#         return valid_transforms

#     given_props = {p for g in [g.properties for g in given] for p in g}
#     def _in_given(n: Node):
#         if not n.properties.issubset(given_props): return False
#         useable = [g for g in given if g.IsA(n)]
#         if len(useable) == 0: return False
#         return useable
    
#     def signature():
#         pass

#     s_given: set[Endpoint] = set(given)
#     s_targets: set[Endpoint] = set(targets)
#     todo: list[State] = [State(
#         targets=s_targets-s_given,
#         usage_signatures=set(),
#         plan=[],
#     )]
#     while len(todo)>0:
#         print(">")
#         _s = todo.pop()
#         if len(_s.targets) == 0: return Result(_s.plan) # solved!
        
#         valid_transforms = {tr for g in [_solve_target(t) for t in _s.targets] for tr in g}
#         if len(valid_transforms) == 0: return Result(_s.plan, "no valid transforms for", _s.targets)
#         print(valid_transforms, _s.targets)
#         for tr in valid_transforms:
#             reqs, pending = [], set()
#             for r in tr.requires:
#                 useable_givens = _in_given(r)
#                 _continue = False
#                 if useable_givens:
#                     for g in useable_givens:
#                         if g in used_givens: continue
#                         reqs.append(g)
#                         used_givens.add(g)
#                         _continue=True; break
#                 if _continue: continue
#                 n = Endpoint

#                 reqs.append(r)
#                 pending.add(r)
        
#             produced = set()
#             for p in tr.produces:
#                 for t in _s.targets:
#                     if t in produced: continue # comparison using exact hash
#                     if p.IsA(t):
#                         produced.add(t)
#                         break
#             todo.append(State(
#                 targets=_s.targets-produced|pending,
#                 usage_signatures=_s.usage_signatures.copy(),
#                 plan=_s.plan+[Application(tr, reqs, produced)],
#             ))

#     return Result([], "todo exhausted")

In [65]:
# from __future__ import annotations
# import os, sys
# import asyncio
# from typing import Iterable, Callable, Any
# from pathlib import Path

# from limes_x.solver import DependencySolver, Plan, Dependency
# from limes_x.persistence import ProjectState, Instance
# from limes_x.compute_module import ComputeModule

# mpath = Path("./test_solver/")
# modules = [
#     ComputeModule(mpath.joinpath(d)) for d in os.listdir(mpath)
# ]
# print(modules)

# given = [
#     ("a", "./test_data/a1"),
#     ("a", "./test_data/a2"),
#     ("b", "./test_data/b1"),
# ]

# prj_path = "./cache/man_test01/"
# state = ProjectState(prj_path, on_exist="overwrite")
# for dtype, val in given:
#     state.RegisterInstance(Instance.Str(dtype, val))
# for m in modules:
#     state.RegisterInstance(Instance.ComputeModule(m))

# deps = []
# for k, inst in state._instances.items():
#     if not inst.IsPyType(ComputeModule): continue
#     deps.append(Dependency(inst.val.requires, inst.val.produces, k))

# solver = DependencySolver(deps)
# # plan = solver.Solve({"a"}, {"reuse", "linear", "branched"})
# plan = solver.Solve({"a"}, {"branched"})
# assert plan != False
# [state.GetInstance(m.ref_key) for m in plan]

In [66]:
# def make_dependency(module: ComputeModule):
#     return Dependency(module.requires, module.produces, module)

# modules = Path("./test_solver/")
# solver = Plan([
#     make_dependency(ComputeModule(p))
# for p in [
#     modules.joinpath(p) for p in os.listdir(modules)
# ]])
# plan = solver.Solve({"a"}, {"reuse", "linear", "branched"})
# plan

In [67]:
# from limes_x.compute_module import ComputeModule

# a = ComputeModule("./test_modules/copy/")
# b = ComputeModule("./test_modules/copy2/")

# a.requires, b.requires

In [68]:
# state = ProjectState("./cache/test_persist")
# ok = Instance("asdf", 1)
# ov = Instance("s", 2)
# state._lineage[ok] = [ov]
# state.Save()

# s2 = ProjectState.Load("./cache/test_persist")
# for k, v in s2._lineage.items():
#     _te = k.type, k.value, ok == k, [(i.type, i.value, i == ov) for i in v]
#     print(_te)

# ok._id