# import data

In [23]:
# from __future__ import annotations
# from pathlib import Path
# import yaml
# from dataclasses import dataclass, field
# from datetime import datetime as dt
# import hashlib

# from local.constants import WORKSPACE_ROOT

# def str_hash(s):
#     return int(hashlib.sha256(s.encode("utf-8", "replace")).hexdigest(), 16)

# @dataclass
# class DataType:
#     name: str
#     properties: dict[str, str]
#     library: DataTypeLibrary
#     _hash: int = None

#     def __hash__(self) -> int:
#         if self._hash is None:
#             self._hash = str_hash(''.join(self.AsSet()))
#         return self._hash
    
#     @classmethod
#     def SetFromDict(cls, raw: dict[str, str]):
#         return set(f"{k}={v}" for k, v in raw.items())

#     def AsSet(self):
#         return self.SetFromDict(self.properties)
    
# @dataclass
# class DataTypeLibrary:
#     source: Path
#     schema: str
#     ontology: dict
#     types: dict[str, DataType] = field(default_factory=dict)

#     def __getitem__(self, key: str) -> DataType:
#         return self.types[key]
    
#     def __in__(self, key: str) -> bool:
#         return key in self.types
    
#     @classmethod
#     def Load(cls, path: Path) -> DataTypeLibrary:
#         with open(path) as f:
#             d = yaml.safe_load(f)
#         lib = cls(path, d["schema"], d["ontology"])
#         types = {}
#         for k, v in d["types"].items():
#             types[k] = DataType(
#                 name=k,
#                 properties=v,
#                 library=lib,
#             )
#         lib.types = types
#         return lib

# @dataclass
# class DataInstance:
#     source: Path
#     type: DataType
#     _hash: int = None

#     def __hash__(self) -> int:
#         if self._hash is None:
#             self._hash = str_hash(str(self.source.resolve())+''.join(self.type.AsSet()))
#         return self._hash
    
#     @classmethod
#     def Register(cls, source: Path, type: DataType):
#         return cls(source, type)
    
#     def Pack(self):
#         return {
#             "source": str(self.source),
#             "type": self.type.name,
#             "properties": self.type.properties,
#         }

# @dataclass
# class DataInstanceLibrary:
#     description: str
#     types_library: DataTypeLibrary
#     manifest: dict[str, DataInstance] = field(default_factory=dict)
#     time_created: dt = field(default_factory=lambda: dt.now())
#     time_modified: dt = field(default_factory=lambda: dt.now())

#     def __getitem__(self, key: str):
#         return self.manifest[key]

#     @classmethod
#     def Load(cls, path: Path):
#         with open(path) as f:
#             d = yaml.safe_load(f)

#         class_attributes = set(cls.__annotations__.keys())
#         TYPE_LIB = "types_library"
#         d[TYPE_LIB] = DataTypeLibrary.Load(Path(d[TYPE_LIB]))
#         for k, v in d.items():
#             assert k in class_attributes, f"unexpected field [{k}]"
#             if k == "manifest":
#                 manifest = {}
#                 for kk, vv in v.items():
#                     type = d[TYPE_LIB][vv["type"]]
#                     manifest[kk] = DataInstance(
#                         source=Path(vv["source"]),
#                         type=type,
#                     )
#                 d[k] = manifest
#         return cls(**d)

#     def Dump(self, path: Path):
#         self.time_modified = dt.now()
#         with open(path, "w") as f:
#             d = {}
#             for k, v in self.__dict__.items():
#                 if k.startswith("_"): continue
#                 if callable(v): continue
#                 if k == "types_library":
#                     v = str(v.source)
#                 elif k == "manifest":
#                     v = {kk: vv.Pack() for kk, vv in v.items()}
#                 d[k] = v
#             yaml.safe_dump(d, f, indent=4)



In [24]:
from pathlib import Path
from metasmith.models.libraries import DataInstance, DataInstanceLibrary, DataTypeLibrary

from local.constants import WORKSPACE_ROOT

lib = DataTypeLibrary.Load(WORKSPACE_ROOT/"main/local_mock/prototypes/metagenomics.yml")
given_contigs = DataInstance.Register(WORKSPACE_ROOT/"scratch/test_ws/data/local/example.fna", lib["contigs"])
given_ref = DataInstance.Register(WORKSPACE_ROOT/"scratch/test_ws/data/local/uniprot_sprot.dmnd", lib["diamond_protein_reference"])
ilib = DataInstanceLibrary(
    description="test workspace",
    types_library=lib,
    manifest={
        "contigs": given_contigs,
        "diamond_reference.uniprot_sprot": given_ref,
    }
)
ilib_path = Path("./cache/test.yml")
# ilib.Dump(ilib_path)
# ilib2 = DataInstanceLibrary.Load(ilib_path)
# ilib2.Dump(ilib_path.with_name("test2.yml"))

# generate workflow

In [25]:
from metasmith.models.libraries import TransformInstanceLibrary
trlib = TransformInstanceLibrary.Load([
    Path("./transforms/simple_1"),
    # Path("./transforms/dupe_test"),
])

ilib_path = Path("./cache/test.yml")
ilib = DataInstanceLibrary.Load(ilib_path)

In [92]:
from typing import Iterable
from pathlib import Path
from dataclasses import dataclass
from metasmith.models.libraries import DataInstance, DataType, TransformInstanceLibrary, TransformInstance
from metasmith.models.solver import *

# concretely describes a solver.Application
@dataclass
class WorkflowStep:
    uses: list[DataInstance]
    produces: list[DataInstance]
    transform: TransformInstance

# concretely describes a solver.Result
@dataclass
class WorkflowPlan:
    uses: set[DataInstance]
    produces: set[DataInstance]
    steps: list[WorkflowStep]

    def __len__(self):
        return len(self.steps)

class WorkflowSolver:
    def __init__(
            self,
            lib: TransformInstanceLibrary,
        ) -> None:
        self._namespace = Namespace()
        self._transform_lib = lib
        self._transform_map: dict[Transform, TransformInstance] = {}
        self._prototype_instances: dict[Dependency, DataInstance] = {}
        def _parse_transform(tr: TransformInstance):
            ns = self._namespace
            model = Transform(ns)
            self._transform_map[model] = tr
            for x in tr.input_signature:
                model.AddRequirement(x.AsProperties())
            for x in tr.output_signature:
                dep = model.AddProduct(x.type.AsProperties())
                self._prototype_instances[dep] = x
            return model
        self._transforms = [_parse_transform(t) for p, t in lib]

    def Solve(self, given: Iterable[DataInstance], target: Iterable[DataType]):
        def _solve(given: Iterable[Endpoint], target: Transform, transforms: Iterable[Transform], _debug=False):
            @dataclass
            class State:
                have: dict[Endpoint, Dependency]
                needed: set[Dependency]
                target: Dependency|Transform
                lineage_requirements: dict[Node, Endpoint]
                seen_signatures: set[str]
                depth: int

            def _get_producers_of(target: Dependency):
                for tr in transforms:
                    for p in tr.produces:
                        if p.IsA(target):
                            yield tr
                            break

            if _debug:
                log_path = Path("./cache/debug_log.txt")
                log_path.parent.mkdir(parents=True, exist_ok=True)
                log = open("./cache/debug_log.txt", "w")
                debug_print = lambda *args: log.write(" ".join(str(a) for a in args)+"\n") if args[0] != "END" else log.close()
            else:
                debug_print = lambda *args: None

            _apply_cache: dict[str, Application] = {}
            def _apply(target: Transform, inputs: Iterable[tuple[Endpoint, Node]]):
                sig  = "".join(e.key+d.key for e, d in inputs)
                if sig in _apply_cache:
                    return _apply_cache[sig]
                appl = target.Apply(inputs)
                _apply_cache[sig] = appl
                return appl

            def _satisfies_lineage(tproto: Dependency, candidate: Endpoint):
                for tp_proto in tproto.parents:
                    if all(not p.IsA(tp_proto) for p, _ in candidate.Iterparents()):
                        return False
                return True

            HORIZON=64
            def _solve_dep(s: State) -> list[DependencyResult]:
                if s.depth >= HORIZON:
                    if _debug: debug_print(f" <-  HORIZON", s.depth)
                    return []
                target: Dependency = s.target
                assert isinstance(target, Dependency), f"{s.target}, not dep"
                if _debug: debug_print(f" ->", s.target, s.lineage_requirements)
                if _debug: debug_print(f"   ", s.have.keys())

                candidates:list[DependencyResult] = []
                for e, eproto in s.have.items():
                    if not e.IsA(target): continue
                    acceptable = True
                    for rproto, r in s.lineage_requirements.items():
                        if e == r: continue
                        if eproto.IsA(rproto): # e is protype, but explicitly breaks lineage
                            acceptable=False; break

                        for p, pproto in e.Iterparents():
                            if rproto.IsA(pproto):
                                if p != r:
                                    acceptable=False; break

                    if not acceptable:
                        continue
                    else:
                        if _debug: debug_print(f"    ^candidate", e, eproto, e.parents)
                        if _debug: debug_print(f"    ^reqs.    ", s.lineage_requirements)
                        candidates.append(DependencyResult([], e))
                    # elif quality == 2:
                    #     if DEBUG: debug_print(f" <-", s.target, e, "DIRECT")
                    #     return [DepResult(0, [], e)]

                def _add_result(res: Result):
                    ep: Endpoint|None = None
                    for e in res.application.produced:
                        if e.IsA(target):
                            ep = e; break
                    assert isinstance(ep, Endpoint)
                    if not _satisfies_lineage(target, ep): return
                    candidates.append(DependencyResult(
                        res.dependency_plan+[res.application],
                        ep,
                    ))

                for tr in _get_producers_of(target):
                    # if target in tr.deletes: continue
                    results = _solve_tr(State(s.have, s.needed, tr, s.lineage_requirements, s.seen_signatures, s.depth))
                    for res in results:
                        _add_result(res)

                if _debug: debug_print(f" <-", s.target, f"{len(candidates)} sol.", candidates[0].endpoint if len(candidates)>0 else None)
                return candidates

            _transform_cache: dict[str, list[Result]] = {}
            def _solve_tr(s: State) -> list[Result]:
                assert isinstance(s.target, Transform), f"{s.target} not tr"
                target: Transform = s.target
                if _debug: debug_print(f">>>{s.depth:02}", s.target, s.lineage_requirements)
                for h in s.have:
                    if _debug: debug_print(f"      ", h)

                # memoization
                sig = "".join(e.key for e in s.have)
                sig += f":{s.target.key}"
                sig += ":"+"".join(e.key for e in s.lineage_requirements.values())
                if sig in _transform_cache:
                    if _debug: debug_print(f"<<<{s.depth:02} CACHED: {len(_transform_cache[sig])} solutions")
                    return _transform_cache[sig]
                if sig in s.seen_signatures:
                    if _debug: debug_print(f"<<<{s.depth:02} FAIL: is loop")
                    return []

                plans: list[list[DependencyResult]] = []
                for i, req in enumerate(s.target.requires):
                    req_p = {}
                    for proto, e in s.lineage_requirements.items():
                        if req.IsA(proto): continue
                        req_p[proto] = e

                    results = _solve_dep(State(s.have, s.needed|{req}, req, req_p, s.seen_signatures|{sig}, s.depth+1))
                    
                    if len(results) == 0:
                        if _debug: debug_print(f"<<< FAIL", s.target, req)
                        return []
                    else:
                        plans.append(results)

                def _gather_valid_inputs():
                    valids: list[list[DependencyResult]] = []
                    ii = 0
                    def _gather(req_i: int, req: Dependency, res: DependencyResult, deps: dict, used: set[Endpoint], inputs: list[DependencyResult]):
                        nonlocal ii; ii += 1         
                        if _debug: debug_print(f"          ", deps)
                        if _debug: debug_print(f"    ___", req, req.parents)
                        if _debug: debug_print(f"        __", res.endpoint, list(res.endpoint.Iterparents()))
                        if res.endpoint in used:
                            if _debug: debug_print(f"    ___ FAIL: duplicate input", res.endpoint)
                            return
                        # used.add(res.endpoint)

                        if not _satisfies_lineage(req, res.endpoint):
                            if _debug: debug_print(f"    ___ FAIL: unsatisfied lineage", req)
                            return

                        for rproto in req.parents:
                            r = deps[rproto]
                            # if all(not p.IsA(rproto) for p, pproto in res.endpoint.Iterparents()):
                            #     if DEBUG: debug_print(f"    ___ FAIL: unsatisfied lineage", rproto)
                            #     _fail=True; break
                            res_parents = list(res.endpoint.Iterparents())
                            res_parents.reverse()
                            for p, pproto in res_parents:
                                if not p.IsA(rproto): continue
                                if p!=r:
                                    if _debug: debug_print(f"    ___ FAIL: lineage mismatch", p, r)
                                    return
                                else:
                                    break # in the case of asm -> bin, the closest ancestor takes priority
                        # deps[req] = res.endpoint

                        if req_i >= len(target.requires)-1:
                            valids.append(inputs+[res])
                        else:
                            req_i += 1
                            for i, next_res in enumerate(plans[req_i]):
                                _gather(req_i, target.requires[req_i], next_res, deps|{req:res.endpoint}, used|{res.endpoint}, inputs+[res])
                    req_i = 0
                    for i, next_res in enumerate(plans[req_i]):
                        _gather(0, target.requires[req_i], next_res, {}, set(), [])
                    total = 1
                    for s in plans:
                        total *= len(s)
                    if _debug: debug_print(f"    ## {ii} visited, {total} combos")
                    return valids

                if _debug: debug_print(f"<<<{s.depth:02}", s.target, s.lineage_requirements)
                if _debug: debug_print(f"     ", [len(x) for x in plans])
                solutions: list[Result] = []
                # for inputs in _iter_satisfies():
                for inputs in _gather_valid_inputs():
                    my_appl = _apply(s.target, [(res.endpoint, req) for req, res in zip(s.target.requires, inputs)])
                    consolidated_plan: list[Application] = []
                    produced_sigs: set[str] = {p.Signature() for p in my_appl.produced}
                    # if DEBUG: debug_print(f"   __", my_appl)
                    for res in inputs:
                        for appl in res.plan:
                            if all(p.Signature() in produced_sigs for p in appl.produced): continue
                            consolidated_plan.append(appl)
                            produced_sigs = produced_sigs.union(p.Signature() for p in appl.produced)
                    solutions.append(Result(
                        my_appl,
                        consolidated_plan,
                    ))
                    # if DEBUG: debug_print(f"    *", my_appl)
                    # if DEBUG: debug_print(f"     ", [res.endpoint for res in inputs])
                    # if DEBUG: debug_print(f"    .", target.requires)
                    # for appl in consolidated_plan:
                    #     if DEBUG: debug_print(f"    __", appl)
                if _debug: debug_print(f"     ", f"{len(solutions)} sol.", solutions[0].application.produced if len(solutions)>0 else None)
                solutions = sorted(solutions, key=lambda s: len(s))
                _transform_cache[sig] = solutions
                return solutions

            input_tr = Transform(target._ns)
            given_dict = {g:input_tr.AddProduct(g.properties) for g in given}
            res = _solve_tr(State(given_dict, set(), target, {}, set(), 0))
            if _debug: debug_print("END")
            return res

        data_instances = {Endpoint(self._namespace, x.type.AsProperties()):x for x in given}
        given_instances = {k for k in data_instances}
        output_map: dict[Dependency, DataType] = {}
        target_tr = Transform(self._namespace)
        for x in target:
            dep = target_tr.AddRequirement(x.AsProperties())
            output_map[dep] = x
        solutions = _solve(data_instances, target_tr, self._transforms)
        if len(solutions) == 0: return

        solution = solutions[0] # just pick first solution
        steps = []
        for appl in solution.dependency_plan:
            tr = self._transform_map[appl.transform]
            used = [data_instances[x] for x in appl.used]
            produced = []
            for e, dep in appl.produced.items():
                inst = self._prototype_instances[dep]
                data_instances[e] = inst
                produced.append(inst)
            steps.append(WorkflowStep(used, produced, tr))

        target_instances = set()
        for x in solution.application.used: # targets
            target_instances.add(data_instances[x])
        return WorkflowPlan(
            given_instances,
            target_instances,
            steps
        )

solver = WorkflowSolver(trlib)
plan = solver.Solve(
    [
        ilib["contigs"],
        ilib["diamond_reference.uniprot_sprot"],
    ],
    [
        lib.types["orf_annotations"],
    ]
)

for step in plan.steps:
    # for d in step.uses:
    #     print(d.type)
    step.transform.protocol(None)

this is pprodigal!
this is diamond!
