In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
import asyncio
import json
from lib.clue import Clue, DeductiveSolver
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import ExploreOptions, Trainer, vLLMConfig
from lib.utils import return_exception
import random
import re
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from typing import Literal, Optional

with open("./data/chain-of-thought-examples.json") as f:
    chain_of_thought_examples: list[dict[str, str]] = json.load(f)


def get_variable_difficulty_game(
    return_first_solver_as_winner: Optional[bool] = None,
) -> Clue:
    num_players = random.randint(3, 6)
    num_weapons = max(
        3,
        min(
            num_players + random.randint(-1, 5),
            len(Clue.weapons),
        ),
    )
    num_suspects = min(
        num_weapons + random.randint(0, num_weapons - 1), len(Clue.suspects)
    )
    num_rooms = min(num_suspects + random.randint(0, num_suspects - 2), len(Clue.rooms))
    elements = {
        "suspect": random.sample(Clue.suspects, k=num_suspects),
        "weapon": random.sample(Clue.weapons, k=num_weapons),
        "room": random.sample(Clue.rooms, k=num_rooms),
    }
    if random.random() < 0.1:
        elements["motive"] = random.sample(
            Clue.motives,
            k=max(3, min(num_weapons + random.randint(-1, 3), len(Clue.motives))),
        )
    if random.random() < 0.1:
        frequency = random.choice([0.25, 0.5, 1.0])
        start = 24.0 - frequency
        end = 0.0
        for _ in range(random.randint(1, num_weapons + 1)):
            if random.randint(0, 1):
                end += frequency
            else:
                start -= frequency

        def format_time(time: float) -> str:
            return f"{int(time):02d}:{int(60 * (time - int(time))):02d}"

        elements["time"] = Clue.get_times(
            format_time(start), format_time(end), f"{int(frequency * 60)}min"
        )
    game = Clue(
        num_players=num_players,
        elements=elements,
    )
    difficulty_level = num_players + random.randint(-2, 3)
    # print(f"Players: {num_players}")
    # for element in elements:
    #     print(f"{element.capitalize()}: {len(elements[element])}")
    # print(f"Difficulty level: {difficulty_level}")
    return game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=difficulty_level > 1,
            check_one_of_constraints=difficulty_level > 2,
            check_inverse_one_of_constraints=difficulty_level > 3,
            merge_and_check_disjoint_inverse_one_of_constraints=difficulty_level > 4,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=(
            bool(random.randint(0, 1))
            if return_first_solver_as_winner is None
            else return_first_solver_as_winner
        ),
        print_playthrough=False,
        max_turns=100,
    )


def get_easy_game(return_first_solver_as_winner: Optional[bool] = None) -> Clue:
    game = Clue(
        num_players=3,
        elements={
            "suspect": random.sample(Clue.suspects, k=3),
            "weapon": random.sample(Clue.weapons, k=3),
            "room": random.sample(Clue.rooms, k=3),
            # "motive": random.sample(Clue.motives, k=3),
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=return_first_solver_as_winner or False,
        print_playthrough=False,
        max_turns=100,
    )
    return game


@return_exception
def sample_random_episode(
    difficulty: Literal["easy", "variable"] = "variable",
    example_probability: float = 0.0,
    max_prompt_characters: int = 8192,
    reward_follow_up_completion: bool = True,
    return_first_solver_as_winner: Optional[bool] = None,
) -> Episode:
    while True:
        try:
            game = (
                get_easy_game if difficulty == "easy" else get_variable_difficulty_game
            )(return_first_solver_as_winner=return_first_solver_as_winner)
            prompt, follow_up, solution = game.get_prompt_and_follow_up_and_solution()
        except ValueError:
            continue
        if len(prompt) <= max_prompt_characters:
            break

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ]
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        if reward_follow_up_completion:
            completion = follow_up_completion
        completion.reward = sum(
            [
                bool(
                    # Find first match of key followed by colon and capture following text
                    (
                        match := re.search(
                            rf"{key}: ([A-Za-z \.:-]+)",
                            answer,
                            re.IGNORECASE,
                        )
                    )
                    # Check if captured group matches expected value
                    and match.group(1).strip().lower() == value.strip().lower()
                )
                for key, value in solution.items()
            ]
        ) / len(solution)
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    example = random.choice(chain_of_thought_examples)

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        examples=lambda: (
            [
                {"role": "user", "content": example["prompt"]},
                {
                    "role": "assistant",
                    "content": example["chain_of_thought"]
                    + (example["answer"] and f"\n\n---\n\n{example['answer']}"),
                },
            ]
            if random.random() < example_probability
            else []
        ),
        on_sample=on_sample,
    )


def train_episodes():
    while True:
        yield sample_random_episode(example_probability=0.33)


model_name = "rl25"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    explore_options=ExploreOptions(
        iterations=7,
        num_parents=6,
        branch_factor=3,
        patience=5,
        sample_probability_power=None,
        sampling_kwargs={
            "max_tokens": 1024,
        },
    ),
    train_episodes=train_episodes(),
    episodes_per_iteration=64 * torch.cuda.device_count(),
    max_mask_sequence_batch_size=1,
    val_episodes=(
        sample_random_episode() for _ in range(64 * torch.cuda.device_count())
    ),
    val_patience=15,
    val_samples_per_episode=3,
    val_sampling_kwargs={"max_tokens": 1024},
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=False,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=1,
        epochs=1,
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=4e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=0.8,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.75,
            entropy_target_coef=0.1,
            kl_coef=0.1,
            weighted_entropy_coef=0.1,
            weighted_kl_coef=0.0,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=16384,
    vllm_config=vLLMConfig(
        env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
        kwargs=dict(
            block_size=32,
            disable_log_requests=True,
            enable_prefix_caching=True,
            enforce_eager=True,
            gpu_memory_utilization=0.9,
            max_model_len=16384,
            max_num_seqs=512,
            max_num_batched_tokens=16384 * 4,
            return_tokens_as_token_ids=True,
            swap_space=32,
        ),
        max_concurrent_samples=512,
        # min_time_between_requests=15 / (64 * 24),
        timeout=120 + 15 * torch.cuda.device_count(),
    ),
    wandb_kwargs=dict(
        name=model_name,
        id=model_name,
    )
)

Resuming from /home/ubuntu/atreides/experiments/models/rl25/0001
INFO 12-12 23:27:17 config.py:446] This model supports multiple tasks: {'reward', 'embed', 'score', 'generate', 'classify'}. Defaulting to 'generate'.
INFO 12-12 23:27:17 llm_engine.py:250] Initializing an LLM engine (v0.1.dev3774+g0b7ca0f) with config: VllmConfig(model_config=<vllm.config.ModelConfig object at 0x74003ee5ee10>, cache_config=<vllm.config.CacheConfig object at 0x740031f025a0>, parallel_config=ParallelConfig(pipeline_parallel_size=1, tensor_parallel_size=1, worker_use_ray=False, max_parallel_loading_workers=None, disable_custom_all_reduce=False, tokenizer_pool_config=None, ray_workers_use_nsight=False, placement_group=None, distributed_executor_backend=None, worker_cls='vllm.worker.worker.Worker', sd_worker_cls='auto', world_size=1, rank=0), scheduler_config=SchedulerConfig(runner_type='generate', max_num_batched_tokens=8192, max_num_seqs=256, max_model_len=8192, num_lookahead_slots=0, delay_factor=0.0, enab

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
await trainer.train(iterations=10, verbosity=1)

Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl25/0001 --port=8003 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val: 0episode [00:00, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl25/config.yaml


1|62|Loss: 0.0487: 100%|██████████| 62/62 [17:16<00:00, 16.56s/it, entropy=0.4355, entropy_target=0.3145, kl_div=0.1720, policy=-0.0009, tanh_log_policy=0.0000, unclipped_policy=-0.0009, value=1.6936, weighted_ce=0.0004, weighted_entropy=-0.0005, weighted_kl_div=-0.0002] 

Saved iteration 2 model files to /home/ubuntu/atreides/experiments/models/rl25/0002
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl25/0002 --port=8003 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl25/config.yaml


1|57|Loss: 0.0293: 100%|██████████| 57/57 [15:54<00:00, 16.63s/it, entropy=0.6293, entropy_target=0.1207, kl_div=0.1727, policy=0.0004, tanh_log_policy=-0.0000, unclipped_policy=0.0004, value=1.8611, weighted_ce=-0.0001, weighted_entropy=0.0002, weighted_kl_div=0.0001]   

Saved iteration 3 model files to /home/ubuntu/atreides/experiments/models/rl25/0003
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl25/0003 --port=8003 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl25/config.yaml


1|69|Loss: 0.0304: 100%|██████████| 69/69 [19:10<00:00, 16.52s/it, entropy=0.5601, entropy_target=0.1899, kl_div=0.1153, policy=0.0021, tanh_log_policy=-0.0001, unclipped_policy=0.0020, value=1.5692, weighted_ce=-0.0000, weighted_entropy=0.0005, weighted_kl_div=-0.0003]  

Saved iteration 4 model files to /home/ubuntu/atreides/experiments/models/rl25/0004
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl25/0004 --port=8003 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl25/config.yaml


1|73|Loss: 0.0248: 100%|██████████| 73/73 [20:18<00:00, 16.64s/it, entropy=0.6586, entropy_target=0.0914, kl_div=0.1552, policy=-0.0009, tanh_log_policy=0.0001, unclipped_policy=-0.0010, value=1.9242, weighted_ce=0.0004, weighted_entropy=-0.0006, weighted_kl_div=-0.0001]  

Saved iteration 5 model files to /home/ubuntu/atreides/experiments/models/rl25/0005
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl25/0005 --port=8003 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl25/config.yaml


1|76|Loss: 0.0310: 100%|██████████| 76/76 [21:08<00:00, 16.43s/it, entropy=0.5290, entropy_target=0.2210, kl_div=0.0878, policy=-0.0008, tanh_log_policy=0.0000, unclipped_policy=-0.0009, value=2.1550, weighted_ce=0.0002, weighted_entropy=-0.0005, weighted_kl_div=-0.0001] 

Saved iteration 6 model files to /home/ubuntu/atreides/experiments/models/rl25/0006
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl25/0006 --port=8003 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl25/config.yaml


1|74|Loss: 0.0229: 100%|██████████| 74/74 [20:31<00:00, 16.42s/it, entropy=0.6260, entropy_target=0.1240, kl_div=0.1042, policy=-0.0004, tanh_log_policy=0.0000, unclipped_policy=-0.0004, value=2.6191, weighted_ce=0.0001, weighted_entropy=-0.0005, weighted_kl_div=-0.0001] 

Saved iteration 7 model files to /home/ubuntu/atreides/experiments/models/rl25/0007
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl25/0007 --port=8003 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)
