In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import asyncio
from lib.clue import Clue, DeductiveSolver
from lib.rl import Completion, CompletionSampler, Episode, EpisodeSampler, EpisodeSamplerRouter
from lib.tokenizer import Tokenizer
from lib.vllm import start_vllm_server, vllm_server_metrics
import numpy as np
import random
import re
from typing import Callable, Coroutine

In [2]:
model = "NousResearch/Hermes-2-Theta-Llama-3-8B"

shutdown_server, client = await start_vllm_server(
    disable_log_requests=True,
    model=model,
)

No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.55it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.29it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.69it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.55it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.53it/s]



In [181]:
shutdown_server()

In [8]:
completion_sampler = CompletionSampler(
    client,
    model=model,
)

In [10]:
def sample_random_episode() -> Episode:
    game = Clue(
        num_players=3,
        elements={
            "suspect": Clue.suspects[:3],
            "weapon": Clue.weapons[:3],
            "room": Clue.rooms[:3],
            # "motive": Clue.motives[:6],
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.05,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=False,
    )
    prompt = game.get_prompt()
    follow_up = "Fill out your answer like this:\n" + "\n".join(
        f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements
    )

    async def reward_completion(completion: Completion) -> None:
        chat_completion = await client.chat.completions.create(
            messages=completion.all_message_params()
            + [
                {"role": "user", "content": follow_up},
            ],
            model=model,
            temperature=0.0,
        )
        answer = chat_completion.choices[0].message.content
        assert answer
        completion.reward = sum(
            [
                bool(
                    re.search(
                        f"{element}: {solution}",
                        answer,
                        re.IGNORECASE,
                    )
                )
                for element, solution in game.solution.items()
            ]
        ) / len(game.solution)

    async def on_sample(completions: list[Completion]) -> None:
        await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        )
        for completion in completions:
            completion.commit()

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )

In [12]:
tokenizer = Tokenizer(model)

In [None]:
tokenizer.llm.get_tokenizer()

In [13]:
branch_factor = 3
abs_buffer_size = 20
weighted_buffer_size = 40
buffer: list[Episode] = []
max_running = 100
episode_sampler_router = EpisodeSamplerRouter(
    EpisodeSampler(sample_random_episode),
    exploitation_factor=1.0,
    min_random_episode_sample_probability_half_life=40,
)

In [118]:
async def sample_completions(episode: Episode) -> None:
    if episode.completion.children:
        try:
            leaf = max(
                (
                    completion
                    for completion in episode.completion.leaves()
                    if any(
                        c.can_split() for c in completion.ancestors(including_self=True)
                    )
                ),
                key=lambda c: c.all_abs_advantage() / c.all_token_count(tokenizer),
            )
            parent = max(
                (c for c in leaf.ancestors(including_self=True) if c.can_split()),
                key=lambda c: abs(c.advantage()) * c.token_count(tokenizer),
            )
            assert parent.split(by="count"), "Unable to split completion"
        except BaseException as e:
            print(type(e), e)
            episode.task = asyncio.create_task(asyncio.sleep(float("inf")))
            return
    else:
        parent = episode.completion
    completions = await completion_sampler.sample_completions(
        parent,
        n=branch_factor - len(parent.children),
    )
    on_sample = episode.on_sample(completions)
    if isinstance(on_sample, Coroutine):
        await on_sample


async def get_episode() -> None:
    sampler = episode_sampler_router.get_sampler()
    episode = sampler.sample()
    if isinstance(episode, Coroutine):
        placeholder = Episode(
            messages=[],
            on_sample=lambda _: None,
        )
        buffer.append(placeholder)
        try:
            episode = await episode
        finally:
            buffer.remove(placeholder)
    episode.task = asyncio.create_task(sample_completions(episode))
    buffer.append(episode)
    try:
        await episode.task
    except BaseException as e:
        buffer.remove(episode)
        raise e
    if not episode.completion.children:
        return buffer.remove(episode)
    sampler.num_samples += 1

    if (
        episode.get_easier_episode
        and episode.min_value is not None
        and episode.completion.value() <= episode.min_value
    ):
        episode_sampler_router.other_samplers.append(
            EpisodeSampler(
                episode.get_easier_episode,
            )
        )
        return buffer.remove(episode)
    elif (
        episode.get_harder_episode
        and episode.max_value is not None
        and episode.completion.value() >= episode.max_value
    ):
        episode_sampler_router.other_samplers.append(
            EpisodeSampler(
                episode.get_harder_episode,
            )
        )
        return buffer.remove(episode)
    elif all(c.advantage() == 0 for c in episode.completion.children):
        return buffer.remove(episode)
    elif episode.get_similar_episode:
        episode_sampler_router.other_samplers.append(
            EpisodeSampler(
                episode.get_similar_episode,
            )
        )
    sampler.num_goldilocks += 1


async def enrich_episode() -> None:
    try:
        episode = min(
            (episode for episode in buffer if episode.task.done()),
            key=lambda episode: episode.num_samples(),
        )
        episode.task = asyncio.create_task(sample_completions(episode))
    except ValueError:
        pass


async def prepare_episodes() -> None:
    while True:
        await asyncio.sleep(2)
        running, pending = vllm_server_metrics()
        global max_running
        max_running = max(max_running, running)
        for _ in range(0, max_running - pending, branch_factor * 2):
            if (
                len(buffer) < abs_buffer_size
                or sum(e.weight for e in buffer) < weighted_buffer_size
            ):
                asyncio.create_task(get_episode())
            else:
                asyncio.create_task(enrich_episode())


prepare_episodes_task = asyncio.create_task(prepare_episodes())

In [117]:
prepare_episodes_task.cancel()

True

In [132]:
vllm_server_metrics()

(39, 0)

In [113]:
len(buffer) < abs_buffer_size or sum(e.weight for e in buffer) < weighted_buffer_size

False

In [120]:
len(buffer)

40

In [133]:
sum([len(list(episode.completion.descendants())) for episode in buffer]) / len(buffer)

46.65

In [122]:
len([episode for episode in buffer if episode.task.done()])

3

In [None]:
await prepare_episodes_task

In [137]:
completion = next(buffer[0].completion.leaves())

In [177]:
completion = max(buffer[0].completion.leaves(), key=lambda c: c.all_abs_advantage() / c.all_token_count(tokenizer))

In [154]:
len(list(completion.ancestors()))

42

In [174]:
tokenizer.decode(tokenizer.encode(completion.parent.message_params()))

'<|begin_of_text|><|im_start|>assistant\n suspect in the Dining Room.'

In [176]:
for completion in buffer[0].completion.descendants():
    completion._token_count = None

In [180]:
print(completion.all_message_params()[-1]["content"])

Let's analyze the cards shown to each player:

1. Blake shows the Lounge to Robert.
2. Robert shows the Knife to Blake.
3. Joel shows the Knife to Blake.
4. Blake shows the Candlestick to Robert.
5. Joel shows a card to Blake (must be Mrs. White or the Lead Pipe).
6. Blake shows a card to Joel (must be Mr. Green or the Dining Room).
7. Joel shows the Dining Room to Robert.
8. Joel shows a card to Blake (must be the Lead Pipe).
9. Blake shows a card to Robert (must be Mr. Green).

From the information above, we can determine the following:

- Robert has Miss Scarlet and the Knife.
- Joel has Mrs. White and the Dining Room.
- Blake has Mr. Green and the Candlestick.

Since Blake showed the Candlestick to Robert, and Robert has the Knife, we know that the Candlestick is not the Knife. Therefore, the Candlestick must be the weapon in the Lounge.

Similarly, since Blake showed a card to Joel (the Lead Pipe), and Joel has the Dining Room, we know that the Lead Pipe is not the Dining Room. Th

In [141]:
from dataclasses import dataclass
from pydantic import BaseModel
import torch

@dataclass
class Trajectory:
    episode: Episode
    terminus: Completion
    abs_advantage: float
    token_count: int

    def score(self) -> float:
        return self.episode.weight * self.abs_advantage / self.token_count


def best_trajectory(episode: Episode) -> Trajectory:
    return max(
        (
            Trajectory(
                episode=episode,
                terminus=completion,
                abs_advantage=completion.all_abs_advantage(),
                token_count=completion.all_token_count(tokenizer),
            )
            for completion in episode.completion.leaves()
        ),
        key=lambda t: t.abs_advantage / t.token_count,
    )

In [144]:
2 ** 13

8192

In [147]:
seqlen = 8192
rows = 64

tokens_file = torch.empty(rows * seqlen, dtype=torch.int64)
tokens_file.numpy().tofile("/tmp/tokens.bin")

advantages_file = torch.empty(rows * seqlen, dtype=torch.float32)
advantages_file.numpy().tofile("/tmp/advantages.bin")

logprobs_file = torch.empty(rows * seqlen, dtype=torch.float32)
logprobs_file.numpy().tofile("/tmp/logprobs.bin")

In [157]:
class Request(BaseModel):
    tokens_filename: str
    advantages_filename: str
    logprobs_filename: str
    rows: int
    seqlen: int
    start: int
    stop: int


def handle_request(request: Request) -> None:
    tokens = torch.from_file(
        request.tokens_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.int64,
    ).view(-1, request.seqlen)
    advantages = torch.from_file(
        request.advantages_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    logprobs = torch.from_file(
        request.logprobs_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    trajectories = sorted(
        (best_trajectory(episode) for episode in buffer),
        key=lambda t: t.score(),
    )
    for row in range(request.start, request.stop):
        selected_trajectories: list[Trajectory] = []
        for i in range(0, len(trajectories), -1):
            if (
                trajectories[i].token_count
                + sum(t.token_count for t in selected_trajectories)
                > request.seqlen
            ):
                continue
            selected_trajectories.append(trajectories.pop(i))
        tokens[row] = tokenizer.encode(
            [
                trajectory.terminus.all_message_params()
                for trajectory in selected_trajectories
            ],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        )
        replacement_token = "<|reserved_special_token_250|>"
        mask = tokenizer.encode(
            [trajectory.terminus.all_message_params(replacement_token=replacement_token) for trajectory in selected_trajectories],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        ) == tokenizer.get_token_id(replacement_token)
        mask_size = mask.sum()
        advantages[row] = torch.full_like(mask, fill_value=torch.nan, dtype=torch.float32)
        advantages[row][mask] = torch.tensor(
            list(
                advantage
                for trajectory in selected_trajectories
                for advantage in trajectory.terminus.all_token_advantages()
            )[:mask_size]
        )
        logprobs[row] = torch.full_like(mask, fill_value=torch.nan, dtype=torch.float32)
        logprobs[row][mask] = torch.tensor(
            list(
                advantage
                for trajectory in selected_trajectories
                for advantage in trajectory.terminus.all_logprobs()
            )[:mask_size]
        )

handle_request(
    Request(
        tokens_filename="/tmp/tokens.bin",
        advantages_filename="/tmp/advantages.bin",
        logprobs_filename="/tmp/logprobs.bin",
        rows=64,
        seqlen=8192,
        start=0,
        stop=1,
    )
)

IndexError: list index out of range