In [None]:
from typing import Protocol, Self, TypedDict, NotRequired
import os
import enum
import dataclasses
import anyio
from anyio.streams.memory import MemoryObjectSendStream
import dagger
from dagger import dag
import logic
import statemachine
#from workspace import Workspace
from models.utils import loop_completion
from models.common import AsyncLLM, Message, TextRaw, ToolUse, ThinkingBlock, ContentBlock

In [None]:
class Workspace(Protocol):
    ...

In [None]:
@dataclasses.dataclass
class ExecResult:
    exit_code: int
    stdout: str
    stderr: str

    @classmethod
    async def from_dagger(cls, ctr: dagger.Container) -> Self:
        return cls(
            exit_code=await ctr.exit_code(),
            stdout=await ctr.stdout(),
            stderr=await ctr.stderr(),
        )

@dataclasses.dataclass
class NodeData:
    workspace: Workspace
    messages: list[Message]
    files: dict[str, str] = dataclasses.field(default_factory=dict)

    def head(self) -> Message:
        if (num_messages := len(self.messages)) != 1:
            raise ValueError(f"Expected 1 got {num_messages} messages: {self.messages}")
        if self.messages[0].role != "assistant":
            raise ValueError(f"Expected assistant role in message: {self.messages}")
        return self.messages[0]
    
    def dump(self) -> dict:
        return {"messages": self.messages, "files": self.files}

In [None]:
async def node_completion(m_client: AsyncLLM, nodes: list[logic.Node[NodeData]], **kwargs) -> list[logic.Node[NodeData]]:
    async def node_fn(node: logic.Node[NodeData], tx: MemoryObjectSendStream[logic.Node[NodeData]]):
        history = [m for n in node.get_trajectory() for m in n.data.messages]
        new_node = logic.Node[NodeData](
            data=NodeData(
                workspace=node.data.workspace.clone(),
                messages=[await loop_completion(m_client, history, **kwargs)],
                files=node.data.files.copy(),
            ),
            parent=node
        )
        async with tx:
            await tx.send(new_node)
    result = []
    tx, rx = anyio.create_memory_object_stream[logic.Node[NodeData]]()
    async with anyio.create_task_group() as tg:
        for node in nodes:
            tg.start_soon(node_fn, node, tx.clone())
        tx.close()
        async with rx:
            async for new_node in rx:
                # pyright: ignore[reportOptionalMemberAccess]
                new_node.parent.children.append(new_node)
                result.append(new_node)
    return result