In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lib.vllm import start_vllm_server, vllm_server_metrics

model = "NousResearch/Hermes-2-Theta-Llama-3-8B"

shutdown_server, client = await start_vllm_server(
    disable_log_requests=True,
    model=model,
)

No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.56it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.30it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.69it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.56it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.54it/s]



In [26]:
shutdown_server()

False

In [None]:
from lib.rl.sampler import CompletionSampler

completion_sampler = CompletionSampler(
    client,
)

In [2]:
import asyncio
from lib.rl.completion import Completion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from typing import Callable, Coroutine, Optional


class Episode:
    def __init__(
        self,
        messages: list[ChatCompletionMessageParam],
        on_sample: Callable[[list[Completion]], None | Coroutine[None, None, None]],
        get_easier_episode: Optional[
            tuple[float, Callable[[], "Episode" | Coroutine[None, None, "Episode"]]]
        ] = None,
        get_similar_episode: Optional[
            Callable[[], "Episode" | Coroutine[None, None, "Episode"]]
        ] = None,
        get_harder_episode: Optional[
            tuple[float, Callable[[], "Episode" | Coroutine[None, None, "Episode"]]]
        ] = None,
    ) -> None:
        self.completion = Completion(messages=messages)  # type: ignore
        self.on_sample = on_sample
        self.min_value = (get_easier_episode or [None])[0]
        self.max_value = (get_harder_episode or [None])[0]
        self.get_easier_episode = (get_easier_episode or [None, None])[1]
        self.get_similar_episode = get_similar_episode
        self.get_harder_episode = (get_harder_episode or [None, None])[1]
        self.weight = 1.0
        self.task = asyncio.create_task(asyncio.sleep(0))

In [None]:
async def on_sample(completions) -> None: ...


episode = Episode(
    messages=[
        {"role": "user", "content": "Hello, how are you doing today?"},
    ],
    on_sample=on_sample,
)

In [3]:
async def sample_random_episode() -> Episode:
    ...

In [23]:
import numpy as np

array([0.13405539, 0.13405539, 0.13405539, 0.13405539, 0.13405539,
       0.32972305])

In [None]:
from lib.tokenizer import Tokenizer

tokenizer = Tokenizer("")

In [None]:
import random


class EpisodeSampler:
    def __init__(
        self,
        sample: Callable[[], Episode | Coroutine[None, None, Episode]],
    ) -> None:
        self.sample = sample
        self.num_samples = 0
        self.num_goldilocks = 0

    def goldilocks_rate(self, prior: float, effective_sample_size: float) -> float:
        return (self.num_goldilocks + prior * effective_sample_size) / (
            self.num_samples + effective_sample_size
        )


branch_factor = 2
min_requests = 10
abs_buffer_size = 100
weighted_buffer_size = 200
buffer: list[Episode] = []
min_random_episode_sample_probability_half_life = 100
exploitation_factor = 1.0
random_sampler = EpisodeSampler(sample_random_episode)
other_samplers: list[EpisodeSampler] = []


def goldilocks_rate_prior_and_effective_sample_size() -> tuple[float, float]:
    num_goldilocks = random_sampler.num_goldilocks + sum(
        s.num_goldilocks for s in other_samplers
    )
    num_samples = random_sampler.num_samples + sum(
        s.num_samples for s in other_samplers
    )
    return (
        num_goldilocks / num_samples
        if num_goldilocks != 0 and num_samples != 0
        else 1.0
    ), max(num_samples / (len(other_samplers) + 1), 1)


async def sample_completions(episode: Episode) -> None:
    if episode.completion.children:
        parent = episode.completion
    else:
        try:
            leaf = max(
                (
                    completion
                    for completion in episode.completion.leaves()
                    if any(
                        c.can_split() for c in completion.ancestors(including_self=True)
                    )
                ),
                key=lambda c: c.all_abs_advantage() / c.all_token_count(tokenizer),
            )
            parent = max(
                (c for c in leaf.ancestors(including_self=True) if c.can_split()),
                key=lambda c: abs(c.advantage()) * c.token_count(tokenizer),
            )
            assert parent.split(by="count")
        except:
            episode.task = asyncio.create_task(asyncio.sleep(float("inf")))
            return
    completions = await completion_sampler.sample_completions(
        parent,
        n=branch_factor,
    )
    on_sample = episode.on_sample(completions)
    if isinstance(on_sample, Coroutine):
        await on_sample


async def get_episode() -> None:
    if not other_samplers:
        sampler = random_sampler
    else:
        prior, effective_sample_size = goldilocks_rate_prior_and_effective_sample_size()
        min_random_goldilocks_rate = 1.0 * np.exp(
            -np.log(2)
            / min_random_episode_sample_probability_half_life
            * random_sampler.num_samples
        )
        random_goldilocks_rate = max(
            random_sampler.goldilocks_rate(prior, effective_sample_size),
            min_random_goldilocks_rate,
        )
        other_goldilocks_rates = np.array(
            [
                sampler.goldilocks_rate(prior, effective_sample_size)
                for sampler in other_samplers
            ]
        )
        other_sampler_weights = other_goldilocks_rates**exploitation_factor
        other_sampler_weights /= other_sampler_weights.sum()
        other_expected_goldilocks_rate = other_goldilocks_rates @ other_sampler_weights
        hierachical_weights = (
            np.array([random_goldilocks_rate, other_expected_goldilocks_rate])
            ** exploitation_factor
        )
        hierachical_weights /= hierachical_weights.sum()
        if random.random() < hierachical_weights[0]:
            sampler = random_sampler
        else:
            sampler = random.choices(other_samplers, weights=other_sampler_weights)[0]
    episode = sampler.sample()
    if isinstance(episode, Coroutine):
        placeholder = Episode(
            messages=[],
            on_sample=lambda _: None,
        )
        buffer.append(
            placeholder,
        )
        try:
            episode = await episode
        finally:
            buffer.remove(placeholder)
    episode.task = asyncio.create_task(sample_completions(episode))
    buffer.append(episode)
    try:
        await episode.task
    except BaseException as e:
        buffer.remove(episode)
        raise e
    if not episode.completion.children:
        return buffer.remove(episode)
    sampler.num_samples += 1

    if (
        episode.get_easier_episode
        and episode.min_value is not None
        and episode.completion.value() <= episode.min_value
    ):
        other_samplers.append(
            EpisodeSampler(
                episode.get_easier_episode,
            )
        )
        return buffer.remove(episode)
    elif (
        episode.get_harder_episode
        and episode.max_value is not None
        and episode.completion.value() >= episode.max_value
    ):
        other_samplers.append(
            EpisodeSampler(
                episode.get_harder_episode,
            )
        )
        return buffer.remove(episode)
    elif all(c.advantage() == 0 for c in episode.completion.children):
        return buffer.remove(episode)
    elif episode.get_similar_episode:
        other_samplers.append(
            EpisodeSampler(
                episode.get_similar_episode,
            )
        )
    sampler.num_goldilocks += 1


async def enrich_episode() -> None:
    try:
        episode = min(
            (episode for episode in buffer if episode.task.done()),
            key=lambda episode: len(list(episode.completion.descendants())),
        )
        episode.task = asyncio.create_task(sample_completions(episode))
    except ValueError:
        await get_episode()


from pydantic import BaseModel
import torch


class Request(BaseModel):
    tokens_filename: str
    advantages_filename: str
    logprobs_filename: str
    rows: int
    seqlen: int
    start: int
    stop: int


requests: list[Request] = []


class Trajectory(BaseModel):
    episode: Episode
    terminus: Completion
    abs_advantage: float
    token_count: int

    def score(self) -> float:
        return self.episode.weight * self.abs_advantage / self.token_count


def best_trajectory(episode: Episode) -> Trajectory:
    return max(
        (
            Trajectory(
                episode=episode,
                terminus=completion,
                abs_advantage=completion.all_abs_advantage(),
                token_count=completion.all_token_count(tokenizer),
            )
            for completion in episode.completion.leaves()
        ),
        key=lambda t: t.abs_advantage / t.token_count,
    )


def handle_request(request: Request) -> None:
    tokens = torch.from_file(
        request.tokens_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.int64,
    ).view(-1, request.seqlen)
    advantages = torch.from_file(
        request.advantages_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    logprobs = torch.from_file(
        request.logprobs_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    trajectories = sorted(
        (best_trajectory(episode) for episode in buffer),
        key=lambda t: t.score(),
        reverse=True,
    )
    for row in range(request.start, request.stop):
        selected_trajectories: list[Trajectory] = []
        for i in range(0, len(trajectories), -1):
            if (
                trajectories[i].token_count
                + sum(t.token_count for t in selected_trajectories)
                > request.seqlen
            ):
                continue
            selected_trajectories.append(trajectories.pop(i))
        tokens[row] = tokenizer.encode(
            [
                trajectory.terminus.all_message_params()
                for trajectory in selected_trajectories
            ],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        )


while True:
    await asyncio.sleep(1)
    for request in requests:
        handle_request(request)
    requests = []
    running, pending = vllm_server_metrics()
    for _ in range(0, running - pending + min_requests, branch_factor * 2):
        if (
            len(buffer) < abs_buffer_size
            or sum(e.weight for e in buffer) < weighted_buffer_size
        ):
            asyncio.create_task(get_episode())
        else:
            asyncio.create_task(enrich_episode())