In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lib.vllm import start_vllm_server, vllm_server_metrics

model = "NousResearch/Hermes-2-Theta-Llama-3-8B"

shutdown_server, client = await start_vllm_server(
    disable_log_requests=True,
    model=model,
)

No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.43it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.37it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  2.00it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.69it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.66it/s]



In [None]:
shutdown_server()

In [3]:
from lib.rl.sampler import CompletionSampler

completion_sampler = CompletionSampler(
    client,
    model=model,
)

In [4]:
import asyncio
from lib.rl.completion import Completion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from typing import Callable, Coroutine, Optional


class Episode:
    def __init__(
        self,
        messages: list[ChatCompletionMessageParam],
        on_sample: Callable[[list[Completion]], None | Coroutine[None, None, None]],
        get_easier_episode: Optional[
            tuple[float, Callable[[], "Episode" | Coroutine[None, None, "Episode"]]]
        ] = None,
        get_similar_episode: Optional[
            Callable[[], "Episode" | Coroutine[None, None, "Episode"]]
        ] = None,
        get_harder_episode: Optional[
            tuple[float, Callable[[], "Episode" | Coroutine[None, None, "Episode"]]]
        ] = None,
    ) -> None:
        self.completion = Completion(messages=messages)  # type: ignore
        self.on_sample = on_sample
        self.min_value = (get_easier_episode or [None])[0]
        self.max_value = (get_harder_episode or [None])[0]
        self.get_easier_episode = (get_easier_episode or [None, None])[1]
        self.get_similar_episode = get_similar_episode
        self.get_harder_episode = (get_harder_episode or [None, None])[1]
        self.weight = 1.0
        self.task = asyncio.create_task(asyncio.sleep(0))

In [5]:
from lib.clue import Clue, DeductiveSolver
import re


def sample_random_episode() -> Episode:
    game = Clue(
        num_players=3,
        elements={
            "suspect": Clue.suspects[:3],
            "weapon": Clue.weapons[:3],
            "room": Clue.rooms[:3],
            # "motive": Clue.motives[:6],
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.05,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=False,
    )
    prompt = game.get_prompt()
    follow_up = "Fill out your answer like this:\n" + "\n".join(
        f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements
    )

    async def reward_completion(completion: Completion) -> None:
        chat_completion = await client.chat.completions.create(
            messages=completion.all_message_params()
            + [
                {"role": "user", "content": follow_up},
            ],
            model=model,
            temperature=0.0,
        )
        answer = chat_completion.choices[0].message.content
        assert answer
        completion.reward = sum(
            [
                bool(
                    re.search(
                        f"{element}: {solution}",
                        answer,
                        re.IGNORECASE,
                    )
                )
                for element, solution in game.solution.items()
            ]
        ) / len(game.solution)

    async def on_sample(completions: list[Completion]) -> None:
        await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        )
        for completion in completions:
            completion.commit()

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )

In [6]:
import numpy as np

In [7]:
from lib.tokenizer import Tokenizer

tokenizer = Tokenizer(model)

In [177]:
tokenizer.get_token_id("A")

32

In [8]:
import random


class EpisodeSampler:
    def __init__(
        self,
        sample: Callable[[], Episode | Coroutine[None, None, Episode]],
    ) -> None:
        self.sample = sample
        self.num_samples = 0
        self.num_goldilocks = 0

    def goldilocks_rate(self, prior: float, effective_sample_size: float) -> float:
        return (self.num_goldilocks + prior * effective_sample_size) / (
            self.num_samples + effective_sample_size
        )


branch_factor = 2
min_requests = 5
abs_buffer_size = 40
weighted_buffer_size = 80
buffer: list[Episode] = []
min_random_episode_sample_probability_half_life = 80
exploitation_factor = 1.0
random_sampler = EpisodeSampler(sample_random_episode)
other_samplers: list[EpisodeSampler] = []

In [10]:
def goldilocks_rate_prior_and_effective_sample_size() -> tuple[float, float]:
    num_goldilocks = random_sampler.num_goldilocks + sum(
        s.num_goldilocks for s in other_samplers
    )
    num_samples = random_sampler.num_samples + sum(
        s.num_samples for s in other_samplers
    )
    return (
        num_goldilocks / num_samples
        if num_goldilocks != 0 and num_samples != 0
        else 1.0
    ), max(num_samples / (len(other_samplers) + 1), 1)


async def sample_completions(episode: Episode) -> None:
    if episode.completion.children:
        try:
            leaf = max(
                (
                    completion
                    for completion in episode.completion.leaves()
                    if any(
                        c.can_split() for c in completion.ancestors(including_self=True)
                    )
                ),
                key=lambda c: c.all_abs_advantage() / c.all_token_count(tokenizer),
            )
            parent = max(
                (c for c in leaf.ancestors(including_self=True) if c.can_split()),
                key=lambda c: abs(c.advantage()) * c.token_count(tokenizer),
            )
            assert parent.split(by="count"), "Unable to split completion"
        except BaseException as e:
            print(type(e), e)
            episode.task = asyncio.create_task(asyncio.sleep(float("inf")))
            return
    else:
        parent = episode.completion
    completions = await completion_sampler.sample_completions(
        parent,
        n=branch_factor,
    )
    on_sample = episode.on_sample(completions)
    if isinstance(on_sample, Coroutine):
        await on_sample


async def get_episode() -> None:
    if not other_samplers:
        sampler = random_sampler
    else:
        prior, effective_sample_size = goldilocks_rate_prior_and_effective_sample_size()
        min_random_goldilocks_rate = 1.0 * np.exp(
            -np.log(2)
            / min_random_episode_sample_probability_half_life
            * random_sampler.num_samples
        )
        random_goldilocks_rate = max(
            random_sampler.goldilocks_rate(prior, effective_sample_size),
            min_random_goldilocks_rate,
        )
        other_goldilocks_rates = np.array(
            [
                sampler.goldilocks_rate(prior, effective_sample_size)
                for sampler in other_samplers
            ]
        )
        other_sampler_weights = other_goldilocks_rates**exploitation_factor
        other_sampler_weights /= other_sampler_weights.sum()
        other_expected_goldilocks_rate = other_goldilocks_rates @ other_sampler_weights
        hierachical_weights = (
            np.array([random_goldilocks_rate, other_expected_goldilocks_rate])
            ** exploitation_factor
        )
        hierachical_weights /= hierachical_weights.sum()
        if random.random() < hierachical_weights[0]:
            sampler = random_sampler
        else:
            sampler = random.choices(other_samplers, weights=other_sampler_weights)[0]
    episode = sampler.sample()
    if isinstance(episode, Coroutine):
        placeholder = Episode(
            messages=[],
            on_sample=lambda _: None,
        )
        buffer.append(
            placeholder,
        )
        try:
            episode = await episode
        finally:
            buffer.remove(placeholder)
    episode.task = asyncio.create_task(sample_completions(episode))
    buffer.append(episode)
    try:
        await episode.task
    except BaseException as e:
        buffer.remove(episode)
        raise e
    if not episode.completion.children:
        return buffer.remove(episode)
    sampler.num_samples += 1

    if (
        episode.get_easier_episode
        and episode.min_value is not None
        and episode.completion.value() <= episode.min_value
    ):
        other_samplers.append(
            EpisodeSampler(
                episode.get_easier_episode,
            )
        )
        return buffer.remove(episode)
    elif (
        episode.get_harder_episode
        and episode.max_value is not None
        and episode.completion.value() >= episode.max_value
    ):
        other_samplers.append(
            EpisodeSampler(
                episode.get_harder_episode,
            )
        )
        return buffer.remove(episode)
    elif all(c.advantage() == 0 for c in episode.completion.children):
        return buffer.remove(episode)
    elif episode.get_similar_episode:
        other_samplers.append(
            EpisodeSampler(
                episode.get_similar_episode,
            )
        )
    sampler.num_goldilocks += 1


async def enrich_episode() -> None:
    try:
        episode = min(
            (episode for episode in buffer if episode.task.done()),
            key=lambda episode: len(list(episode.completion.descendants())),
        )
        episode.task = asyncio.create_task(sample_completions(episode))
    except ValueError:
        await get_episode()


async def prepare_episodes() -> None:
    while True:
        await asyncio.sleep(5)
        running, pending = vllm_server_metrics()
        for _ in range(0, running - pending + min_requests, branch_factor * 2):
            if (
                len(buffer) < abs_buffer_size
                or sum(e.weight for e in buffer) < weighted_buffer_size
            ):
                asyncio.create_task(get_episode())
            else:
                asyncio.create_task(enrich_episode())


prepare_episodes_task = asyncio.create_task(prepare_episodes())

In [75]:
prepare_episodes_task.cancel()

True

In [85]:
vllm_server_metrics()

(0, 0)

In [83]:
len(buffer) < abs_buffer_size or sum(e.weight for e in buffer) < weighted_buffer_size

False

In [84]:
len(buffer)

83

In [81]:
sum([len(list(episode.completion.descendants())) for episode in buffer]) / len(buffer)

18.120481927710845

In [82]:
len([episode for episode in buffer if episode.task.done()])

83

In [None]:
await prepare_episodes_task

In [86]:
from dataclasses import dataclass
from pydantic import BaseModel
import torch

@dataclass
class Trajectory:
    episode: Episode
    terminus: Completion
    abs_advantage: float
    token_count: int

    def score(self) -> float:
        return self.episode.weight * self.abs_advantage / self.token_count


def best_trajectory(episode: Episode) -> Trajectory:
    return max(
        (
            Trajectory(
                episode=episode,
                terminus=completion,
                abs_advantage=completion.all_abs_advantage(),
                token_count=completion.all_token_count(tokenizer),
            )
            for completion in episode.completion.leaves()
        ),
        key=lambda t: t.abs_advantage / t.token_count,
    )

In [87]:
trajectories = sorted(
    (best_trajectory(episode) for episode in buffer),
    key=lambda t: t.score(),
    reverse=False,
)

In [97]:
trajectory = trajectories[-1]

In [98]:
[c.advantage() for c in trajectory.terminus.ancestors(including_self=True)]

[0.5555555555555556,
 -0.37037037037037046,
 0.5432098765432098,
 -0.36419753086419754,
 0.0]

In [174]:
tokenizer.llm.get_tokenizer().added_tokens_encoder

{'<|begin_of_text|>': 128000,
 '<|end_of_text|>': 128001,
 '<|im_start|>': 128002,
 '<|im_end|>': 128003,
 '<tool_call>': 128004,
 '<|reserved_special_token_3|>': 128005,
 '<|start_header_id|>': 128006,
 '<|end_header_id|>': 128007,
 '<tools>': 128008,
 '<|eot_id|>': 128009,
 '</tools>': 128010,
 '</tool_call>': 128011,
 '</tool_response>': 128012,
 '<|reserved_special_token_8|>': 128013,
 '<|reserved_special_token_9|>': 128014,
 '<|reserved_special_token_10|>': 128015,
 '<|reserved_special_token_11|>': 128016,
 '<|reserved_special_token_12|>': 128017,
 '<|reserved_special_token_13|>': 128018,
 '<|reserved_special_token_14|>': 128019,
 '<|reserved_special_token_15|>': 128020,
 '<|reserved_special_token_16|>': 128021,
 '<|reserved_special_token_17|>': 128022,
 '<|reserved_special_token_18|>': 128023,
 '<|reserved_special_token_19|>': 128024,
 '<|reserved_special_token_20|>': 128025,
 '<|reserved_special_token_21|>': 128026,
 '<|reserved_special_token_22|>': 128027,
 '<|reserved_special_

In [211]:
replacement_token = "<|reserved_special_token_250|>"
seqlen = 1000
mask = tokenizer.encode(
    [trajectory.terminus.all_message_params(replacement_token=replacement_token)],  # type: ignore
    concatenate=True,
    seqlen=seqlen,
) == tokenizer.get_token_id(replacement_token)
logprobs = torch.full_like(mask, fill_value=torch.nan, dtype=torch.float32)
logprobs[mask] = torch.tensor(list(trajectory.terminus.all_logprobs())[:mask.sum()])
logprobs

tensor([        nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,     

In [203]:
mask.sum()

tensor(315)

In [200]:
torch.tensor(list(trajectory.terminus.all_logprobs())).shape

torch.Size([316])

In [112]:
tokenizer.llm.get_tokenizer().added_tokens_decoder

{128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 128002: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 128003: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 128004: AddedToken("<tool_call>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
 128005: AddedToken("<|reserved_special_token_3|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
 128006: AddedToken("<|start_header_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 128007: AddedToken("<|end_header_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 128008: AddedToken("<tools>", rstrip

In [114]:
tokenizer.llm.get_tokenizer().encode("<|reserved_special_token_250|>", add_special_tokens=False)

[128255]

In [127]:
%%timeit
"<|reserved_special_token_250|>" * 1000

335 ns ± 6.33 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [121]:
tokenizer.encode([{"role": param["role"], "content": param["content"].replace("On a cool spring day", "<|reserved_special_token_250|>")} for param in trajectory.terminus.all_message_params()])[:8]

tensor([128000, 128000, 128002,    882,    198, 128255,  66294,     11])

In [103]:
tokenizer.decode(tokenizer.encode(trajectory.terminus.all_message_params(), continue_final_message=False)[:5])

'<|begin_of_text|><|begin_of_text|><|im_start|>user\n'

In [None]:


class Request(BaseModel):
    tokens_filename: str
    advantages_filename: str
    logprobs_filename: str
    rows: int
    seqlen: int
    start: int
    stop: int


def handle_request(request: Request) -> None:
    tokens = torch.from_file(
        request.tokens_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.int64,
    ).view(-1, request.seqlen)
    advantages = torch.from_file(
        request.advantages_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    logprobs = torch.from_file(
        request.logprobs_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    trajectories = sorted(
        (best_trajectory(episode) for episode in buffer),
        key=lambda t: t.score(),
        reverse=True,
    )
    for row in range(request.start, request.stop):
        selected_trajectories: list[Trajectory] = []
        for i in range(0, len(trajectories), -1):
            if (
                trajectories[i].token_count
                + sum(t.token_count for t in selected_trajectories)
                > request.seqlen
            ):
                continue
            selected_trajectories.append(trajectories.pop(i))
        tokens[row] = tokenizer.encode(
            [
                trajectory.terminus.all_message_params()
                for trajectory in selected_trajectories
            ],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        )