In [None]:
import asyncio
from lib.rl.completion_sampler import CompletionSampler
from lib.rl.episode import Episode
from lib.tokenizer import Tokenizer
from typing import Callable, Protocol


class SearchStrategy(Protocol):
    def __call__(
        self,
        sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode],
        update_progress: Callable[[float], None],
    ) -> None: ...


async def simple_search(
    sampler: CompletionSampler,
    tokenizer: Tokenizer,
    ready_episodes: asyncio.Queue[Episode],
    done_episodes: asyncio.Queue[Episode],
    update_progress: Callable[[float], None],
) -> None:
    while episode := await ready_episodes.get():
        task = asyncio.create_task(
            episode.sample_completions(
                sampler, tokenizer, num_parents=1, branch_factor=1
            )
        )

        def done_callback(_: asyncio.Task[bool]) -> None:
            done_episodes.put_nowait(episode)
            update_progress(1)

        task.add_done_callback(done_callback)