In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
import asyncio
import httpx
import json
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import ExploreOptions, Trainer, vLLMConfig
import random
import re
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from typing import AsyncIterable, Literal, Optional

with open("./data/chain-of-thought-examples.json") as f:
    chain_of_thought_examples: list[dict[str, str]] = json.load(f)


async def sample_random_episode(
    difficulty: Literal["easy", "variable"] = "variable",
    example_probability: float = 0.0,
    max_prompt_characters: int = 8192,
    reward_follow_up_completion: bool = True,
    return_first_solver_as_winner: Optional[bool] = None,
) -> Episode:
    while True:
        async with httpx.AsyncClient() as client:
            response = await client.get(
                "http://0.0.0.0:2218/new-episode-data",
                params=dict(
                    difficulty=difficulty,
                    return_first_solver_as_winner=return_first_solver_as_winner,
                ),
            )
            response.raise_for_status()
            result = response.json()
            prompt = result["prompt"]
            follow_up = result["follow_up"]
            solution = result["solution"]
        if len(prompt) <= max_prompt_characters:
            break

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ],
                max_tokens=10
                + len("\n".join(f"{key}: {value}" for key, value in solution.items()))
                // 2,
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        if reward_follow_up_completion:
            completion = follow_up_completion
        completion.reward = sum(
            [
                bool(
                    # Find first match of key followed by colon and capture following text
                    (
                        match := re.search(
                            rf"{key}: ([A-Za-z \.:-]+)",
                            answer,
                            re.IGNORECASE,
                        )
                    )
                    # Check if captured group matches expected value
                    and match.group(1).strip().lower() == value.strip().lower()
                )
                for key, value in solution.items()
            ]
        ) / len(solution)
        completion.reward -= (
            completion.all_absent_stop_tokens
            / (3 if reward_follow_up_completion else 2)
            / len(solution)
        )
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    example = random.choice(chain_of_thought_examples)

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        examples=lambda: (
            [
                {"role": "user", "content": example["prompt"]},
                {
                    "role": "assistant",
                    "content": example["chain_of_thought"]
                    + (example["answer"] and f"\n\n---\n\n{example['answer']}"),
                },
            ]
            if random.random() < example_probability
            else []
        ),
        on_sample=on_sample,
    )


episodes_per_iteration = 64 * torch.cuda.device_count()


async def train_episodes() -> AsyncIterable[Episode | BaseException]:
    pending: set[asyncio.Task[Episode | BaseException]] = set()
    while True:
        pending.update(
            asyncio.create_task(sample_random_episode())
            for _ in range(episodes_per_iteration - len(pending))
        )
        done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
        for task in done:
            try:
                yield task.result()
            except BaseException as e:
                yield e


async def val_episodes() -> AsyncIterable[Episode | BaseException]:
    for fut in asyncio.as_completed(
        sample_random_episode() for _ in range(64 * torch.cuda.device_count())
    ):
        try:
            yield await fut
        except BaseException as e:
            yield e


model_name = "rl33"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    explore_options=ExploreOptions(
        iterations=8,
        num_parents=5,
        branch_factor=3,
        patience=5,
        sample_probability_power=None,
        sampling_kwargs={"max_tokens": 1024},
        split_method="prob",
        split_point_std_deviation=0.5,
    ),
    train_episodes=train_episodes(),
    episodes_per_iteration=episodes_per_iteration,
    max_mask_sequence_batch_size=1,
    val_episodes=val_episodes(),
    val_patience=15,
    val_samples_per_episode=3,
    val_sampling_kwargs={"max_tokens": 1024},
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=True,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=1,
        epochs=1,
        max_steps_per_epoch=32,
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=4e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=0.9,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.75,
            entropy_target_coef=0.15,
            kl_coef=0.25,
            weighted_entropy_coef=0.1,
            weighted_kl_coef=0.0,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=16384,
    vllm_config=vLLMConfig(
        env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
        kwargs=dict(
            block_size=32,
            disable_log_requests=True,
            enable_prefix_caching=True,
            enforce_eager=True,
            gpu_memory_utilization=0.9,
            max_model_len=16384,
            max_num_seqs=2048,
            max_num_batched_tokens=16384,
            preemption_mode="swap",
            return_tokens_as_token_ids=True,
            swap_space=32,
        ),
        max_concurrent_samples=2048,
        timeout=120 + 15 * torch.cuda.device_count(),
    ),
    wandb_kwargs=dict(
        name=model_name,
        id=model_name,
    ),
)

INFO 12-14 19:43:07 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=NousResearch/Hermes-2-Theta-Llama-3-8B, num_scheduler_steps=1, chunked_prefill_enabled=Fal

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [26]:
await trainer.train(iterations=4, verbosity=1)

val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|29|Loss: 0.0443: 100%|██████████| 29/29 [08:11<00:00, 16.48s/it, entropy=0.6889, entropy_target=0.0611, kl_div=0.1234, policy=0.0093, tanh_log_policy=0.0020, unclipped_policy=0.0037, value=1.1347, weighted_ce=0.0159, weighted_entropy=-0.0248, weighted_kl_div=-0.0119]  

Saved iteration 6 model files to /home/ubuntu/atreides/experiments/models/rl29/0006
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0006 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|33|Loss: 0.0624: 100%|██████████| 33/33 [09:16<00:00, 16.57s/it, entropy=0.9417, entropy_target=0.1917, kl_div=0.1481, policy=0.0050, tanh_log_policy=-0.0035, unclipped_policy=-0.0007, value=1.4439, weighted_ce=-0.0004, weighted_entropy=0.0019, weighted_kl_div=0.0019] 

Saved iteration 7 model files to /home/ubuntu/atreides/experiments/models/rl29/0007
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0007 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|41|Loss: 0.0672: 100%|██████████| 41/41 [11:30<00:00, 16.55s/it, entropy=0.4560, entropy_target=0.2940, kl_div=0.0907, policy=-0.0152, tanh_log_policy=0.0001, unclipped_policy=-0.0215, value=2.0385, weighted_ce=0.0001, weighted_entropy=-0.0028, weighted_kl_div=0.0015]  

Saved iteration 8 model files to /home/ubuntu/atreides/experiments/models/rl29/0008
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0008 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|54|Loss: 0.0401: 100%|██████████| 54/54 [15:29<00:00, 17.21s/it, entropy=0.6053, entropy_target=0.1447, kl_div=0.0715, policy=0.0013, tanh_log_policy=0.0019, unclipped_policy=-0.0015, value=1.5641, weighted_ce=-0.0115, weighted_entropy=0.0113, weighted_kl_div=0.0115]    

Saved iteration 9 model files to /home/ubuntu/atreides/experiments/models/rl29/0009
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0009 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

In [27]:
await trainer.train(iterations=4, verbosity=1)

val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|185|Loss: 0.0443: 100%|██████████| 185/185 [51:20<00:00, 16.53s/it, entropy=0.6468, entropy_target=0.1032, kl_div=0.0934, policy=-0.0291, tanh_log_policy=0.0026, unclipped_policy=-0.0323, value=1.9048, weighted_ce=0.0188, weighted_entropy=-0.0312, weighted_kl_div=-0.0066]  

Saved iteration 10 model files to /home/ubuntu/atreides/experiments/models/rl29/0010
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0010 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|101|Loss: 0.0376: 100%|██████████| 101/101 [28:08<00:00, 16.61s/it, entropy=0.6686, entropy_target=0.0814, kl_div=0.0844, policy=-0.0282, tanh_log_policy=0.0022, unclipped_policy=-0.0298, value=1.0489, weighted_ce=0.0084, weighted_entropy=-0.0236, weighted_kl_div=-0.0026]

Saved iteration 11 model files to /home/ubuntu/atreides/experiments/models/rl29/0011
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0011 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|127|Loss: 0.0216: 100%|██████████| 128/128 [35:40<00:00, 16.64s/it, entropy=0.7464, entropy_target=0.0036, kl_div=0.0928, policy=-0.0009, tanh_log_policy=-0.0017, unclipped_policy=-0.0051, value=5.0248, weighted_ce=-0.0075, weighted_entropy=0.0059, weighted_kl_div=0.0063] 

Saved iteration 12 model files to /home/ubuntu/atreides/experiments/models/rl29/0012
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0012 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|143|Loss: 0.0361: 100%|██████████| 143/143 [39:47<00:00, 16.62s/it, entropy=0.6294, entropy_target=0.1206, kl_div=0.0729, policy=0.0015, tanh_log_policy=-0.0001, unclipped_policy=0.0004, value=1.1154, weighted_ce=-0.0008, weighted_entropy=0.0014, weighted_kl_div=0.0005]   

Saved iteration 13 model files to /home/ubuntu/atreides/experiments/models/rl29/0013
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl29/0013 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)


In [28]:
await trainer.train(iterations=12, verbosity=1)

val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)
Early stopping exploration due to expired patience (0 remaining episodes x 5 patience per episode = 0 seconds)
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl29/config.yaml


1|100|Loss: 0.0312:  45%|████▌     | 100/222 [27:53<33:44, 16.59s/it, entropy=0.8347, entropy_target=0.0847, kl_div=0.0736, policy=-0.0045, tanh_log_policy=-0.0004, unclipped_policy=-0.0056, value=2.1442, weighted_ce=0.0013, weighted_entropy=-0.0045, weighted_kl_div=-0.0014]