In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
import asyncio
import httpx
import json
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import ExploreOptions, Trainer, vLLMConfig
import random
import re
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from typing import Any, AsyncIterable, Literal, Optional

with open("./data/chain-of-thought-examples.json") as f:
    chain_of_thought_examples: list[dict[str, str]] = json.load(f)

client = httpx.AsyncClient(
    timeout=httpx.Timeout(5.0, read=30.0),
    limits=httpx.Limits(max_connections=512, max_keepalive_connections=512),
)


async def sample_random_episode(
    difficulty: Literal["easy", "variable"] = "variable",
    example_probability: float = 0.0,
    max_prompt_characters: int = 8192,
    reward_follow_up_completion: bool = True,
    return_first_solver_as_winner: Optional[bool] = None,
    length_penalty: float = 0.0,
) -> Episode:
    while True:
        params: dict[str, Any] = {
            "difficulty": difficulty,
        }
        if return_first_solver_as_winner is not None:
            params["return_first_solver_as_winner"] = return_first_solver_as_winner
        try:
            response = await client.get(
                "http://0.0.0.0:2218/new-episode-data",
                params=params,
            )
            response.raise_for_status()
        except httpx.TimeoutException:
            continue
        result = response.json()
        prompt = result["prompt"]
        follow_up = result["follow_up"]
        solution = result["solution"]
        if len(prompt) <= max_prompt_characters:
            break

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ],
                max_tokens=10
                + len("\n".join(f"{key}: {value}" for key, value in solution.items()))
                // 2,
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        if reward_follow_up_completion:
            completion = follow_up_completion
        completion.reward = sum(
            [
                bool(
                    # Find first match of key followed by colon and capture following text
                    (
                        match := re.search(
                            rf"{key}: ([A-Za-z \.:-]+)",
                            answer,
                            re.IGNORECASE,
                        )
                    )
                    # Check if captured group matches expected value
                    and match.group(1).strip().lower() == value.strip().lower()
                )
                for key, value in solution.items()
            ]
        ) / len(solution)
        completion.reward -= (
            completion.all_absent_stop_tokens
            / (3 if reward_follow_up_completion else 2)
            / len(solution)
        )
        completion.reward -= (
            completion.completion_tokens
            / (len(prompt) + len(solution) * 10)
            * length_penalty
        )
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    example = random.choice(chain_of_thought_examples)

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        examples=lambda: (
            [
                {"role": "user", "content": example["prompt"]},
                {
                    "role": "assistant",
                    "content": example["chain_of_thought"]
                    + (example["answer"] and f"\n\n---\n\n{example['answer']}"),
                },
            ]
            if random.random() < example_probability
            else []
        ),
        on_sample=on_sample,
    )


episodes_per_iteration = 64 * torch.cuda.device_count()


async def train_episodes() -> AsyncIterable[Episode | BaseException]:
    pending: set[asyncio.Task[Episode | BaseException]] = set()
    while True:
        pending.update(
            asyncio.create_task(sample_random_episode(example_probability=0.2))
            for _ in range(episodes_per_iteration - len(pending))
        )
        done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
        for task in done:
            try:
                yield task.result()
            except BaseException as e:
                yield e


async def val_episodes() -> AsyncIterable[Episode | BaseException]:
    for fut in asyncio.as_completed(
        sample_random_episode() for _ in range(64 * torch.cuda.device_count())
    ):
        try:
            yield await fut
        except BaseException as e:
            yield e


model_name = "rl38"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    explore_options=ExploreOptions(
        iterations=7,
        num_parents=6,
        branch_factor=3,
        patience=5,
        sample_probability_power=None,
        sampling_kwargs={
            "max_tokens": 2048
        },
        # split_method="prob",
        # split_point_std_deviation=0.5,
    ),
    train_episodes=train_episodes(),
    episodes_per_iteration=episodes_per_iteration,
    max_mask_sequence_batch_size=1,
    val_episodes=val_episodes(),
    val_patience=15,
    val_samples_per_episode=3,
    val_sampling_kwargs={"max_tokens": 2048},
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=True,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=1,
        epochs=1,
        # max_steps_per_epoch=32,
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=4e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=0.8,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.6,
            entropy_target_coef=0.05,
            kl_coef=0.05,
            weighted_entropy_coef=0.2,
            weighted_kl_coef=0.0,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=16384,
    vllm_config=vLLMConfig(
        env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
        kwargs=dict(
            block_size=32,
            disable_log_requests=True,
            enable_chunked_prefill=True,
            enable_prefix_caching=True,
            enforce_eager=True,
            gpu_memory_utilization=0.9,
            max_model_len=16384,
            max_num_seqs=512,
            max_num_batched_tokens=16384,
            preemption_mode="swap",
            return_tokens_as_token_ids=True,
            swap_space=100,
        ),
        max_concurrent_samples=512,
        min_time_between_requests=0.0,
        timeout=120 + 15 * torch.cuda.device_count(),
    ),
    wandb_kwargs=dict(
        name=model_name,
        id=model_name,
    ),
)

INFO 12-17 18:15:22 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=NousResearch/Hermes-2-Theta-Llama-3-8B, num_scheduler_steps=1, chunked_prefill_enabled=Fal

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [77]:
await trainer.train(iterations=10, verbosity=1)

Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl38/0001 --port=8002 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default


In [4]:
result = await trainer.explore(verbosity=1)

Starting 1 vLLM servers...
$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --port=8002 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


explore:   0%|          | 0/64 [00:00<?, ?episode/s]

In [41]:
list(
    token_logprob.logprob
    for token_logprob in list(result.episodes[23].completion.leaves())[
        0
    ].all_token_logprobs()
)

[-2.573509693145752,
 -1.5497195136049413e-06,
 -3.683499380713329e-05,
 -1.1517176628112793,
 -0.5885536670684814,
 -0.16787488758563995,
 -0.03347204998135567,
 -0.017174215987324715,
 -0.011563891544938087,
 -6.958192348480225,
 -0.04683351516723633,
 -0.35457366704940796,
 -0.2585309147834778,
 -0.7922614812850952,
 -2.3859448432922363,
 -0.3589121699333191,
 -9.238292841473594e-05,
 -0.0005824061809107661,
 -0.806515097618103,
 -0.009681769646704197,
 -0.47959402203559875,
 -1.8232364654541016,
 -0.23573361337184906,
 -0.056180961430072784,
 -2.094564914703369,
 -1.9626911878585815,
 -0.18298567831516266,
 -4.172316494077677e-06,
 -0.004867489915341139,
 0.0,
 -0.5472233295440674,
 0.0,
 -4.7205765440594405e-05,
 -0.00012206286191940308,
 -1.0607349872589111,
 -1.7881377516459906e-06,
 -0.6786861419677734,
 -6.198863957251888e-06,
 -0.001359015703201294,
 0.0,
 -3.814689989667386e-06,
 0.0,
 -0.26380497217178345,
 0.0,
 -0.00022504181833937764,
 -9.65590606938349e-06,
 -0.00033730

In [69]:
(~torch.isnan(result.tensors()["values"])).sum()



tensor(989145)

In [73]:
trainer.tune_recipe_config.loss.normalize_values = False

In [76]:
await trainer.tune(result)

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl38/config.yaml


DEBUG:torchtune.utils._logging:Training is not distributed. If you want to train on multiple GPUs and are using the tune CLI, specify --nnodes 1 and --nproc_per_node [num_gpus]
INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.checkpointing._checkpointer.FullModelHFCheckpointer
  checkpoint_dir: /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725
  checkpoint_files:
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00004-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00001-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b

Saved iteration 1 model files to /home/ubuntu/atreides/experiments/models/rl38/0001


In [4]:
def tokens(string: str) -> list[int]:
    tokens = trainer.tokenizer.llm.get_tokenizer().encode(
        string, add_special_tokens=False  # type: ignore
    )
    if len(tokens) > 1:
        return []
    return tokens


word_biases = [
    # Critical thinking markers (highest importance)
    ("wait", 1.0),  # Frequent correction initiator
    ("hmm", 1.8),  # Common uncertainty marker
    ("alternatively", 1.4),  # Key for exploring options
    # Important correction/uncertainty markers
    ("perhaps", 1.2),  # Common for tentative suggestions
    ("actually", 1.2),  # Important for corrections
    ("rather", 1.0),  # Refinement marker
    ("however", 0.6),  # Common for corrections
    # Analysis progression markers
    ("first", 0.6),
    ("similarly", 0.4),
    ("therefore", 0.4),
    ("instead", 0.6),
    # Mistake acknowledgment
    ("wrong", 1.0),
    ("incorrect", 1.0),
    ("no", 0.2),
    # Supporting markers (lower importance)
    ("but", 0.2),
    ("so", 0.1),
    ("thus", 0.4),
    ("now", 0.1),
    ("given", 0.4),
    ("suppose", 0.4),
    # Light correction markers
    ("oh", 1.0),
    ("oops", 1.0),
    ("right", 0.3),
]

logit_bias = {
    str(token): bias * 6
    for word, bias in word_biases
    for transform in [lambda x: x, str.capitalize, lambda x: " " + x]
    for token in tokens(transform(word))
}
logit_bias

{'11748': 6.0,
 '14524': 6.0,
 '3868': 6.0,
 '81122': 10.8,
 '88601': 10.8,
 '93114': 8.399999999999999,
 '69487': 8.399999999999999,
 '66372': 7.199999999999999,
 '32576': 7.199999999999999,
 '8530': 7.199999999999999,
 '74128': 7.199999999999999,
 '53692': 7.199999999999999,
 '3604': 7.199999999999999,
 '74303': 6.0,
 '65002': 6.0,
 '4856': 6.0,
 '98936': 3.5999999999999996,
 '11458': 3.5999999999999996,
 '4869': 3.5999999999999996,
 '3983': 3.5999999999999996,
 '5451': 3.5999999999999996,
 '1176': 3.5999999999999996,
 '68791': 2.4000000000000004,
 '30293': 2.4000000000000004,
 '55915': 2.4000000000000004,
 '9093': 2.4000000000000004,
 '65937': 3.5999999999999996,
 '31887': 3.5999999999999996,
 '4619': 3.5999999999999996,
 '35970': 6.0,
 '30285': 6.0,
 '5076': 6.0,
 '63054': 6.0,
 '41568': 6.0,
 '15465': 6.0,
 '2201': 1.2000000000000002,
 '2822': 1.2000000000000002,
 '912': 1.2000000000000002,
 '8248': 1.2000000000000002,
 '4071': 1.2000000000000002,
 '719': 1.2000000000000002,
 '708

In [5]:
trainer.explore_options.sampling_kwargs["logit_bias"] = logit_bias
trainer.explore_options.sampling_kwargs["frequency_penalty"] = 0.5

In [111]:
trainer.explore_options.iterations = 1
trainer.explore_options.num_parents = 1
trainer.explore_options.branch_factor = 2

In [112]:
trainer.explore_options.patience = 60

In [113]:
explore_result = await trainer.explore(verbosity=1)

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

In [59]:
len(list(explore_result.episodes[0].completion.leaves()))

2

In [115]:
print(list(explore_result.episodes[3].completion.leaves())[0].all_message_params()[1]["content"])

Let's first analyze the given information:

1. Skylar has Madame Rose and Candlestick.
2. Leslie showed Monsieur Brunette to Skylar.
3. Leslie showed Miss Scarlet to Skylar.
4. Leslie showed a card to Kayla (unknown which one).
5. Kayla showed Leslie a card (unknown which one).
6. Kayla showed Sophie a card (unknown which one).
7. Kayla showed Sophie another card (unknown which one).
8. Sophie showed Skylar the Carriage House.
9. Skylar showed Leslie the Candlestick.
10. Skylar showed Kayla the Candlestick.

From point 9, we can determine that the facedown card in the center of the table for "Weapon" was actually already revealed, so there was no "Candlestick" in play.

Now, let's look at point 7 again: "Kayla showed Sophie another card (unknown which one)." Since we know that Kayla has 3 cards and she has already shown two cards to other players, we can deduce that this third card must be the remaining weapon, Rope.

Similarly, from point 10: "Skylar showed Kayla the Candlestick." Sin

In [None]:
await trainer.tune(trainer.explore_results[-1])

In [6]:
await trainer.train(iterations=10, verbosity=1)

Starting 1 vLLM servers...
$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --port=8000 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val: 0episode [00:00, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]