In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [4]:
import asyncio
from lib.clue import Clue, DeductiveSolver
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.trainer import Trainer
import random
import re


def sample_random_episode() -> Episode:
    game = Clue(
        num_players=3,
        elements={
            "suspect": random.sample(Clue.suspects, k=3),
            "weapon": random.sample(Clue.weapons, k=3),
            "room": random.sample(Clue.rooms, k=3),
            # "motive": random.choices(Clue.motives, k=3),
            # "time": random.choices(Clue.get_times("21:00", "03:00", "1h"), k=3),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=False,
    )
    prompt = game.get_prompt()
    follow_up = "Fill out your answer like this:\n" + "\n".join(
        f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements
    )

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ]
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        completion.reward = sum(
            [
                bool(
                    re.search(
                        f"{element}: {solution}",
                        answer,
                        re.IGNORECASE,
                    )
                )
                for element, solution in game.solution.items()
            ]
        ) / len(game.solution)
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )


def train_episodes():
    while True:
        yield sample_random_episode()


trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    samples_per_episode=8,
    branch_factor=2,
    train_episodes=train_episodes(),
    episodes_per_iteration=32,
    val_episodes=[sample_random_episode() for _ in range(32)],
    tune_sequence_length=8192,
    vllm_kwargs=dict(disable_log_requests=True, scheduling_policy="priority"),
    vllm_max_concurrent_requests=256,
)

config.json:   0%|          | 0.00/716 [00:00<?, ?B/s]

INFO 11-22 03:40:01 llm_engine.py:237] Initializing an LLM engine (vdev) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=NousResearch/Hermes-2-Theta-Llama-3-8B, use_v2_block_manager=True, num_scheduler_steps=1, chunked_

tokenizer_config.json:   0%|          | 0.00/56.3k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/444 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/169 [00:00<?, ?B/s]

In [5]:
val_score, episodes = await asyncio.gather(trainer.eval("val", 0), trainer.explore(1))

$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --disable-log-requests --scheduling-policy=priority --api-key=default


No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION


INFO 11-22 03:40:06 api_server.py:528] vLLM API server version dev
INFO 11-22 03:40:06 api_server.py:529] args: Namespace(subparser='serve', model_tag='NousResearch/Hermes-2-Theta-Llama-3-8B', config='', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='default', lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='NousResearch/Hermes-2-Theta-Llama-3-8B', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', config_format='auto', dtype='auto', kv_cache_dtype='au

No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION


INFO 11-22 03:40:17 llm_engine.py:237] Initializing an LLM engine (vdev) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=NousResearch/Hermes-2-Theta-Llama-3-8B, use_v2_block_manager=True, num_scheduler_steps=1, chunked_

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.32it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.21it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.67it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.33it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.36it/s]



INFO 11-22 03:41:04 model_runner.py:1071] Loading model weights took 14.9595 GB
INFO 11-22 03:41:05 gpu_executor.py:122] # GPU blocks: 27864, # CPU blocks: 2048
INFO 11-22 03:41:05 gpu_executor.py:126] Maximum concurrency for 8192 tokens per request: 54.42x
INFO 11-22 03:41:07 model_runner.py:1402] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 11-22 03:41:07 model_runner.py:1406] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 11-22 03:41:16 model_runner.py:1530] Graph capturing finished in 9 secs.
INFO 11-22 03:41:17 api_server.py:232] vLLM to use /tmp/tmpvc_tgll2 as PROMETHEUS_MULTIPROC_DIR
INFO 11-22 03:41:17 launcher.py:19] Available rou

INFO:     Started server process [4333]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on socket ('0.0.0.0', 8000) (Press CTRL+C to quit)


INFO:     127.0.0.1:44182 - "POST /v1/chat/completions HTTP/1.1" 200 OK
vLLM server started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/32 [00:00<?, ?episode/s]

explore:   0%|          | 0/32 [00:00<?, ?episode/s]

In [6]:
vllm = await trainer.vllm()
vllm.process.terminate()
trainer._vllm_task, trainer._completion_sampler = None, None

In [22]:
import os
import subprocess
from torchtune.models.llama3_1 import llama3_1_8b
from torchtune.training import cleanup_before_training
from torchtune.training.metric_logging import DiskLogger
from typing import Any

from lib.recipes.rl import ComponentConfig, RLConfig, RLRecipe
from lib.rl.pack import PackedDataset, packed_tensors
from lib.rl.ppo import PPOLoss

tensors = packed_tensors(
    episodes,
    model=trainer.model,
    sequence_length=trainer.tune_sequence_length,
    trajectories_per_episode=(
        int(trainer.samples_per_episode * trainer.tune_episode_sample_fraction)
        if trainer.tune_episode_sample_fraction < 1.0
        else None
    ),
    tokenizer=trainer.tokenizer,
)

checkpoint_dir = subprocess.run(
    f"HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download {trainer.model}",
    shell=True,
    capture_output=True,
    text=True,
).stdout.strip()
print("Checkpoint directory:", checkpoint_dir)

checkpoint_output_dir = "/home/ubuntu/atreides/experiments/models/rl"
os.makedirs(checkpoint_output_dir, exist_ok=True)

PLACEHOLDER: Any = None

config = RLConfig(
    # Dataset
    dataset=ComponentConfig(PackedDataset, tensors=tensors),
    seed=42,
    shuffle=False,
    # Model
    model=ComponentConfig(llama3_1_8b),
    num_output_chunks=4,
    # Checkpointer
    checkpointer=ComponentConfig(
        "torchtune.training.FullModelHFCheckpointer",
        checkpoint_dir=checkpoint_dir,
        checkpoint_files=[
            "model-00001-of-00004.safetensors",
            "model-00002-of-00004.safetensors",
            "model-00003-of-00004.safetensors",
            "model-00004-of-00004.safetensors",
        ],
        recipe_checkpoint=None,
        output_dir=checkpoint_output_dir,
        model_type="LLAMA3",
    ),
    resume_from_checkpoint=False,
    # Fine-tuning arguments
    batch_size=4,
    epochs=1,
    optimizer=ComponentConfig(
        # AdamW,
        "bitsandbytes.optim.PagedAdamW8bit",
        params=PLACEHOLDER,
        lr=5e-6,
        # fused=True,
    ),
    loss=ComponentConfig(
        PPOLoss,
        # clip_epsilon=0.3,
        # entropy_coef=0.0,
        # kl_coef=0.0,
        clip_epsilon=0.3,
        entropy_coef=0.025,
        kl_coef=0.025,
    ),
    max_steps_per_epoch=None,
    compile=False,
    optimizer_in_bwd=False,
    gradient_accumulation_steps=1,
    # Training env
    device="cuda",
    # Memory management
    enable_activation_checkpointing=True,
    enable_activation_offloading=False,
    custom_sharded_layers=["tok_embeddings", "output"],
    # Reduced precision
    dtype="bf16",
    # Logging
    metric_logger=ComponentConfig(
        DiskLogger, log_dir="/home/ubuntu/atreides/experiments/logs"
    ),
    output_dir="/home/ubuntu/atreides/experiments/logs",
    log_every_n_steps=1,
    log_peak_memory_stats=True,
)

recipe = RLRecipe(config)
recipe.setup(config)
recipe.train()
recipe.cleanup()
del recipe
cleanup_before_training()

Packed sequences in 0.02s ✓
Prepared tensors in 0.61s ✓
Created mask in 1.14s ✓


INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
DEBUG:torchtune.utils._logging:Setting manual seed to local seed 42. Local seed is seed + rank = 42 + 0


Checkpoint directory: /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725
Writing logs to /home/ubuntu/atreides/experiments/logs/log_1732249559.txt


INFO:torchtune.utils._logging:FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 2.22 secs
INFO:torchtune.utils._logging:Memory stats after model init:
	GPU peak memory allocation: 15.08 GiB
	GPU peak memory reserved: 15.20 GiB
	GPU peak memory active: 15.08 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
1|2|Loss: 0.0204: 100%|██████████| 2/2 [00:34<00:00, 17.32s/it, entropy=0.2895, kl_div=0.0556, policy=0.0263]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Ge

In [10]:
checkpoint_dir = subprocess.run(
    f"HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download {trainer.model}",
    shell=True,
    capture_output=True,
    text=True,
).stdout.strip()
print("Checkpoint directory:", checkpoint_dir)

Checkpoint directory: /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725


In [23]:
import glob
import os
import re
import shutil

# Find the latest epoch number from model checkpoint files
epoch = max(
    (
        int(result.group(1))
        for result in (
            re.search(r"hf_model_\d+_(\d+)\.pt", file)
            for file in glob.glob(f"{checkpoint_output_dir}/hf_model_*_*.pt")
        )
        if result
    ),
    default=-1,
)

# Find the next iteration number by looking at existing subdirectories
iteration = (
    max(
        (
            int(subdir)
            for subdir in os.listdir(checkpoint_output_dir)
            if os.path.isdir(os.path.join(checkpoint_output_dir, subdir))
            and subdir.isdigit()
        ),
        default=-1,
    )
    + 1
)

# Create a new directory for this iteration
os.makedirs(f"{checkpoint_output_dir}/{iteration:04d}", exist_ok=True)

# Copy configuration files (non-model files) to the iteration directory
for file in os.listdir(checkpoint_dir):
    if not any(
        file.endswith(suffix)
        for suffix in (".safetensors", ".pt", ".ckpt", ".bin", ".pth", ".h5")
    ):
        src = os.path.join(checkpoint_dir, file)
        dst = os.path.join(f"{checkpoint_output_dir}/{iteration:04d}", file)
        shutil.copy2(src, dst)

# Move model checkpoint files to the iteration directory
for src in glob.glob(f"{checkpoint_output_dir}/hf_model_*_{epoch}.pt"):
    dst = f"{checkpoint_output_dir}/{iteration:04d}/{os.path.basename(src)}"
    shutil.move(src, dst)

# Delete remaining model checkpoint files in checkpoint_output_dir root
for file in glob.glob(f"{checkpoint_output_dir}/hf_model_*_*.pt"):
    if os.path.isfile(file):
        os.remove(file)

print(f"Saved model checkpoint to {checkpoint_output_dir}/{iteration:04d}")

In [None]:
iter = 1

# Create target directory if it doesn't exist
os.makedirs(f"{checkpoint_output_dir}/{iter:04d}", exist_ok=True)

# Copy all non-safetensors files from model_dir to target
for file in os.listdir(checkpoint_dir):
    if not file.endswith(".safetensors") and not file.endswith(".pt"):
        src = os.path.join(checkpoint_dir, file)
        dst = os.path.join(f"{checkpoint_dir}/{iter:04d}", file)
        shutil.copy2(src, dst)

# Move all .pt files from ./models/test to ./models/test/0000
for file in os.listdir(checkpoint_output_dir):
    if file.endswith(".pt"):
        src = os.path.join(checkpoint_output_dir, file)
        dst = os.path.join(f"{checkpoint_output_dir}/{iter:04d}", file)
        shutil.move(src, dst)

In [7]:
del recipe
cleanup_before_training()

In [None]:
import matplotlib.pyplot as plt
import torch


def show(mask: torch.Tensor) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(mask, cmap="inferno")
    plt.colorbar(label="Relative Position")
    plt.title("Relative Position Attention Mask")
    plt.xlabel("Target Position")
    plt.ylabel("Source Position")
    plt.show()

i = 1

show(
    tensors["mask"][i].cumsum(dim=1)
    * (
        tensors["mask"][i]
        & (
            ~torch.isnan(tensors["advantages"][i]).unsqueeze(0)
            & ~torch.isnan(tensors["advantages"][i]).unsqueeze(1)
        )
    )
)

In [None]:
from IPython.display import HTML

HTML(
    f'<div style="white-space: pre-wrap">{list(episodes[2].completion.leaves())[0].html(30.0)}</div>'
)

In [None]:
def mask_and_pos_ids(
    ids: torch.Tensor, parent_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Creates an attention mask and position IDs for hierarchical attention based on node IDs and their parent IDs.

    Args:
        ids: A tensor of shape (batch_size, sequence_length) containing node IDs
        parent_ids: A tensor of shape (batch_size, sequence_length) containing parent IDs for each node

    Returns:
        A tuple containing:
        - mask: A boolean tensor of shape (batch_size, sequence_length, sequence_length) where True indicates
          allowed attention connections. Each position can attend to itself and any of its ancestors
          in the hierarchy, but only for previous positions (due to causal masking).
        - pos_ids: A tensor of shape (batch_size, sequence_length, sequence_length) containing relative
          position IDs for each allowed attention connection, with -1 for masked positions.
    """
    mask = ids.unsqueeze(1) == ids.unsqueeze(2)
    _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    while torch.any(mask != _mask):
        parent_ids = parent_ids.gather(
            1, torch.argmax((parent_ids.unsqueeze(2) == ids.unsqueeze(1)).int(), dim=2)
        )
        mask = _mask
        _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    mask &= torch.tril(torch.ones_like(mask, dtype=torch.bool, device=ids.device))
    # mask = torch.linalg.matrix_power(mask.float(), mask.size(1) - 1) > 0
    pos_ids = (torch.where(mask, mask.cumsum(2), 0) - 1).max(1).values
    return mask, pos_ids


def test_mask_and_pos_ids(
    ids: list[int],
    parent_ids: list[int],
    expected_mask: list[list[int]],
    expected_pos_ids: list[int],
):
    mask, pos_ids = mask_and_pos_ids(
        ids=torch.tensor([ids]), parent_ids=torch.tensor([parent_ids])
    )
    assert torch.all(mask.int() == torch.tensor([expected_mask])), f"\n{mask.int()[0]}"
    assert torch.all(
        pos_ids == torch.tensor([expected_pos_ids])
    ), f"{pos_ids[0].tolist()}"


test_mask_and_pos_ids(
    ids=[0, 1],
    parent_ids=[0, 1],
    expected_mask=[[1, 0], [0, 1]],
    expected_pos_ids=[0, 0],
)

test_mask_and_pos_ids(
    ids=[0, 1, 1],
    parent_ids=[0, 0, 0],
    expected_mask=[[1, 0, 0], [1, 1, 0], [1, 1, 1]],
    expected_pos_ids=[0, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 0, 1, 2],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]],
    expected_pos_ids=[0, 1, 2, 3],
)

test_mask_and_pos_ids(
    ids=[0, 0, 1, 1],
    parent_ids=[0, 0, 1, 1],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]],
    expected_pos_ids=[0, 1, 0, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 1, 0, 1],
    expected_mask=[[1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1]],
    expected_pos_ids=[0, 0, 1, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 2, 3, 3],
    parent_ids=[0, 1, 0, 0, 1, 1],
    expected_mask=[
        [1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 0, 1, 2, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3, 4, 4, 5, 5],
    parent_ids=[0, 0, 1, 1, 2, 2, 3, 3],
    expected_mask=[
        [1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 1, 0, 0, 0],
        [1, 1, 1, 0, 1, 1, 0, 0],
        [1, 1, 0, 1, 0, 0, 1, 0],
        [1, 1, 0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 1, 2, 2, 3, 4, 3, 4],
)

test_mask_and_pos_ids(
    ids=[2, 1, 0],
    parent_ids=[2, 2, 0],
    expected_mask=[
        [1, 0, 0],
        [1, 1, 0],
        [0, 0, 1],
    ],
    expected_pos_ids=[0, 1, 0],
)