In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [7]:
import asyncio
import json
from lib.clue import Clue, DeductiveSolver
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import ExploreOptions, Trainer, vLLMConfig
from lib.utils import return_exception
import random
import re
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from typing import Literal, Optional

with open("./data/chain-of-thought-examples.json") as f:
    chain_of_thought_examples: list[dict[str, str]] = json.load(f)


def get_variable_difficulty_game(
    return_first_solver_as_winner: Optional[bool] = None,
) -> Clue:
    num_players = random.randint(3, 6)
    num_weapons = max(
        3,
        min(
            num_players + random.randint(-1, 5),
            len(Clue.weapons),
        ),
    )
    num_suspects = min(
        num_weapons + random.randint(0, num_weapons - 1), len(Clue.suspects)
    )
    num_rooms = min(num_suspects + random.randint(0, num_suspects - 2), len(Clue.rooms))
    elements = {
        "suspect": random.sample(Clue.suspects, k=num_suspects),
        "weapon": random.sample(Clue.weapons, k=num_weapons),
        "room": random.sample(Clue.rooms, k=num_rooms),
    }
    if random.random() < 0.1:
        elements["motive"] = random.sample(
            Clue.motives,
            k=max(3, min(num_weapons + random.randint(-1, 3), len(Clue.motives))),
        )
    if random.random() < 0.1:
        frequency = random.choice([0.25, 0.5, 1.0])
        start = 24.0 - frequency
        end = 0.0
        for _ in range(random.randint(1, num_weapons + 1)):
            if random.randint(0, 1):
                end += frequency
            else:
                start -= frequency

        def format_time(time: float) -> str:
            return f"{int(time):02d}:{int(60 * (time - int(time))):02d}"

        elements["time"] = Clue.get_times(
            format_time(start), format_time(end), f"{int(frequency * 60)}min"
        )
    game = Clue(
        num_players=num_players,
        elements=elements,
    )
    difficulty_level = num_players + random.randint(-2, 3)
    # print(f"Players: {num_players}")
    # for element in elements:
    #     print(f"{element.capitalize()}: {len(elements[element])}")
    # print(f"Difficulty level: {difficulty_level}")
    return game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=difficulty_level > 1,
            check_one_of_constraints=difficulty_level > 2,
            check_inverse_one_of_constraints=difficulty_level > 3,
            merge_and_check_disjoint_inverse_one_of_constraints=difficulty_level > 4,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=(
            bool(random.randint(0, 1))
            if return_first_solver_as_winner is None
            else return_first_solver_as_winner
        ),
        print_playthrough=False,
        max_turns=100,
    )


def get_easy_game(return_first_solver_as_winner: Optional[bool] = None) -> Clue:
    game = Clue(
        num_players=3,
        elements={
            "suspect": random.sample(Clue.suspects, k=3),
            "weapon": random.sample(Clue.weapons, k=3),
            "room": random.sample(Clue.rooms, k=3),
            # "motive": random.sample(Clue.motives, k=3),
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=return_first_solver_as_winner or False,
        print_playthrough=False,
        max_turns=100,
    )
    return game


@return_exception
def sample_random_episode(
    difficulty: Literal["easy", "variable"] = "variable",
    example_probability: float = 0.0,
    max_prompt_characters: int = 8192,
    reward_follow_up_completion: bool = True,
    return_first_solver_as_winner: Optional[bool] = None,
) -> Episode:
    while True:
        try:
            game = (
                get_easy_game if difficulty == "easy" else get_variable_difficulty_game
            )(return_first_solver_as_winner=return_first_solver_as_winner)
            prompt, follow_up, solution = game.get_prompt_and_follow_up_and_solution()
        except ValueError:
            continue
        if len(prompt) <= max_prompt_characters:
            break

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ],
                max_tokens=10
                + len("\n".join(f"{key}: {value}" for key, value in solution.items()))
                // 2,
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        if reward_follow_up_completion:
            completion = follow_up_completion
        completion.reward = sum(
            [
                bool(
                    # Find first match of key followed by colon and capture following text
                    (
                        match := re.search(
                            rf"{key}: ([A-Za-z \.:-]+)",
                            answer,
                            re.IGNORECASE,
                        )
                    )
                    # Check if captured group matches expected value
                    and match.group(1).strip().lower() == value.strip().lower()
                )
                for key, value in solution.items()
            ]
        ) / len(solution)
        completion.reward -= (
            completion.all_absent_stop_tokens
            / (3 if reward_follow_up_completion else 2)
            / len(solution)
        )
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    example = random.choice(chain_of_thought_examples)

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        examples=lambda: (
            [
                {"role": "user", "content": example["prompt"]},
                {
                    "role": "assistant",
                    "content": example["chain_of_thought"]
                    + (example["answer"] and f"\n\n---\n\n{example['answer']}"),
                },
            ]
            if random.random() < example_probability
            else []
        ),
        on_sample=on_sample,
    )


def train_episodes():
    while True:
        yield sample_random_episode(
            example_probability=0.05, reward_follow_up_completion=random.random() < 0.5
        )


model_name = "rl30"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    explore_options=ExploreOptions(
        iterations=9,
        num_parents=5,
        branch_factor=3,
        patience=5,
        sample_probability_power=None,
        sampling_kwargs={"max_tokens": 1024},
        split_method="prob",
        split_point_std_deviation=0.5,
    ),
    train_episodes=train_episodes(),
    episodes_per_iteration=64 * torch.cuda.device_count(),
    max_mask_sequence_batch_size=1,
    val_episodes=(
        sample_random_episode() for _ in range(64 * torch.cuda.device_count())
    ),
    val_patience=15,
    val_samples_per_episode=3,
    val_sampling_kwargs={"max_tokens": 1024},
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=False,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=1,
        epochs=1,
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=4e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=0.8,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.7,
            entropy_target_coef=0.2,
            kl_coef=0.3,
            weighted_entropy_coef=0.2,
            weighted_kl_coef=0.0,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=16384,
    vllm_config=vLLMConfig(
        env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
        kwargs=dict(
            block_size=32,
            disable_log_requests=True,
            enable_prefix_caching=True,
            enforce_eager=True,
            gpu_memory_utilization=0.9,
            max_model_len=16384,
            max_num_seqs=512,
            max_num_batched_tokens=16384 * 4,
            preemption_mode="swap",
            return_tokens_as_token_ids=True,
            swap_space=64,
        ),
        max_concurrent_samples=512,
        timeout=120 + 15 * torch.cuda.device_count(),
    ),
    wandb_kwargs=dict(
        name=model_name,
        id=model_name,
    ),
)

INFO 12-13 21:20:43 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=NousResearch/Hermes-2-Theta-Llama-3-8B, num_scheduler_steps=1, chunked_prefill_enabled=Fal

In [None]:
# This one went really really badly 7th iteration only got a average score of 0.0713

In [9]:
await trainer.train(iterations=4, verbosity=1)

val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining episodes x 15 patience per episode = 0 seconds)
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl30/config.yaml


1|43|Loss: 0.0320: 100%|██████████| 43/43 [13:01<00:00, 18.09s/it, entropy=0.6986, entropy_target=0.0014, kl_div=0.1100, policy=0.0025, tanh_log_policy=-0.0002, unclipped_policy=-0.0010, value=1.5981, weighted_ce=-0.0002, weighted_entropy=0.0060, weighted_kl_div=0.0014] 

Saved iteration 5 model files to /home/ubuntu/atreides/experiments/models/rl30/0005
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl30/0005 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=64 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl30/config.yaml


1|40|Loss: 0.0787: 100%|██████████| 40/40 [12:10<00:00, 17.71s/it, entropy=0.9082, entropy_target=0.2082, kl_div=0.0969, policy=-0.0061, tanh_log_policy=0.0067, unclipped_policy=-0.0423, value=0.9072, weighted_ce=0.0132, weighted_entropy=-0.0130, weighted_kl_div=-0.0031] 

Saved iteration 6 model files to /home/ubuntu/atreides/experiments/models/rl30/0006
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl30/0006 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=64 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (1 remaining episodes x 15 patience per episode = 15 seconds)
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl30/config.yaml


1|67|Loss: 0.0442: 100%|██████████| 67/67 [20:14<00:00, 18.17s/it, entropy=0.7399, entropy_target=0.0399, kl_div=0.1170, policy=0.0104, tanh_log_policy=0.0006, unclipped_policy=0.0032, value=1.0848, weighted_ce=0.0042, weighted_entropy=-0.0029, weighted_kl_div=0.0011]    

Saved iteration 7 model files to /home/ubuntu/atreides/experiments/models/rl30/0007
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl30/0007 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=64 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

In [5]:
trainer.explore_results[-1].exceptions

[openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError('Internal Server Error'),
 openai.InternalServerError(