In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lib.vllm import start_vllm_server, vllm_server_metrics

model = "NousResearch/Hermes-2-Theta-Llama-3-8B"

shutdown_server, client = await start_vllm_server(
    disable_log_requests=True,
    model=model,
)

No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.57it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.30it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.70it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.57it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.55it/s]



In [4]:
import asyncio
from lib.clue import Clue, DeductiveSolver
from lib.rl import (
    Completion,
    CompletionSampler,
    Episode,
    EpisodeBuffer,
    EpisodeSampler,
    EpisodeSamplerRouter,
)
from lib.tokenizer import Tokenizer
import re

In [5]:
shutdown_server()

True

In [5]:
def sample_random_episode() -> Episode:
    game = Clue(
        num_players=3,
        elements={
            "suspect": Clue.suspects[:3],
            "weapon": Clue.weapons[:3],
            "room": Clue.rooms[:3],
            # "motive": Clue.motives[:6],
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.05,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=False,
    )
    prompt = game.get_prompt()
    follow_up = "Fill out your answer like this:\n" + "\n".join(
        f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements
    )

    async def reward_completion(completion: Completion) -> None:
        chat_completion = await client.chat.completions.create(
            messages=completion.all_message_params()
            + [
                {"role": "user", "content": follow_up},
            ],
            model=model,
            temperature=0.0,
        )
        answer = chat_completion.choices[0].message.content
        assert answer
        completion.reward = sum(
            [
                bool(
                    re.search(
                        f"{element}: {solution}",
                        answer,
                        re.IGNORECASE,
                    )
                )
                for element, solution in game.solution.items()
            ]
        ) / len(game.solution)

    async def on_sample(completions: list[Completion]) -> None:
        await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        )
        for completion in completions:
            completion.commit()

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )

In [6]:
completion_sampler = CompletionSampler(
    client,
    model=model,
)

tokenizer = Tokenizer(model)

episode_buffer = EpisodeBuffer(
    episode_sampler_router=EpisodeSamplerRouter(
        EpisodeSampler(sample_random_episode),
        exploitation_factor=1.0,
        min_random_episode_sample_probability_half_life=40,
    ),
    completion_sampler=completion_sampler,
    tokenizer=tokenizer,
    branch_factor=2,
    split_method="count",
    size=512,
    episode_decay=0.7,
    completion_decay=0.7,
)

In [7]:
episode_buffer.size = 512

In [196]:
episode_buffer.stop_buffering()

True

In [201]:
running, pending = vllm_server_metrics()
print(f"Running/Pending/HWM vLLM Requests: {running}/{pending}/{episode_buffer.max_running_requests}")
print("Number of Episodes:", len(episode_buffer.episodes))
print(
    "Pending Tasks:",
    sum(task._state == "PENDING" for task in episode_buffer.tasks.values()),
)
print(
    "Sampled Completions:",
    sum(episode.num_samples() for episode in episode_buffer.episodes),
)
print(
    "Average Absolute Advantage Per Token:",
    sum(
        episode.best_leaf(tokenizer).all_abs_advantage_per_token(tokenizer)
        for episode in episode_buffer.episodes
    )
    / len(episode_buffer.episodes),
)

Running/Pending/HWM vLLM Requests: 0/0/162
Number of Episodes: 512
Pending Tasks: 0
Sampled Completions: 7713
Average Absolute Advantage Per Token: 0.0015422665585660134


In [None]:
from fastapi import FastAPI
from pydantic import BaseModel
import torch

app = FastAPI()


class Request(BaseModel):
    tokens_filename: str
    advantages_filename: str
    logprobs_filename: str
    rows: int
    seqlen: int
    start: int
    stop: int


class Trajectory:
    episode: Episode
    terminus: Completion
    abs_advantage: float
    token_count: int

    def __init__(self, episode: Episode, tokenizer: Tokenizer) -> None:
        self.episode = episode
        self.terminus = episode.best_leaf(tokenizer)
        self.abs_advantage = self.terminus.all_abs_advantage()
        self.token_count = self.terminus.all_token_count(tokenizer)

    def score(self) -> float:
        return self.episode.weight * self.abs_advantage / self.token_count


@app.post("/write-trajectories")
def write_trajectories(request: Request) -> None:
    tokens = torch.from_file(
        request.tokens_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.int64,
    ).view(-1, request.seqlen)
    advantages = torch.from_file(
        request.advantages_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    logprobs = torch.from_file(
        request.logprobs_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    trajectories: list[Trajectory] = []
    for row in range(request.start, request.stop):
        if not trajectories:
            trajectories = sorted(
                (
                    Trajectory(episode=episode, tokenizer=tokenizer)
                    for episode in episode_buffer.episodes
                ),
                key=lambda trajectory: trajectory.score(),
            )
        selected_trajectories: list[Trajectory] = []
        for i in range(0, len(trajectories), -1):
            if (
                trajectories[i].token_count
                + sum(t.token_count for t in selected_trajectories)
                > request.seqlen
            ):
                continue
            selected_trajectories.append(trajectories.pop(i))
        for trajectory in selected_trajectories:
            trajectory.episode.weight *= episode_buffer.episode_decay
            for completion in trajectory.terminus.ancestors(including_self=True):
                completion.weight *= episode_buffer.completion_decay
        tokens[row] = tokenizer.encode(
            [
                trajectory.terminus.all_message_params()
                for trajectory in selected_trajectories
            ],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        )
        replacement_token = "<|reserved_special_token_250|>"
        mask = tokenizer.encode(
            [trajectory.terminus.all_message_params(replacement_token=replacement_token) for trajectory in selected_trajectories],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        ) == tokenizer.get_token_id(replacement_token)
        mask_size = mask.sum()
        advantages[row] = torch.full_like(
            mask, fill_value=torch.nan, dtype=torch.float32
        )
        advantages[row][mask] = torch.tensor(
            list(
                advantage
                for trajectory in selected_trajectories
                for advantage in trajectory.terminus.all_token_advantages()
            )[:mask_size]
        )
        logprobs[row] = torch.full_like(mask, fill_value=torch.nan, dtype=torch.float32)
        logprobs[row][mask] = torch.tensor(
            list(
                advantage
                for trajectory in selected_trajectories
                for advantage in trajectory.terminus.all_logprobs()
            )[:mask_size]
        )

In [85]:
for episode in episode_buffer.episodes:
    for descendent in episode.completion.descendants(including_self=True):
        descendent._cached_value = None

In [172]:
all(
    uncached is cached
    for uncached, cached in zip(
        *(
            [
                episode.best_leaf(tokenizer, cache=cache)
                for episode in episode_buffer.episodes
            ]
            for cache in [False, True]
        )
    )
)

True

In [97]:
# %%timeit
best_leaves = [episode.best_leaf(tokenizer, cache=True) for episode in episode_buffer.episodes]

In [98]:
scores = [leaf.all_abs_advantage_per_token(tokenizer, cache=True) for leaf in best_leaves]

In [99]:
print(
    "Average Absolute Advantage Per Token:",
    sum(
        episode.best_leaf(tokenizer).all_abs_advantage_per_token(tokenizer)
        for episode in episode_buffer.episodes
    )
    / len(episode_buffer.episodes),
)

Average Absolute Advantage Per Token: 0.0043192790392456125


In [33]:
sum(1 for _ in episode_buffer.episodes[0].completion.descendants())

7

In [20]:
leaf = episode_buffer.episodes[0].best_leaf(tokenizer)
leaf.all_abs_advantage() / leaf.all_token_count(tokenizer)

0.00029788501638367595

In [180]:
print(completion.all_message_params()[-1]["content"])

Let's analyze the cards shown to each player:

1. Blake shows the Lounge to Robert.
2. Robert shows the Knife to Blake.
3. Joel shows the Knife to Blake.
4. Blake shows the Candlestick to Robert.
5. Joel shows a card to Blake (must be Mrs. White or the Lead Pipe).
6. Blake shows a card to Joel (must be Mr. Green or the Dining Room).
7. Joel shows the Dining Room to Robert.
8. Joel shows a card to Blake (must be the Lead Pipe).
9. Blake shows a card to Robert (must be Mr. Green).

From the information above, we can determine the following:

- Robert has Miss Scarlet and the Knife.
- Joel has Mrs. White and the Dining Room.
- Blake has Mr. Green and the Candlestick.

Since Blake showed the Candlestick to Robert, and Robert has the Knife, we know that the Candlestick is not the Knife. Therefore, the Candlestick must be the weapon in the Lounge.

Similarly, since Blake showed a card to Joel (the Lead Pipe), and Joel has the Dining Room, we know that the Lead Pipe is not the Dining Room. Th

In [141]:
from dataclasses import dataclass
from pydantic import BaseModel
import torch

@dataclass
class Trajectory:
    episode: Episode
    terminus: Completion
    abs_advantage: float
    token_count: int

    def score(self) -> float:
        return self.episode.weight * self.abs_advantage / self.token_count


def best_trajectory(episode: Episode) -> Trajectory:
    return max(
        (
            Trajectory(
                episode=episode,
                terminus=completion,
                abs_advantage=completion.all_abs_advantage(),
                token_count=completion.all_token_count(tokenizer),
            )
            for completion in episode.completion.leaves()
        ),
        key=lambda t: t.abs_advantage / t.token_count,
    )

In [144]:
2 ** 13

8192

In [147]:
seqlen = 8192
rows = 64

tokens_file = torch.empty(rows * seqlen, dtype=torch.int64)
tokens_file.numpy().tofile("/tmp/tokens.bin")

advantages_file = torch.empty(rows * seqlen, dtype=torch.float32)
advantages_file.numpy().tofile("/tmp/advantages.bin")

logprobs_file = torch.empty(rows * seqlen, dtype=torch.float32)
logprobs_file.numpy().tofile("/tmp/logprobs.bin")

In [157]:
class Request(BaseModel):
    tokens_filename: str
    advantages_filename: str
    logprobs_filename: str
    rows: int
    seqlen: int
    start: int
    stop: int


def handle_request(request: Request) -> None:
    tokens = torch.from_file(
        request.tokens_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.int64,
    ).view(-1, request.seqlen)
    advantages = torch.from_file(
        request.advantages_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    logprobs = torch.from_file(
        request.logprobs_filename,
        shared=True,
        size=request.rows * request.seqlen,
        dtype=torch.float32,
    ).view(-1, request.seqlen)
    trajectories = sorted(
        (best_trajectory(episode) for episode in buffer),
        key=lambda t: t.score(),
    )
    for row in range(request.start, request.stop):
        selected_trajectories: list[Trajectory] = []
        for i in range(0, len(trajectories), -1):
            if (
                trajectories[i].token_count
                + sum(t.token_count for t in selected_trajectories)
                > request.seqlen
            ):
                continue
            selected_trajectories.append(trajectories.pop(i))
        tokens[row] = tokenizer.encode(
            [
                trajectory.terminus.all_message_params()
                for trajectory in selected_trajectories
            ],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        )
        replacement_token = "<|reserved_special_token_250|>"
        mask = tokenizer.encode(
            [trajectory.terminus.all_message_params(replacement_token=replacement_token) for trajectory in selected_trajectories],  # type: ignore
            concatenate=True,
            seqlen=request.seqlen,
        ) == tokenizer.get_token_id(replacement_token)
        mask_size = mask.sum()
        advantages[row] = torch.full_like(mask, fill_value=torch.nan, dtype=torch.float32)
        advantages[row][mask] = torch.tensor(
            list(
                advantage
                for trajectory in selected_trajectories
                for advantage in trajectory.terminus.all_token_advantages()
            )[:mask_size]
        )
        logprobs[row] = torch.full_like(mask, fill_value=torch.nan, dtype=torch.float32)
        logprobs[row][mask] = torch.tensor(
            list(
                advantage
                for trajectory in selected_trajectories
                for advantage in trajectory.terminus.all_logprobs()
            )[:mask_size]
        )

handle_request(
    Request(
        tokens_filename="/tmp/tokens.bin",
        advantages_filename="/tmp/advantages.bin",
        logprobs_filename="/tmp/logprobs.bin",
        rows=64,
        seqlen=8192,
        start=0,
        stop=1,
    )
)

IndexError: list index out of range