In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
from openai import AsyncOpenAI, Timeout

reference_client = AsyncOpenAI(
    api_key="default",
    base_url="http://209.20.157.218:8000/v1",
    timeout=Timeout(600, connect=60),
)

In [4]:
chat_completion = await reference_client.chat.completions.create(
    messages=[{"role": "user", "content": "Hello!"}],
    model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    logprobs=True,
    top_logprobs=20,
)
chat_completion.choices[0].logprobs.content[0].top_logprobs

[TopLogprob(token='Hello', bytes=[72, 101, 108, 108, 111], logprob=-0.00798675324767828),
 TopLogprob(token='Hi', bytes=[72, 105], logprob=-4.834799289703369),
 TopLogprob(token='Hey', bytes=[72, 101, 121], logprob=-12.700722694396973),
 TopLogprob(token='Greetings', bytes=[71, 114, 101, 101, 116, 105, 110, 103, 115], logprob=-12.879494667053223),
 TopLogprob(token='Welcome', bytes=[87, 101, 108, 99, 111, 109, 101], logprob=-14.667202949523926),
 TopLogprob(token='Hallo', bytes=[72, 97, 108, 108, 111], logprob=-14.845974922180176),
 TopLogprob(token=' Hello', bytes=[32, 72, 101, 108, 108, 111], logprob=-15.203516960144043),
 TopLogprob(token='Hola', bytes=[72, 111, 108, 97], logprob=-16.633684158325195),
 TopLogprob(token='hello', bytes=[104, 101, 108, 108, 111], logprob=-18.60016441345215),
 TopLogprob(token='Bonjour', bytes=[66, 111, 110, 106, 111, 117, 114], logprob=-19.136476516723633),
 TopLogprob(token='Nice', bytes=[78, 105, 99, 101], logprob=-19.4940185546875),
 TopLogprob(toke

In [5]:
import torch
logprobs = [[top_logprob.logprob for top_logprob in logprob.top_logprobs] for logprob in chat_completion.choices[0].logprobs.content]
entropy = torch.distributions.Categorical(probs=torch.exp(torch.tensor(logprobs))).entropy()
entropy

tensor([4.6440e-02, 4.0069e-02, 6.9702e-01, 2.8915e-02, 6.7287e-06, 6.7746e-01,
        2.9725e-04, 9.6658e-04, 1.3310e-04, 1.3671e+00, 2.9643e-05, 3.3668e-02,
        5.3873e-02, 1.5532e-04, 1.2505e-01, 2.8845e-01, 9.9986e-07, 1.2106e-07,
        1.0927e-05, 1.2103e+00, 4.1731e-01, 2.5629e-02, 1.4311e-03, 4.3575e-05,
        1.8644e-01, 4.9029e-03, 8.1870e-01, 1.0115e-02])

In [6]:
import asyncio
from lib.clue import Clue, DeductiveSolver
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import Trainer
from lib.utils import return_exception
import random
import re
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from typing import Literal, Optional


def get_variable_difficulty_game(
    return_first_solver_as_winner: Optional[bool] = None,
) -> Clue:
    num_players = random.randint(3, 6)
    num_weapons = max(
        3,
        min(
            num_players + random.randint(-1, 5),
            len(Clue.weapons),
        ),
    )
    num_suspects = min(
        num_weapons + random.randint(0, num_weapons - 1), len(Clue.suspects)
    )
    num_rooms = min(num_suspects + random.randint(0, num_suspects - 2), len(Clue.rooms))
    elements = {
        "suspect": random.sample(Clue.suspects, k=num_suspects),
        "weapon": random.sample(Clue.weapons, k=num_weapons),
        "room": random.sample(Clue.rooms, k=num_rooms),
    }
    if random.random() < 0.1:
        elements["motive"] = random.sample(
            Clue.motives,
            k=max(3, min(num_weapons + random.randint(-1, 3), len(Clue.motives))),
        )
    if random.random() < 0.1:
        frequency = random.choice([0.25, 0.5, 1.0])
        start = 24.0 - frequency
        end = 0.0
        for _ in range(random.randint(1, num_weapons + 1)):
            if random.randint(0, 1):
                end += frequency
            else:
                start -= frequency

        def format_time(time: float) -> str:
            return f"{int(time):02d}:{int(60 * (time - int(time))):02d}"

        elements["time"] = Clue.get_times(
            format_time(start), format_time(end), f"{int(frequency * 60)}min"
        )
    game = Clue(
        num_players=num_players,
        elements=elements,
    )
    difficulty_level = num_players + random.randint(-2, 3)
    # print(f"Players: {num_players}")
    # for element in elements:
    #     print(f"{element.capitalize()}: {len(elements[element])}")
    # print(f"Difficulty level: {difficulty_level}")
    return game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=difficulty_level > 1,
            check_one_of_constraints=difficulty_level > 2,
            check_inverse_one_of_constraints=difficulty_level > 3,
            merge_and_check_disjoint_inverse_one_of_constraints=difficulty_level > 4,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=(
            bool(random.randint(0, 1))
            if return_first_solver_as_winner is None
            else return_first_solver_as_winner
        ),
        print_playthrough=False,
        max_turns=100,
    )


def get_easy_game(return_first_solver_as_winner: Optional[bool] = None) -> Clue:
    game = Clue(
        num_players=3,
        elements={
            "suspect": random.sample(Clue.suspects, k=3),
            "weapon": random.sample(Clue.weapons, k=3),
            "room": random.sample(Clue.rooms, k=3),
            # "motive": random.sample(Clue.motives, k=3),
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        return_first_solver_as_winner=return_first_solver_as_winner or False,
        print_playthrough=False,
        max_turns=100,
    )
    return game


@return_exception
def sample_random_episode(
    difficulty: Literal["easy", "variable"] = "variable",
    max_prompt_characters: int = 8192,
    reward_follow_up_completion: bool = True,
    return_first_solver_as_winner: Optional[bool] = None,
) -> Episode:
    while True:
        try:
            game = (get_easy_game if difficulty == "easy" else get_variable_difficulty_game)(
                return_first_solver_as_winner=return_first_solver_as_winner
            )
            prompt, follow_up, solution = game.get_prompt_and_follow_up_and_solution()
        except ValueError:
            continue
        if len(prompt) <= max_prompt_characters:
            break

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ]
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        if reward_follow_up_completion:
            completion = follow_up_completion
        completion.reward = sum(
            [
                bool(
                    # Find first match of key followed by colon and capture following text
                    (
                        match := re.search(
                            rf"{key}: ([A-Za-z \.:-]+)",
                            answer,
                            re.IGNORECASE,
                        )
                    )
                    # Check if captured group matches expected value
                    and match.group(1).strip().lower() == value.strip().lower()
                )
                for key, value in solution.items()
            ]
        ) / len(solution)
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )


def train_episodes():
    while True:
        yield sample_random_episode()


model_name = "rl5"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    samples_per_episode=81,
    branch_factor=3,
    reference_clients_and_model=(
        [reference_client],
        "NousResearch/Hermes-2-Theta-Llama-3-8B",
    ),
    sample_probability_power=0,
    train_episodes=train_episodes(),
    episodes_per_iteration=64 * torch.cuda.device_count(),
    patience_per_episode=5,
    patience_per_val_sample=10,
    sampling_kwargs={
        "max_tokens": 1024,
    },
    max_mask_sequence_batch_size=1,
    val_episodes=(
        sample_random_episode() for _ in range(64 * torch.cuda.device_count())
    ),
    val_samples_per_episode=3,
    torchrun_kwargs=dict(nnodes=1, nproc_per_node=torch.cuda.device_count()),
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=False,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=1,
        epochs=1,
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=5e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=0.8,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.6,
            entropy_target_coef=0.05,
            kl_coef=0.05,
            weighted_entropy_coef=0.2,
            weighted_kl_coef=0.0,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=16384,
    vllm_env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
    vllm_kwargs=dict(
        block_size=32,
        disable_log_requests=True,
        enable_prefix_caching=True,
        enforce_eager=True,
        gpu_memory_utilization=0.95,
        max_model_len=16384,
        max_num_seqs=512,
        max_num_batched_tokens=16384 * 4,
        return_tokens_as_token_ids=True,
        swap_space=8,
        # scheduling_policy="priority",
        # tensor_parallel_size=torch.cuda.device_count() // 8,
    ),
    vllm_max_concurrent_samples=512 * torch.cuda.device_count(),
    vllm_min_time_between_requests=0.0,
    vllm_num=torch.cuda.device_count(),
    vllm_timeout=120 + 15 * torch.cuda.device_count(),
    wandb_kwargs=dict(
        name=model_name,
        id=model_name,
    ),
)

Resuming from /home/ubuntu/atreides/experiments/models/rl5/0002
INFO 12-06 20:12:09 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=NousResearch/Hermes-2-Thet

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
await trainer.train(iterations=2, verbosity=1)

Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl5/0002 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val: 0episode [00:00, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

In [8]:
trainer.eval_exceptions["val"]

[]

In [9]:
await trainer.eval("val")

val:   0%|          | 0/64 [00:00<?, ?episode/s]

(0.18576218313537646, [])

In [18]:
list(trainer.eval_episodes["val"][0].completion.leaves())[0].value()

0.7777777777777778

In [19]:
list(trainer.eval_episodes["val"][0].completion.leaves())[0].all_message_params()

[{'role': 'user',
  'content': "On a cool spring day Audrey, Samantha, and Noah sat down to play a casual deduction game.\n\nThey gathered 3 decks of cards, each for a different type of data composed of the following:\n\nSuspect:\n- Madame Rose\n- Colonel Mustard\n- Sgt. Gray\n\nWeapon:\n- Rope\n- Knife\n- Poison\n\nRoom:\n- Kitchen\n- Study\n- Fountain\n\nAfter randomly (and blindly) choosing one card from each group and placing them in the center of the table facedown, they shuffled the remaining cards and dealt out the following to each player:\n\n- Audrey: 2 cards (Poison and Study)\n- Samantha: 2 cards\n- Noah: 2 cards\n\nThe game proceeded as follows:\n\n1. On their turn, a player asked about a set of exactly 3 cards, one from each of the game's categories. (Note: Players could ask about any cards, including those in their own hand.)\n2. The player directed this question to the other players in clockwise order, starting with the player to their left.\n3. If a player had one or mo

In [6]:
await trainer.stop_vllms()

In [11]:
explore_result = await trainer.explore(verbosity=1)

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

In [8]:
explore_result = await trainer.explore(verbosity=1)

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

In [9]:
explore_result.exceptions

[openai.APIConnectionError('Connection error.'),
 openai.APIConnectionError('Connection error.')]

In [10]:
trainer.tune_recipe_config.metric_logger.name = f"{model_name}_tune"
trainer.tune_recipe_config.metric_logger.id = f"{model_name}_tune"

In [8]:
explore_result.tensors()['tokens'].shape

torch.Size([70, 16384])

In [13]:
trainer.tune_recipe_config.batch_size = 1

In [14]:
await trainer.tune(explore_result, verbosity=2)

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl3/config.yaml
Running with torchrun...


INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.checkpointing._checkpointer.FullModelHFCheckpointer
  checkpoint_dir: /home/ubuntu/atreides/experiments/models/rl3/0001
  checkpoint_files:
  - /home/ubuntu/atreides/experiments/models/rl3/0001/hf_model_0003_0.pt
  - /home/ubuntu/atreides/experiments/models/rl3/0001/hf_model_0004_0.pt
  - /home/ubuntu/atreides/experiments/models/rl3/0001/hf_model_0001_0.pt
  - /home/ubuntu/atreides/experiments/models/rl3/0001/hf_model_0002_0.pt
  model_type: LLAMA3
  output_dir: /home/ubuntu/atreides/experiments/models/rl3
  recipe_checkpoint: null
compile: false
custom_sharded_layers:
- tok_embeddings
- output
dataset:
  _component_: lib.rl.pack.PackedDataset
  dir: /home/ubuntu/atreides/experiments/models/rl3/tensors
  num_sequences: 74
  sequence_length: 16384
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
ep

Saved iteration 2 model files to /home/ubuntu/atreides/experiments/models/rl3/0002


In [7]:
trainer.tune_recipe_config.loss.entropy_target_coef = 0.1
trainer.tune_recipe_config.loss.kl_coef = 0.05

In [52]:
await trainer.train(iterations=4, verbosity=1)

val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

Exception ignored in: <function vLLM.__del__ at 0x7b46653f8720>
Traceback (most recent call last):
  File "/home/ubuntu/atreides/experiments/lib/vllm.py", line 101, in __del__
    self.process.terminate()
  File "/home/ubuntu/.local/share/uv/python/cpython-3.12.7-linux-x86_64-gnu/lib/python3.12/asyncio/subprocess.py", line 143, in terminate
    self._transport.terminate()
  File "/home/ubuntu/.local/share/uv/python/cpython-3.12.7-linux-x86_64-gnu/lib/python3.12/asyncio/base_subprocess.py", line 149, in terminate
    self._check_proc()
  File "/home/ubuntu/.local/share/uv/python/cpython-3.12.7-linux-x86_64-gnu/lib/python3.12/asyncio/base_subprocess.py", line 142, in _check_proc
    raise ProcessLookupError()
ProcessLookupError: 


$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl3/config.yaml


1|22|Loss: 0.0106: 100%|██████████| 22/22 [04:21<00:00, 11.49s/it, entropy=0.5910, entropy_target=0.0910, kl_div=0.0428, policy=0.0069, tanh_log_policy_to_log=-0.0006, unclipped_policy=-0.0006, value=1.3990, weighted_ce=-0.0050, weighted_entropy=0.0047, weighted_kl_div=0.0011]  

Saved iteration 15 model files to /home/ubuntu/atreides/experiments/models/rl3/0015
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl3/0015 --port=8006 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl3/config.yaml


1|20|Loss: 0.0076: 100%|██████████| 20/20 [03:57<00:00, 11.48s/it, entropy=0.5419, entropy_target=0.0419, kl_div=0.0327, policy=0.1791, tanh_log_policy_to_log=0.0018, unclipped_policy=0.1565, value=2.0940, weighted_ce=-0.0292, weighted_entropy=-0.0079, weighted_kl_div=-0.0013]

Saved iteration 16 model files to /home/ubuntu/atreides/experiments/models/rl3/0016
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl3/0016 --port=8006 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl3/config.yaml


1|25|Loss: 0.0030: 100%|██████████| 26/26 [05:08<00:00, 11.55s/it, entropy=0.5285, entropy_target=0.0285, kl_div=0.0293, policy=0.0086, tanh_log_policy_to_log=-0.0013, unclipped_policy=-0.0057, value=2.1253, weighted_ce=-0.0212, weighted_entropy=0.0078, weighted_kl_div=-0.0007] 

Saved iteration 17 model files to /home/ubuntu/atreides/experiments/models/rl3/0017
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl3/0017 --port=8006 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

Exception ignored in: <function vLLM.__del__ at 0x7b46653f8720>
Traceback (most recent call last):
  File "/home/ubuntu/atreides/experiments/lib/vllm.py", line 101, in __del__
    self.process.terminate()
  File "/home/ubuntu/.local/share/uv/python/cpython-3.12.7-linux-x86_64-gnu/lib/python3.12/asyncio/subprocess.py", line 143, in terminate
    self._transport.terminate()
  File "/home/ubuntu/.local/share/uv/python/cpython-3.12.7-linux-x86_64-gnu/lib/python3.12/asyncio/base_subprocess.py", line 149, in terminate
    self._check_proc()
  File "/home/ubuntu/.local/share/uv/python/cpython-3.12.7-linux-x86_64-gnu/lib/python3.12/asyncio/base_subprocess.py", line 142, in _check_proc
    raise ProcessLookupError()
ProcessLookupError: 


Early stopping val evaluation due to expired patience (0 remaining samples x 10 patience per sample = 0 seconds)
$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl3/config.yaml


1|19|Loss: 0.0085: 100%|██████████| 20/20 [04:08<00:00, 11.54s/it, entropy=0.4671, entropy_target=0.0329, kl_div=0.0323, policy=0.0063, tanh_log_policy_to_log=0.0036, unclipped_policy=-0.0249, value=1.4129, weighted_ce=0.0180, weighted_entropy=-0.0283, weighted_kl_div=-0.0014]  

Saved iteration 18 model files to /home/ubuntu/atreides/experiments/models/rl3/0018
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl3/0018 --port=8006 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=65536 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/64 [00:00<?, ?episode/s]

In [70]:
terminus = list(trainer.eval_episodes["val"][10].completion.leaves(model='/home/ubuntu/atreides/experiments/models/rl3/0018'))[2]
print(terminus.value())
for message in terminus.all_message_params():
    print(message["role"].capitalize() + ":")
    print(message["content"], end="\n\n")

0.55
User:
On a warm autumn morning Conner, Arianna, Kaleb, and Daisy sat down to play a casual deduction game.

They assembled 3 stacks of cards, each for a different type of data composed of the following:

Suspect:
- Madame Rose
- Mr. Green
- Colonel Mustard
- Sgt. Gray
- Monsieur Brunette
- Professor Plum

Weapon:
- Knife
- Revolver
- Candlestick
- Lead Pipe
- Poison

Room:
- Billiard Room
- Cloak Room
- Courtyard
- Hall
- Dining Room
- Studio
- Library
- Fountain
- Kitchen

After randomly (and blindly) choosing one card from each group and placing them in the middle of the table facedown, they shuffled the remaining cards and dealt out the following to each player:

- Conner: 4 cards (Billiard Room, Revolver, Studio, and Candlestick)
- Arianna: 4 cards
- Kaleb: 4 cards
- Daisy: 5 cards

The game proceeded as follows:

1. On their turn, a player asked about a set of exactly 3 cards, one from each of the game's categories. (Note: Players could ask about any cards, including those in

In [6]:
explore_result = await trainer.explore()

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

Early stopping exploration due to expired patience (36.11519999998988 remaining samples x 0.037037037037037035 patience per sample = 1.3375999999996253 seconds)


In [11]:
await trainer.tune(explore_result)

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl2/config.yaml
Running with torchrun...


INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.checkpointing._checkpointer.FullModelHFCheckpointer
  checkpoint_dir: /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725
  checkpoint_files:
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00004-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00001-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00002-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3

CancelledError: 

In [11]:
from openai import AsyncOpenAI

completion_sampler = await trainer.get_completion_sampler()
client: AsyncOpenAI = completion_sampler.samplers[0].client  # type: ignore

In [None]:
vllm.client.chat.completions.create()

In [None]:
from typing import Any

episode = explore_result.episodes[0]
completion = next(iter(episode.completion.leaves()))
tokens = completion.all_tokens(trainer.tokenizer, cache=True).tolist()
plain_completion = await vllm.client.completions.create(
    model=trainer.model,
    prompt=tokens,
    max_tokens=1,
    extra_body={
        "prompt_logprobs": True,
    },
)
prompt_logprobs: list[dict[str, dict[str, Any]]] = plain_completion.choices[0].prompt_logprobs  # type: ignore
prompt_logprobs

In [56]:
assert len(prompt_logprobs) == len(tokens)

In [None]:
reference_logprobs = [
    prompt_logprob[str(token)]["logprob"] if prompt_logprob else torch.nan
    for token, prompt_logprob in zip(tokens, prompt_logprobs)
]
for c in completion.ancestors(including_self=True, reverse=True):
    count = c.token_count(trainer.tokenizer, cache=True)
    c.reference_logprobs, reference_logprobs = (
        torch.tensor(reference_logprobs[:count]),
        reference_logprobs[count:],
    )

completion.reference_logprobs

In [None]:
torch.tensor(
    [
        prompt_logprob[str(token)]["logprob"] if prompt_logprob else torch.nan
        for token, prompt_logprob in zip(tokens, prompt_logprobs)
    ]
)

In [None]:
len(prompt_logprobs), len(completion.all_tokens(trainer.tokenizer, cache=True).tolist())

In [None]:
import pickle

with open("/tmp/err_execute_model_input_20241201-005319.pkl", "rb") as f:
    data = pickle.load(f)

data

In [None]:
list(data)

In [34]:
getattr(plain_completion, "prompt_logprobs", None)

In [None]:
plain_completion.choices[0].prompt_logprobs

In [None]:
completion.all_tokens(trainer.tokenizer, cache=True).tolist()

In [None]:
plain_completion

In [10]:
trainer.patience_per_val_sample = 1.0
trainer.patience_per_test_sample = 1.0
trainer.tune_recipe_config.optimizer.lr = 8e-6
trainer.tune_recipe_config.loss.clip_epsilon = 0.1
trainer.tune_recipe_config.loss.weighted_ce_coef = 0.2

In [None]:
await trainer.train(iterations=3)

In [None]:
trainer.eval_scores

In [5]:
from lib.rl.pack import packed_tensors_from_dir

tensors = packed_tensors_from_dir(
    dir="./models/rl/tensors", num_sequences=50, sequence_length=16384
)

In [None]:
tensors["mask"][0][2][2]

In [None]:
trainer.max_mask_sequence_batch_size = 16
# (eval_score, eval_exceptions),
(result,) = await asyncio.gather(
    # trainer.eval("val", 0, return_exceptions=True),
    trainer.explore(1, return_exceptions=True),
)
# print(f"Eval score: {eval_score:.2%}")
print(
    f"Generated {sum(completion.num_token_logprobs() for episode in result.episodes for completion in episode.completion.descendants()):,} tokens"
)
tensors = trainer.tensors(result.episodes)
(tensors["mask"] == result.tensors()["mask"]).all()

In [None]:
tensors = trainer.tensors(result.episodes)
(tensors["mask"] == result.tensors()["mask"]).all()

In [None]:
torch.tensor(tensors["advantages"].shape).prod()

In [None]:
tensors["mask"].shape

In [None]:
import matplotlib.pyplot as plt
import torch


def show(mask: torch.Tensor) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(mask, cmap="inferno")
    plt.colorbar(label="Relative Position")
    plt.title("Relative Position Attention Mask")
    plt.xlabel("Target Position")
    plt.ylabel("Source Position")
    plt.show()


i = 1
_tensors = tensors
key = "input_pos"

show(
    _tensors["mask"][i].cumsum(dim=1)
    * (
        _tensors["mask"][i]
        & (
            ~torch.isnan(_tensors[key][i]).unsqueeze(0)
            & ~torch.isnan(_tensors[key][i]).unsqueeze(1)
        )
    )
)

In [None]:
result.tensors()["mask"].shape

In [None]:
(tensors["mask"] == result.tensors()["mask"]).all()

In [None]:
for i in range(127):
    for j in range(127):
        if (tensors["mask"][i] == result.tensors()["mask"][j]).all():
            print(i, j)

In [None]:
key = "advantages"
torch.isclose(
    tensors[key], result.tensors()[key], rtol=1e-5, atol=1e-8, equal_nan=True
).all()

In [None]:
torch.all((tensors["weights"] == result.tensors()["weights"]))

In [None]:
raise result.exceptions[1]

In [None]:
result.exceptions

In [None]:
2_033_717 / 4.25

In [None]:
2_064_056 / 6.66

In [None]:
2_064_056 / 10.75

In [None]:
2_071_601 / 8

In [None]:
2_071_601 / 11.75

In [None]:
4_119_041 / 16.5

In [None]:
await trainer.train(iterations=1)

In [None]:
val_score, episodes = await asyncio.gather(trainer.eval("val", 0), trainer.explore(1))

In [None]:
from torchtune.models.llama3_1 import llama3_1_8b
from torchtune.training import cleanup_before_training
from torchtune.training.metric_logging import DiskLogger
from typing import Any

from lib.recipes.rl import ComponentConfig, RLConfig, RLRecipe
from lib.rl.pack import PackedDataset, packed_tensors_to_dir
from lib.rl.ppo import PPOLoss


tensors, checkpoint_dir, checkpoint_files = await trainer.tune_resources(episodes)

PLACEHOLDER: Any = None

config = RLConfig(
    # Dataset
    dataset=ComponentConfig(
        PackedDataset, **packed_tensors_to_dir(tensors, trainer.output_dir + "/tensors")
    ),
    seed=42,
    shuffle=False,
    # Model
    model=ComponentConfig(llama3_1_8b),
    num_output_chunks=4,
    # Checkpointer
    checkpointer=ComponentConfig(
        "torchtune.training.FullModelHFCheckpointer",
        checkpoint_dir=checkpoint_dir,
        checkpoint_files=checkpoint_files,
        recipe_checkpoint=None,
        output_dir=trainer.output_dir,
        model_type="LLAMA3",
    ),
    resume_from_checkpoint=False,
    # Fine-tuning arguments
    batch_size=4,
    epochs=1,
    optimizer=ComponentConfig(
        "torch.optim.AdamW",
        # "bitsandbytes.optim.PagedAdamW8bit",
        # "bitsandbytes.optim.AdamW",
        # params=PLACEHOLDER,
        lr=5e-6,
        fused=True,
    ),
    loss=ComponentConfig(
        PPOLoss,
        # clip_epsilon=0.3,
        # entropy_coef=0.0,
        # kl_coef=0.0,
        clip_epsilon=0.3,
        entropy_coef=0.025,
        kl_coef=0.025,
        normalize_advantages=False,
    ),
    max_steps_per_epoch=None,
    compile=False,
    optimizer_in_bwd=False,
    gradient_accumulation_steps=1,
    # Training env
    device="cuda",
    # Memory management
    enable_activation_checkpointing=True,
    enable_activation_offloading=False,
    custom_sharded_layers=["tok_embeddings", "output"],
    # Reduced precision
    dtype="bf16",
    # Logging
    metric_logger=ComponentConfig(
        DiskLogger, log_dir="/home/ubuntu/atreides/experiments/logs"
    ),
    log_every_n_steps=1,
    log_peak_memory_stats=True,
)

# recipe = RLRecipe(config)
# recipe.setup(config)
# recipe.train()
# recipe.cleanup()
# del tensors, recipe
# cleanup_before_training()
# trainer.save(base_checkpoint_dir=checkpoint_dir)

In [19]:
from omegaconf import OmegaConf

dict_config = config.dict_config()
OmegaConf.save(dict_config, trainer.output_dir + "/config.yaml")

In [None]:
import os
import sys
from typing import IO

torchrun_kwargs = {"nnodes": 1, "nproc_per_node": 2}
kwargs = {}
env = {"CUDA_LAUNCH_BLOCKING": "1"}

args = [
    "tune",
    "run",
    *[
        f"--{key.replace('_', '-')}{f'={value}' if value is not True else ''}"
        for key, value in torchrun_kwargs.items()
    ],
    "lib.recipes.rl.RLRecipe",
    "--config",
    trainer.output_dir + "/config.yaml",
    *[
        f"--{key.replace('_', '-')}{f'={value}' if value != True else ''}"
        for key, value in kwargs.items()
    ],
]
print(f"$ {' '.join(args)}")

In [None]:
process = await asyncio.create_subprocess_exec(
    *args,
    stdout=asyncio.subprocess.PIPE,
    stderr=asyncio.subprocess.PIPE,
    env={
        **os.environ,
        **(env or {}),
    },
)


async def log_output(stream: asyncio.StreamReader, io: IO[str]) -> None:
    while True:
        line = await stream.readline()
        if not line:
            break
        decoded_line = line.decode()
        io.write(decoded_line)
        io.flush()


tasks = []
if process.stdout:
    tasks.append(asyncio.create_task(log_output(process.stdout, sys.stdout)))
if process.stderr:
    tasks.append(asyncio.create_task(log_output(process.stderr, sys.stderr)))
_ = await asyncio.gather(*tasks)

In [None]:
from lib.recipes.rl import recipe_main
import os
from torch import distributed as dist
from torchtune.training import is_distributed

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"


recipe_main(config)

In [7]:
from omegaconf import DictConfig, OmegaConf

dict_config = config.dict_config()
OmegaConf.save(dict_config, trainer.output_dir + "/config.yaml")

In [None]:
from lib.rl.completion import Completion


OmegaConf.create(
    OmegaConf.to_yaml(
        DictConfig(dict(name=f"{Completion.__module__}.{Completion.__name__}"))
    )
)

In [27]:
import traceback
import sys

traceback.clear_frames(sys.exc_info()[2])

In [28]:
cleanup_before_training()

In [None]:
trainer.save(base_checkpoint_dir=checkpoint_dir)

In [None]:
import matplotlib.pyplot as plt
import torch


def show(mask: torch.Tensor) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(mask, cmap="inferno")
    plt.colorbar(label="Relative Position")
    plt.title("Relative Position Attention Mask")
    plt.xlabel("Target Position")
    plt.ylabel("Source Position")
    plt.show()


i = 1

show(
    tensors["mask"][i].cumsum(dim=1)
    * (
        tensors["mask"][i]
        & (
            ~torch.isnan(tensors["advantages"][i]).unsqueeze(0)
            & ~torch.isnan(tensors["advantages"][i]).unsqueeze(1)
        )
    )
)

In [None]:
from IPython.display import HTML

HTML(
    f'<div style="white-space: pre-wrap">{list(episodes[2].completion.leaves())[0].html(30.0)}</div>'
)

In [None]:
def mask_and_pos_ids(
    ids: torch.Tensor, parent_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Creates an attention mask and position IDs for hierarchical attention based on node IDs and their parent IDs.

    Args:
        ids: A tensor of shape (batch_size, sequence_length) containing node IDs
        parent_ids: A tensor of shape (batch_size, sequence_length) containing parent IDs for each node

    Returns:
        A tuple containing:
        - mask: A boolean tensor of shape (batch_size, sequence_length, sequence_length) where True indicates
          allowed attention connections. Each position can attend to itself and any of its ancestors
          in the hierarchy, but only for previous positions (due to causal masking).
        - pos_ids: A tensor of shape (batch_size, sequence_length, sequence_length) containing relative
          position IDs for each allowed attention connection, with -1 for masked positions.
    """
    mask = ids.unsqueeze(1) == ids.unsqueeze(2)
    _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    while torch.any(mask != _mask):
        parent_ids = parent_ids.gather(
            1, torch.argmax((parent_ids.unsqueeze(2) == ids.unsqueeze(1)).int(), dim=2)
        )
        mask = _mask
        _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    mask &= torch.tril(torch.ones_like(mask, dtype=torch.bool, device=ids.device))
    # mask = torch.linalg.matrix_power(mask.float(), mask.size(1) - 1) > 0
    pos_ids = (torch.where(mask, mask.cumsum(2), 0) - 1).max(1).values
    return mask, pos_ids


def test_mask_and_pos_ids(
    ids: list[int],
    parent_ids: list[int],
    expected_mask: list[list[int]],
    expected_pos_ids: list[int],
):
    mask, pos_ids = mask_and_pos_ids(
        ids=torch.tensor([ids]), parent_ids=torch.tensor([parent_ids])
    )
    assert torch.all(mask.int() == torch.tensor([expected_mask])), f"\n{mask.int()[0]}"
    assert torch.all(
        pos_ids == torch.tensor([expected_pos_ids])
    ), f"{pos_ids[0].tolist()}"


test_mask_and_pos_ids(
    ids=[0, 1],
    parent_ids=[0, 1],
    expected_mask=[[1, 0], [0, 1]],
    expected_pos_ids=[0, 0],
)

test_mask_and_pos_ids(
    ids=[0, 1, 1],
    parent_ids=[0, 0, 0],
    expected_mask=[[1, 0, 0], [1, 1, 0], [1, 1, 1]],
    expected_pos_ids=[0, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 0, 1, 2],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]],
    expected_pos_ids=[0, 1, 2, 3],
)

test_mask_and_pos_ids(
    ids=[0, 0, 1, 1],
    parent_ids=[0, 0, 1, 1],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]],
    expected_pos_ids=[0, 1, 0, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 1, 0, 1],
    expected_mask=[[1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1]],
    expected_pos_ids=[0, 0, 1, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 2, 3, 3],
    parent_ids=[0, 1, 0, 0, 1, 1],
    expected_mask=[
        [1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 0, 1, 2, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3, 4, 4, 5, 5],
    parent_ids=[0, 0, 1, 1, 2, 2, 3, 3],
    expected_mask=[
        [1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 1, 0, 0, 0],
        [1, 1, 1, 0, 1, 1, 0, 0],
        [1, 1, 0, 1, 0, 0, 1, 0],
        [1, 1, 0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 1, 2, 2, 3, 4, 3, 4],
)

test_mask_and_pos_ids(
    ids=[2, 1, 0],
    parent_ids=[2, 2, 0],
    expected_mask=[
        [1, 0, 0],
        [1, 1, 0],
        [0, 0, 1],
    ],
    expected_pos_ids=[0, 1, 0],
)