In [None]:
%load_ext autoreload
%autoreload 2

In [17]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [18]:
import asyncio
import json
from lib.clue import Clue, DeductiveSolver
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import ExploreOptions, Trainer, vLLMConfig
from lib.utils import return_exception
import random
import re
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from typing import Literal, Optional

with open("./data/chain-of-thought-examples.json") as f:
    chain_of_thought_examples: list[dict[str, str]] = json.load(f)


def get_variable_difficulty_game(
    return_first_solver_as_winner: Optional[bool] = None,
) -> Clue:
    num_players = random.randint(3, 6)
    num_weapons = max(
        3,
        min(
            num_players + random.randint(-1, 5),
            len(Clue.weapons),
        ),
    )
    num_suspects = min(
        num_weapons + random.randint(0, num_weapons - 1), len(Clue.suspects)
    )
    num_rooms = min(num_suspects + random.randint(0, num_suspects - 2), len(Clue.rooms))
    elements = {
        "suspect": random.sample(Clue.suspects, k=num_suspects),
        "weapon": random.sample(Clue.weapons, k=num_weapons),
        "room": random.sample(Clue.rooms, k=num_rooms),
    }
    if random.random() < 0.1:
        elements["motive"] = random.sample(
            Clue.motives,
            k=max(3, min(num_weapons + random.randint(-1, 3), len(Clue.motives))),
        )
    if random.random() < 0.1:
        frequency = random.choice([0.25, 0.5, 1.0])
        start = 24.0 - frequency
        end = 0.0
        for _ in range(random.randint(1, num_weapons + 1)):
            if random.randint(0, 1):
                end += frequency
            else:
                start -= frequency

        def format_time(time: float) -> str:
            return f"{int(time):02d}:{int(60 * (time - int(time))):02d}"

        elements["time"] = Clue.get_times(
            format_time(start), format_time(end), f"{int(frequency * 60)}min"
        )
    game = Clue(
        num_players=num_players,
        elements=elements,
    )
    difficulty_level = num_players + random.randint(-2, 3)
    # print(f"Players: {num_players}")
    # for element in elements:
    #     print(f"{element.capitalize()}: {len(elements[element])}")
    # print(f"Difficulty level: {difficulty_level}")
    return game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=difficulty_level > 1,
            check_one_of_constraints=difficulty_level > 2,
            check_inverse_one_of_constraints=difficulty_level > 3,
            merge_and_check_disjoint_inverse_one_of_constraints=difficulty_level > 4,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=(
            bool(random.randint(0, 1))
            if return_first_solver_as_winner is None
            else return_first_solver_as_winner
        ),
        print_playthrough=False,
        max_turns=100,
    )


def get_easy_game(return_first_solver_as_winner: Optional[bool] = None) -> Clue:
    game = Clue(
        num_players=3,
        elements={
            "suspect": random.sample(Clue.suspects, k=3),
            "weapon": random.sample(Clue.weapons, k=3),
            "room": random.sample(Clue.rooms, k=3),
            # "motive": random.sample(Clue.motives, k=3),
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=return_first_solver_as_winner or False,
        print_playthrough=False,
        max_turns=100,
    )
    return game


@return_exception
def sample_random_episode(
    difficulty: Literal["easy", "variable"] = "variable",
    example_probability: float = 0.0,
    max_prompt_characters: int = 8192,
    reward_follow_up_completion: bool = True,
    return_first_solver_as_winner: Optional[bool] = None,
) -> Episode:
    while True:
        try:
            game = (
                get_easy_game if difficulty == "easy" else get_variable_difficulty_game
            )(return_first_solver_as_winner=return_first_solver_as_winner)
            prompt, follow_up, solution = game.get_prompt_and_follow_up_and_solution()
        except ValueError:
            continue
        if len(prompt) <= max_prompt_characters:
            break

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ],
                max_tokens=16 + len(solution) * 16,
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        if reward_follow_up_completion:
            completion = follow_up_completion
        completion.reward = sum(
            [
                bool(
                    # Find first match of key followed by colon and capture following text
                    (
                        match := re.search(
                            rf"{key}: ([A-Za-z \.:-]+)",
                            answer,
                            re.IGNORECASE,
                        )
                    )
                    # Check if captured group matches expected value
                    and match.group(1).strip().lower() == value.strip().lower()
                )
                for key, value in solution.items()
            ]
        ) / len(solution)
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    example = (
        random.choice(chain_of_thought_examples)
        if random.random() < example_probability
        else None
    )

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        examples=(
            [
                {"role": "user", "content": example["prompt"]},
                {
                    "role": "assistant",
                    "content": example["chain_of_thought"]
                    + (example["answer"] and f"\n\n---\n\n{example['answer']}"),
                },
            ]
            if example
            else []
        ),
        on_sample=on_sample,
    )


def train_episodes():
    while True:
        yield sample_random_episode()


model_name = "rl24"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    explore_options=ExploreOptions(
        iterations=7,
        num_parents=6,
        branch_factor=3,
        patience=5,
        sample_probability_power=None,
        sampling_kwargs={
            "max_tokens": 1024,
            "logit_bias": {

            },
        },
    ),
    train_episodes=train_episodes(),
    episodes_per_iteration=64 * torch.cuda.device_count(),
    max_mask_sequence_batch_size=1,
    val_episodes=(
        sample_random_episode() for _ in range(64 * torch.cuda.device_count())
    ),
    val_patience=15,
    val_samples_per_episode=3,
    val_sampling_kwargs={"max_tokens": 1024},
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=False,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=1,
        epochs=1,
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=4e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=0.8,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.6,
            entropy_target_coef=0.15,
            kl_coef=0.15,
            weighted_entropy_coef=0.1,
            weighted_kl_coef=0.2,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=16384,
    vllm_config=vLLMConfig(
        env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
        kwargs=dict(
            block_size=32,
            disable_log_requests=True,
            enable_prefix_caching=True,
            enforce_eager=True,
            gpu_memory_utilization=0.85,
            max_model_len=16384,
            max_num_seqs=512,
            max_num_batched_tokens=16384 * 4,
            return_tokens_as_token_ids=True,
            swap_space=32,
            preemption_mode="swap",
        ),
        max_concurrent_samples=512,
        timeout=120 + 15 * torch.cuda.device_count(),
    ),
    wandb_kwargs=dict(
        name=model_name,
        id=model_name,
    ),
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
def tokens(string: str) -> list[int]:
    tokens = trainer.tokenizer.llm.get_tokenizer().encode(
        string, add_special_tokens=False
    )
    if len(tokens) > 1:
        return []
    return tokens

word_biases = [
    # Critical thinking markers (highest importance)
    ("wait", 0.9),    # Frequent correction initiator
    ("hmm", 0.8),     # Common uncertainty marker
    ("alternatively", 0.8),  # Key for exploring options
    
    # Important correction/uncertainty markers
    ("perhaps", 0.7),  # Common for tentative suggestions
    ("actually", 0.7), # Important for corrections
    ("rather", 0.7),   # Refinement marker
    
    # Analysis progression markers
    ("first", 0.6),
    ("similarly", 0.6),
    ("therefore", 0.6),
    ("instead", 0.6),
    
    # Mistake acknowledgment
    ("wrong", 0.5),
    ("incorrect", 0.5),
    ("no", 0.5),
    
    # Supporting markers (lower importance)
    ("but", 0.4),
    ("so", 0.4),
    ("thus", 0.4),
    ("now", 0.4),
    ("given", 0.4),
    ("suppose", 0.4),
    
    # Light correction markers
    ("oh", 0.3),
    ("oops", 0.3),
    ("right", 0.3)
]

logit_bias = {
    token: bias
    for word, bias in word_biases
    for transform in [
        lambda x: x,
        str.capitalize,
        lambda x: " " + x
    ]
    for token in tokens(transform(word))
}
logit_bias

In [50]:
trainer.tokenizer.llm.get_tokenizer().encode("...Hmm", add_special_tokens=False)

[1131, 81122]

In [22]:
trainer.tokenizer.encode([{"role": "user", "content": "But wait"}])

tensor([128000, 128002,    882,    198,   4071,   3868])

In [8]:
await trainer.eval("val", verbosity=1)

Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl24/0005 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.85 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --preemption-mode=swap --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val: 0episode [00:00, ?episode/s]

(0.20867374870072775, [])

In [4]:
await trainer.train(iterations=3, verbosity=1)

Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl24/0001 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.85 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val: 0episode [00:00, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl24/config.yaml


1|53|Loss: 0.0343: 100%|██████████| 53/53 [15:55<00:00, 17.74s/it, entropy=0.6107, entropy_target=0.0107, kl_div=0.2170, policy=-0.0028, tanh_log_policy=0.0002, unclipped_policy=-0.0045, value=1.6921, weighted_ce=0.0023, weighted_entropy=-0.0026, weighted_kl_div=-0.0014] 

Saved iteration 2 model files to /home/ubuntu/atreides/experiments/models/rl24/0002
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl24/0002 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.85 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl24/config.yaml


1|68|Loss: 0.0379: 100%|██████████| 68/68 [20:17<00:00, 17.77s/it, entropy=0.4352, entropy_target=0.1648, kl_div=0.0902, policy=0.0004, tanh_log_policy=-0.0003, unclipped_policy=-0.0006, value=2.6518, weighted_ce=-0.0016, weighted_entropy=0.0020, weighted_kl_div=0.0005]   

Saved iteration 3 model files to /home/ubuntu/atreides/experiments/models/rl24/0003
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl24/0003 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.85 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

<class 'ProcessLookupError'> 
$ tune run lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl24/config.yaml


1|6|Loss: 0.0219: 100%|██████████| 6/6 [01:58<00:00, 18.32s/it, entropy=0.6080, entropy_target=0.0080, kl_div=0.1713, policy=0.0608, tanh_log_policy=-0.0041, unclipped_policy=0.0551, value=1.4917, weighted_ce=-0.0248, weighted_entropy=0.0363, weighted_kl_div=0.0096]  

Saved iteration 4 model files to /home/ubuntu/atreides/experiments/models/rl24/0004
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl24/0004 --port=8000 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.85 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=32 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]