In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass, field
from dotenv import load_dotenv
from lib.clue import Clue, DeductiveSolver
import os
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
from pathlib import Path
import random
import re
import subprocess
import time
from typing import Generic, Optional, TypeVar

load_dotenv()

True

In [None]:
# Start the vllm serve process and redirect stdout/stderr to log files
log_dir = "./logs"
os.makedirs(log_dir, exist_ok=True)
api_key = "sk-" + "".join(
    random.choices(
        "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=32
    )
)  # Generate 32-char random key
vllm_serve_command = [
    "vllm",
    "serve",
    str(Path("./models/test/0001").absolute()),
    "--api-key",
    api_key,
    "--served-model-name",
    "default",
]
print(f"Starting vllm serve with command: {' '.join(vllm_serve_command)}")
stdout_log = open(f"{log_dir}/vllm_stdout.log", "w")
stderr_log = open(f"{log_dir}/vllm_stderr.log", "w")
vllm_process = subprocess.Popen(
    vllm_serve_command, stdout=stdout_log, stderr=stderr_log
)

# To shut down the process later, use vllm_process.terminate() or vllm_process.kill()

client = AsyncOpenAI(api_key=api_key, base_url="http://localhost:8000/v1")
sync_client = OpenAI(api_key=api_key, base_url="http://localhost:8000/v1")

start_time = time.time()
timeout = 90.0
failed = True
while failed:
    try:
        sync_client.chat.completions.create(
            model=model,
            messages=[
                {"role": "user", "content": "Hi!"},
            ],
            max_tokens=1,
        )
        failed = False
    except Exception:
        if time.time() - start_time > timeout:
            raise TimeoutError("VLLM server failed to start")

print(f"VLLM server started in {time.time() - start_time:.2f} seconds ✅")

In [None]:
vllm_process.terminate()

In [3]:
client = AsyncOpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key=os.getenv("OPENROUTER_API_KEY"),
)

In [4]:
@dataclass
class Completion:
    parent: Optional["Completion"] = None  # State
    messages: list[ChatCompletionMessageParam] = field(default_factory=list)  # Action
    logprobs: list[Optional[list[ChatCompletionTokenLogprob]]] = field(
        default_factory=list
    )  # Action
    reward: float = 0.0  # Reward
    # Next state, action, reward triples
    children: list["Completion"] = field(default_factory=list)

    def __post_init__(self) -> None:
        if len(self.logprobs) == 0:
            self.logprobs = [None] * len(self.messages)

    def all_messages(self) -> list[ChatCompletionMessageParam]:
        if not self.parent:
            return self.messages
        return self.parent.all_messages() + self.messages

    def value(self) -> float:
        if not self.children:
            return self.reward
        return self.reward + (
            sum(c.value() for c in self.children) / len(self.children)
        )

    def advantage(self) -> float:
        if self.parent is None:
            return 0.0
        return self.value() - self.parent.value()

    def split(self) -> None:
        assert len(self.messages) == len(self.logprobs)
        split = (
            sum(len(logprobs) for logprobs in self.logprobs if logprobs is not None)
            // 2
        )
        new_completion = Completion(parent=self)
        for child in self.children:
            child.parent = new_completion
        self.children = [new_completion]
        cum_sum = 0
        for i, (message, logprobs) in enumerate(zip(self.messages, self.logprobs)):
            if not logprobs:
                continue
            local_split = split - cum_sum
            cum_sum += len(logprobs)
            if cum_sum < split:
                continue
            new_logprobs = logprobs[local_split:]
            logprobs[local_split:] = []
            new_completion.logprobs.append(new_logprobs)
            suffix = "".join(logprob.token for logprob in new_logprobs)
            message_content = message.get("content", "")
            assert isinstance(message_content, str) and message_content.endswith(suffix)
            message_content.removesuffix(suffix)
            new_message = message.copy()
            new_message["content"] = suffix
            new_completion.messages.append(new_message)
            break
        new_completion.messages.extend(self.messages[i + 1 :])
        self.messages = self.messages[: i + 1]

In [5]:
import asyncio

root = Completion()
model = "nousresearch/hermes-2-theta-llama-3-8b"


async def sample_episode_completions() -> None:
    game = Clue(
        num_players=3,
        elements={
            "suspect": Clue.suspects[:3],
            "weapon": Clue.weapons[:3],
            "room": Clue.rooms[:3],
            # "motive": Clue.motives[:6],
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=False,
    )
    prompt = game.get_prompt()
    follow_up = "Fill out your answer like this:\n" + "\n".join(
        f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements
    )
    episode = Completion(messages=[{"role": "user", "content": prompt}])

    chat_completion = await client.chat.completions.create(
        messages=episode.all_messages(),
        model=model,
        logprobs=True,
        n=3,
    )

    async def process_choice(choice: Choice, parent: Completion) -> None:
        assert choice.logprobs is not None
        completion = Completion(
            messages=[
                {"role": "assistant", "content": choice.message.content},
            ],
            logprobs=[choice.logprobs.content],
        )
        chat_completion = await client.chat.completions.create(
            messages=completion.all_messages()
            + [{"role": "user", "content": follow_up}],
            model=model,
        )
        answer = chat_completion.choices[0].message.content
        if not answer:
            return
        completion.reward = sum(
            [
                bool(
                    re.search(
                        f"{element}: {solution}",
                        answer,
                        re.IGNORECASE,
                    )
                )
                for element, solution in game.solution.items()
            ]
        ) / len(game.solution)
        completion.parent = parent
        parent.children.append(completion)
        if parent.parent is None:
            parent.parent = root
            root.children.append(parent)

    await asyncio.gather(
        *[process_choice(choice, episode) for choice in chat_completion.choices]
    )

    if episode.value() <= 0.0:
        # Sample an easier episode
        await sample_episode_completions()
    elif 0.0 < episode.value() < 1.0:
        # Sample more completions
        chat_completion = await client.chat.completions.create(
            messages=episode.all_messages(),
            model=model,
            logprobs=True,
            n=2,
        )
        await asyncio.gather(
            *[process_choice(choice, episode) for choice in chat_completion.choices]
        )
        outlier_completion = max(episode.children, key=lambda c: abs(c.advantage()))
        outlier_completion.split()
        chat_completion = await client.chat.completions.create(
            messages=outlier_completion.all_messages(),
            model=model,
            logprobs=True,
            n=2,
            extra_body=dict(continue_final_message=True),
        )
        await asyncio.gather(
            *[
                process_choice(choice, outlier_completion)
                for choice in chat_completion.choices
            ]
        )
    else:
        # Sample a harder episode
        await sample_episode_completions()


await sample_episode_completions()

AssertionError: 

In [7]:
root

Completion(parent=None, messages=[], logprobs=[], reward=0.0, children=[])

In [None]:
@dataclass
class Completion:
    message: ChatCompletionMessageParam
    logprobs: Optional[ChoiceLogprobs] = None
    previous: Optional["Completion"] = None

In [3]:
T = TypeVar("T")


@dataclass
class Tree(Generic[T]):
    state: T
    children: list["Tree[T]"] = field(default_factory=list)


@dataclass
class Completion:
    message: ChatCompletionMessageParam
    logprobs: Optional[ChoiceLogprobs] = None


@dataclass
class Episode:
    rollouts: Tree[Completion]
    solution: dict[str, str]


async def sample_episode(debug: bool = False) -> Episode:
    game = Clue(
        num_players=3,
        elements={
            "suspect": Clue.suspects[:3],
            "weapon": Clue.weapons[:3],
            "room": Clue.rooms[:3],
            # "motive": Clue.motives[:6],
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=debug,
    )
    prompt = game.get_prompt()
    episode = Episode(
        rollouts=Tree(Completion(message={"role": "user", "content": prompt})),
        solution=game.solution,
    )
    if debug:
        print("\nUser:")
        print(prompt)
    completion = await client.chat.completions.create(
        messages=[episode.rollouts.state.message],
        model=model,
        logprobs=True,
    )
    for choice in completion.choices:
        response = choice.message.content or choice.message.refusal
        assert response
        step = Completion(
            message={
                "role": "assistant",
                "content": response,
            },
            logprobs=choice.logprobs,
        )
        node = Tree(step)
        episode.rollouts.children.append(node)
        follow_up = "Fill out your answer like this:\n" + "\n".join(
            f"{element.capitalize()}: <#{element.upper()}#>"
            for element in game.elements
        )
        if debug:
            print("\nAssistant:")
            print(response)
            print("\nUser:")
            print(follow_up)
        completion = await client.chat.completions.create(
            messages=[
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": response},
                {"role": "user", "content": follow_up},
            ],
            model=model,
            logprobs=True,
        )
        answer = completion.choices[0].message.content
        assert answer
        if debug:
            print("\nAssistant:")
            print(answer)
        score = sum(
            [
                bool(
                    re.search(
                        f"{element}: {solution}",
                        answer,
                        re.IGNORECASE,
                    )
                )
                for element, solution in game.solution.items()
            ]
        ) / len(game.solution)
        print(f"Score: {score:.2f}")
        return Episode(
            rollouts=Tree(
                (
                    {"role": "user", "content": prompt},
                    completion.choices[0].logprobs,
                ),
                [
                    Tree(
                        (
                            {"role": "assistant", "content": response},
                            completion.choices[0].logprobs,
                        ),
                    ),
                    Tree(
                        (
                            {"role": "user", "content": follow_up},
                            completion.choices[0].logprobs,
                        ),
                        [
                            Tree(
                                (
                                    {"role": "assistant", "content": answer},
                                    completion.choices[0].logprobs,
                                ),
                            ),
                        ],
                    ),
                ],
            ),
            solution=game.solution,
        )


async def get_rollout(debug: bool = False) -> Tree:
    game = Clue(
        num_players=3,
        elements={
            "suspect": Clue.suspects[:3],
            "weapon": Clue.weapons[:3],
            "room": Clue.rooms[:3],
            # "motive": Clue.motives[:6],
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=debug,
    )
    prompt = game.get_prompt()
    if debug:
        print("\nUser:")
        print(prompt)
    completion = await client.chat.completions.create(
        messages=[
            {"role": "user", "content": prompt},
        ],
        model=model,
        # logprobs=True,
        # extra_body=dict(prompt_logprobs=True),
    )
    response = completion.choices[0].message.content
    assert response
    follow_up = "Fill out your answer like this:\n" + "\n".join(
        f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements
    )
    if debug:
        print("\nAssistant:")
        print(response)
        print("\nUser:")
        print(follow_up)
    completion = await client.chat.completions.create(
        messages=[
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": response},
            {"role": "user", "content": follow_up},
        ],
        model=model,
        # logprobs=True,
        # extra_body=dict(prompt_logprobs=True),
    )
    answer = completion.choices[0].message.content
    assert answer
    if debug:
        print("\nAssistant:")
        print(answer)
    score = sum(
        [
            bool(
                re.search(
                    f"{element}: {solution}",
                    answer,
                    re.IGNORECASE,
                )
            )
            for element, solution in game.solution.items()
        ]
    ) / len(game.solution)
    print(f"Score: {score:.2f}")
    return Tree(prompt, response, follow_up, answer, score)


await get_rollout(debug=True)

Player 1's Hand: {'Hall', 'Miss Scarlet'}
Player 2's Hand: {'Mr. Green', 'Lead Pipe'}
Player 3's Hand: {'Knife', 'Dining Room'}
Solution: {'suspect': 'Mrs. White', 'weapon': 'Candlestick', 'room': 'Lounge'}
Player                1  2  3
Element Card                 
Suspect Miss Scarlet  ✓  ✗  ✗
        Mr. Green     ✗  ✓  ✗
        Mrs. White    ✗  ✗  ✗
Weapon  Candlestick   ✗  ✗  ✗
        Knife         ✗  ✗  ✓
        Lead Pipe     ✗  ✓  ✗
Room    Hall          ✓  ✗  ✗
        Lounge        ✗  ✗  ✗
        Dining Room   ✗  ✗  ✓
Player 1's Simple Solver Grid:
Player                1  2  3
Element Card                 
Suspect Miss Scarlet  ✓  ✗  ✗
        Mr. Green     ✗      
        Mrs. White    ✗      
Weapon  Candlestick   ✗      
        Knife         ✗      
        Lead Pipe     ✗      
Room    Hall          ✓  ✗  ✗
        Lounge        ✗      
        Dining Room   ✗      
Player 1's CP-SAT Solver Grid:
Player                1  2  3
Element Card                 
Suspect Mis

Rollout(prompt="On a cool autumn afternoon Kennedy, Charles, and Isaac and sat down to play a friendly deduction game.\n\nThey assembled 3 groups of cards, each for a different type of data composed of the following:\n\nSuspect:\n- Miss Scarlet\n- Mr. Green\n- Mrs. White\n\nWeapon:\n- Candlestick\n- Knife\n- Lead Pipe\n\nRoom:\n- Hall\n- Lounge\n- Dining Room\n\nAfter randomly (and blindly) choosing one card from each group and placing them in the center of the table facedown, they shuffled the remaining cards and dealt out the following to each player:\n\n- Kennedy: 2 cards (Hall and Miss Scarlet)\n- Charles: 2 cards\n- Isaac: 2 cards\n\nThe game proceeded as follows:\n\n1. On their turn, a player asked about a set of exactly 3 cards, one from each of the game's categories. (Note: Players could ask about any cards, including those in their own hand.)\n2. The player directed this question to the other players in clockwise order, starting with the player to their left.\n3. If a player h

In [28]:
import re

assert answer
sum(
    [
        bool(
            re.search(
                f"{element}: {solution}",
                answer,
                re.IGNORECASE,
            )
        )
        for element, solution in game.solution.items()
    ]
) / len(game.solution)

0.6666666666666666

In [12]:
print("Fill in out your answer like this:\n" + "\n".join(f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements))

Fill in out your answer like this:
Suspect: <#SUSPECT#>
Weapon: <#WEAPON#>
Room: <#ROOM#>


In [16]:
vllm_process.terminate()

In [3]:
from lib.clue import Clue, DeductiveSolver

game = Clue(
    num_players=3,
    elements={
        "suspect": Clue.suspects[:3],
        "weapon": Clue.weapons[:3],
        "room": Clue.rooms[:3],
        # "motive": Clue.motives[:6],
        # "time": Clue.get_times("21:00", "03:00", "1h"),
    },
)
game.play(
    deductive_solver=DeductiveSolver(
        # note_cards_in_hand=False,
        # note_responses_to_suggestions=False,
        # note_cards_that_players_do_not_have=False,
        # check_unique_card_placement_constraints=False,
        # check_player_hand_size_constraints=False,
        check_solution_has_one_and_only_one_card_per_element=False,
        check_one_of_constraints=False,
        check_inverse_one_of_constraints=False,
        merge_and_check_disjoint_inverse_one_of_constraints=False,
        exhaustively_test_possible_assignments=False,
    ),
    check_if_deductive_solver_and_cp_solver_grids_match=False,
)
rollouts = game.get_prompt()

Player 1's Hand: {'Mrs. White', 'Candlestick'}
Player 2's Hand: {'Dining Room', 'Lounge'}
Player 3's Hand: {'Miss Scarlet', 'Lead Pipe'}
Solution: {'suspect': 'Mr. Green', 'weapon': 'Knife', 'room': 'Hall'}
Player                1  2  3
Element Card                 
Suspect Miss Scarlet  ✗  ✗  ✓
        Mr. Green     ✗  ✗  ✗
        Mrs. White    ✓  ✗  ✗
Weapon  Candlestick   ✓  ✗  ✗
        Knife         ✗  ✗  ✗
        Lead Pipe     ✗  ✗  ✓
Room    Hall          ✗  ✗  ✗
        Lounge        ✗  ✓  ✗
        Dining Room   ✗  ✓  ✗
Player 1's Simple Solver Grid:
Player                1  2  3
Element Card                 
Suspect Miss Scarlet  ✗      
        Mr. Green     ✗      
        Mrs. White    ✓  ✗  ✗
Weapon  Candlestick   ✓  ✗  ✗
        Knife         ✗      
        Lead Pipe     ✗      
Room    Hall          ✗      
        Lounge        ✗      
        Dining Room   ✗      
Player 1's CP-SAT Solver Grid:
Player                1  2  3
Element Card                 
Suspect Mis

In [None]:
from vllm.sampling_params import SamplingParams

rollouts = """
On a warm spring day Summer, Giselle and Connor sat down to play a casual mystery game.

They assembled 3 decks of cards, each for a separate type of information composed of the following:

Suspect:
- Miss Scarlet
- Mr. Green
- Mrs. White

Weapon:
- Candlestick
- Knife
- Lead Pipe

Room:
- Hall
- Lounge
- Dining Room

After randomly (and blindly) choosing one card from each group and placing them in the middle of the table facedown, they shuffled the remaining cards and dealt out the following to each player:

- Summer: 2 cards
- Giselle: 2 cards ('Lounge', 'Miss Scarlet')
- Connor: 2 cards

The game proceeded as follows:

1. On their turn, a player asked about a set of exactly 3 cards, one from each of the game's categories. (Note: Players could ask about any cards, including those in their own hand.)
2. The player directed this question to the other players in clockwise order, starting with the player to their left.
3. If a player had one or more of the asked-about cards, they had to show one of those cards (of their choice) to the asking player privately. The turn then ended, and play passed to the next player.
4. If a player did not have any of the asked-about cards, they said so, and the question passed to the next player in clockwise order.
5. This continued until either:
    a) A player showed a card to the asking player, or
    b) All the queried players had stated they didn't have any of the asked-about cards.
6. After a player's turn ended (either by being shown a card or having all queried players pass), play moved to the next player in clockwise order.

Here is how the game played out:

Summer asked if anyone had 'Mrs. White' or 'Knife' or 'Dining Room':
- Giselle did not have any of the cards
- Connor showed Summer a card

Giselle asked if anyone had 'Mrs. White' or 'Knife' or 'Lounge':
- Connor did not have any of the cards
- Summer did not have any of the cards

Connor asked if anyone had 'Miss Scarlet' or 'Candlestick' or 'Hall':
- Summer did not have any of the cards
- Giselle showed Connor 'Miss Scarlet'

Summer asked if anyone had 'Mr. Green' or 'Knife' or 'Hall':
- Giselle did not have any of the cards
- Connor did not have any of the cards

At this point, Giselle was able to correctly infer the solution and win the game.

What were the facedown cards in the middle of the table?
""".strip()

output = llm.chat([[dict(role="user", content=rollouts)]] * 1, sampling_params=SamplingParams(max_tokens=10_000))  # type: ignore