In [None]:
import asyncio
from dataclasses import dataclass
from lib.rl.completion import Completion, SplitMethod
from lib.rl.completion_sampler import CompletionSampler, SamplingKwargs
from lib.rl.episode import Episode
from lib.tokenizer import Tokenizer
import random
from typing import Callable, Protocol


class SearchStrategy(Protocol):
    def __call__(
        self,
        sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode],
        update_progress: Callable[[float], None],
    ) -> None: ...


@dataclass
class SimpleSearch(SearchStrategy):
    num_samples: int

    async def __call__(
        self,
        sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode],
        update_progress: Callable[[float], None],
    ) -> None:
        while episode := await ready_episodes.get():
            task = asyncio.create_task(
                episode.sample_completions(
                    sampler, tokenizer, num_parents=1, branch_factor=self.num_samples
                )
            )

            def done_callback(_: asyncio.Task[bool]) -> None:
                done_episodes.put_nowait(episode)
                update_progress(1)

            task.add_done_callback(done_callback)


@dataclass
class TreeSearch(SearchStrategy):
    branch_factor: int
    depth: int
    split_method: SplitMethod = "count"
    split_separators: set[str] = set()

    async def __call__(
        self,
        completion_sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode],
        update_progress: Callable[[float], None],
    ) -> None:
        model = await completion_sampler.get_model()
        priority = 0
        while episode := await ready_episodes.get():
            priority += 1

            async def expand(episode: Episode) -> None:
                pending: set[asyncio.Task] = {
                    asyncio.create_task(
                        episode.sample_completions(
                            completion_sampler,
                            tokenizer,
                            num_parents=1,
                            branch_factor=self.branch_factor,
                            priority=priority,
                        )
                    )
                }

                num_leaves = 0
                while pending:
                    _, pending = await asyncio.wait(
                        pending, return_when=asyncio.FIRST_COMPLETED
                    )
                    _num_leaves = 0
                    for leaf in episode.completion.leaves(model=model):
                        _num_leaves += 1
                        depth = sum(1 for _ in leaf.ancestors())
                        num_partitions = self.depth - depth + 1
                        if num_partitions > 1:
                            parents = list(
                                leaf.split(
                                    by=self.split_method,
                                    at=(
                                        split / num_partitions
                                        for split in range(1, num_partitions)
                                    ),
                                    separators=self.split_separators,
                                    cache=True,
                                )
                            )[:-1]
                            for parent in parents:
                                pending.add(
                                    asyncio.create_task(
                                        episode._sample_completions(
                                            parent=parent,
                                            model=model,
                                            completion_sampler=completion_sampler,
                                            tokenizer=tokenizer,
                                            branch_factor=self.branch_factor,
                                            fork_decay=1.0,
                                            recovery_pattern=None,
                                            split_separators=self.split_separators,
                                            sampling_kwargs=SamplingKwargs(),
                                            priority=priority,
                                        )
                                    )
                                )
                    update_progress(
                        (num_leaves - _num_leaves) / (self.branch_factor**self.depth)
                    )
                    num_leaves = _num_leaves

                await done_episodes.put(episode)

            asyncio.create_task(expand(episode))


@dataclass
class VineSearch(SearchStrategy):
    branch_factor: int
    depth: int
    exploration_weight: Callable[[Completion], float] = lambda _: 1.0
    split_method: SplitMethod = "count"
    split_separators: set[str] = set()

    async def __call__(
        self,
        completion_sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode],
        update_progress: Callable[[float], None],
    ) -> None:
        model = await completion_sampler.get_model()
        priority = 0
        while episode := await ready_episodes.get():
            priority += 1

            async def expand(episode: Episode) -> None:
                num_samples = (self.branch_factor - 1) * self.depth + 1
                await asyncio.create_task(
                    episode.sample_completions(
                        completion_sampler,
                        tokenizer,
                        num_parents=1,
                        branch_factor=self.branch_factor,
                        priority=priority,
                    )
                )
                update_progress(self.branch_factor / num_samples)
                vine = random.choices(
                    list(episode.completion.leaves(model=model)),
                    weights=[
                        self.exploration_weight(leaf)
                        for leaf in episode.completion.leaves(model=model)
                    ],
                    k=1,
                )[0]
                parents = list(
                    vine.split(
                        by=self.split_method,
                        at=(split / self.depth for split in range(1, self.depth)),
                        separators=self.split_separators,
                        cache=True,
                    )
                )[:-1]
                await asyncio.gather(
                    *(
                        episode._sample_completions(
                            parent=parent,
                            model=model,
                            completion_sampler=completion_sampler,
                            tokenizer=tokenizer,
                            branch_factor=self.branch_factor,
                            fork_decay=1.0,
                            recovery_pattern=None,
                            split_separators=self.split_separators,
                            sampling_kwargs=SamplingKwargs(),
                            priority=priority,
                        )
                        for parent in parents
                    )
                )
                update_progress((num_samples - self.branch_factor) / num_samples)
                await done_episodes.put(episode)

            asyncio.create_task(expand(episode))