In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
from openai import AsyncOpenAI

reference_client = AsyncOpenAI(
    api_key="default",
    base_url="http://209.20.158.71:8000/v1"
)

In [4]:
await reference_client.chat.completions.create(
    messages=[{"role": "user", "content": "Hello!"}],
    model="NousResearch/Hermes-2-Theta-Llama-3-8B",
)

ChatCompletion(id='chat-2e9ecdec70154b01b133ff25722371e4', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="Hi there! How can I help you today? Do you have any questions, or would you like to talk about a specific topic? I'm here to listen and provide information or guidance as needed.", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[]), stop_reason=None)], created=1733347839, model='NousResearch/Hermes-2-Theta-Llama-3-8B', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=41, prompt_tokens=11, total_tokens=52, completion_tokens_details=None, prompt_tokens_details=None), prompt_logprobs=None)

In [5]:
import asyncio
from lib.clue import Clue, DeductiveSolver
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import Trainer
from lib.utils import return_exception
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from torchtune.training.metric_logging import WandBLogger
import random
import re


@return_exception
def sample_random_episode() -> Episode:
    game = Clue(
        num_players=3,
        elements={
            "suspect": random.sample(Clue.suspects, k=3),
            "weapon": random.sample(Clue.weapons, k=3),
            "room": random.sample(Clue.rooms, k=3),
            # "motive": random.sample(Clue.motives, k=3),
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=False,
    )
    prompt, follow_up, solution = game.get_prompt_and_follow_up_and_solution()

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ]
            )
        else:
            follow_up_completion = completion
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        completion = follow_up_completion
        completion.reward = sum(
            [
                bool(
                    # Find first match of key followed by colon and capture following text
                    (
                        match := re.search(
                            rf"{key}: ([A-Za-z \.]+)",
                            answer,
                            re.IGNORECASE,
                        )
                    )
                    # Check if captured group matches expected value
                    and match.group(1).strip().lower() == value.strip().lower()
                )
                for key, value in solution.items()
            ]
        ) / len(solution)
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )


def train_episodes():
    while True:
        yield sample_random_episode()


model_name = "rl1"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    samples_per_episode=27,
    branch_factor=3,
    reference_clients_and_model=(
        [reference_client],
        "NousResearch/Hermes-2-Theta-Llama-3-8B",
    ),
    sample_probability_power=1 / 3,
    train_episodes=train_episodes(),
    episodes_per_iteration=128 * torch.cuda.device_count(),
    patience_per_episode=5,
    patience_per_val_sample=10,
    sampling_kwargs={
        "max_tokens": 512,
    },
    max_mask_sequence_batch_size=1,
    val_episodes=(
        sample_random_episode() for _ in range(32 * torch.cuda.device_count())
    ),
    val_samples_per_episode=3,
    torchrun_kwargs=dict(nnodes=1, nproc_per_node=torch.cuda.device_count()),
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=False,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=2,
        epochs=1,
        metric_logger=ComponentConfig(
            WandBLogger,
            name=f"{model_name}_tune",
            resume="allow",
            id=f"{model_name}_tune",
        ),
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=4e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=1.0,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.5,
            entropy_target_coef=0.1,
            kl_coef=0.05,
            weighted_entropy_coef=0.0,
            weighted_kl_coef=0.0,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=8192,
    # vllm_env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
    vllm_kwargs=dict(
        block_size=32,
        disable_log_requests=True,
        enable_prefix_caching=True,
        enforce_eager=True,
        gpu_memory_utilization=0.95,
        # max_model_len=16384,
        max_num_seqs=512 * torch.cuda.device_count(),
        max_num_batched_tokens=8192 * 4,
        return_tokens_as_token_ids=True,
        swap_space=8,
        # scheduling_policy="priority",
        # tensor_parallel_size=torch.cuda.device_count() // 8,
    ),
    vllm_max_concurrent_samples=512 * torch.cuda.device_count(),
    vllm_min_time_between_requests=0.0,
    vllm_num=torch.cuda.device_count(),
    vllm_timeout=120 + 15 * torch.cuda.device_count(),
)

Resuming from /home/ubuntu/atreides/experiments/models/rl1/0008
INFO 12-04 21:30:53 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=NousResearch/Hermes-2-Thet

In [6]:
await trainer.eval("val")

Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl1/0007 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-num-seqs=512 --max-num-batched-tokens=32768 --return-tokens-as-token-ids --swap-space=8 --api-key=default
INFO 12-04 21:13:49 api_server.py:528] vLLM API server version 0.6.3.post1
INFO 12-04 21:13:49 api_server.py:529] args: Namespace(subparser='serve', model_tag='/home/ubuntu/atreides/experiments/models/rl1/0007', config='', host=None, port=8001, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='default', lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=True, disable_frontend_multiprocessing=False, enable_auto_tool_choice=Fals

Loading pt checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
  state = torch.load(bin_file, map_location="cpu")
Loading pt checkpoint shards:  25% Completed | 1/4 [00:04<00:12,  4.20s/it]
Loading pt checkpoint shards:  50% Completed | 2/4 [00:04<00:04,  2.16s/it]
Loading pt checkpoint shards:  75% Completed | 3/4 [00:08<00:02,  2.65s/it]
Loading pt checkpoint shards: 100% Completed | 4/4 [00:11<00:00,  2.89s/it]
Loading pt checkpoint shards: 100% Completed | 4/4 [00:11<00:00,  2.85s/it]



INFO 12-04 21:14:11 model_runner.py:1067] Loading model weights took 14.9595 GB
INFO 12-04 21:14:12 gpu_executor.py:122] # GPU blocks: 14333, # CPU blocks: 2048
INFO 12-04 21:14:12 gpu_executor.py:126] Maximum concurrency for 8192 tokens per request: 55.99x
INFO 12-04 21:14:16 api_server.py:232] vLLM to use /tmp/tmpk1_41leh as PROMETHEUS_MULTIPROC_DIR
INFO 12-04 21:14:16 launcher.py:19] Available routes are:
INFO 12-04 21:14:16 launcher.py:27] Route: /openapi.json, Methods: HEAD, GET
INFO 12-04 21:14:16 launcher.py:27] Route: /docs, Methods: HEAD, GET
INFO 12-04 21:14:16 launcher.py:27] Route: /docs/oauth2-redirect, Methods: HEAD, GET
INFO 12-04 21:14:16 launcher.py:27] Route: /redoc, Methods: HEAD, GET
INFO 12-04 21:14:16 launcher.py:27] Route: /health, Methods: GET
INFO 12-04 21:14:16 launcher.py:27] Route: /tokenize, Methods: POST
INFO 12-04 21:14:16 launcher.py:27] Route: /detokenize, Methods: POST
INFO 12-04 21:14:16 launcher.py:27] Route: /v1/models, Methods: GET
INFO 12-04 21:14

INFO:     Started server process [30507]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on socket ('0.0.0.0', 8001) (Press CTRL+C to quit)


INFO:     127.0.0.1:60550 - "POST /v1/chat/completions HTTP/1.1" 200 OK
vLLM server started succesfully. Logs can be found at ./logs/vllm.log


val: 0episode [00:00, ?episode/s]

(0.8645833333333333, [])

In [8]:
explore_result = await trainer.explore(verbosity=1)

Starting 1 vLLM servers...
$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-num-seqs=512 --max-num-batched-tokens=32768 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


explore:   0%|          | 0/128 [00:00<?, ?episode/s]

In [10]:
trainer.tune_recipe_config.metric_logger.name = f"{model_name}_tune"
trainer.tune_recipe_config.metric_logger.id = f"{model_name}_tune"

In [11]:
await trainer.tune(explore_result, verbosity=2)

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl1/config.yaml
Running with torchrun...


INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.checkpointing._checkpointer.FullModelHFCheckpointer
  checkpoint_dir: /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725
  checkpoint_files:
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00004-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00001-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00002-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3

Saved iteration 1 model files to /home/ubuntu/atreides/experiments/models/rl1/0001


In [7]:
trainer.tune_recipe_config.loss.entropy_target_coef = 0.1
trainer.tune_recipe_config.loss.kl_coef = 0.05

In [6]:
await trainer.train(iterations=3, verbosity=1)

Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl1/0008 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-num-seqs=512 --max-num-batched-tokens=32768 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val: 0episode [00:00, ?episode/s]

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

Early stopping val evaluation due to expired patience (0 remaining samples x 10 patience per sample = 0 seconds)
$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl1/config.yaml


1|14|Loss: 0.0100: 100%|██████████| 14/14 [01:55<00:00,  7.70s/it, entropy=0.3990, entropy_target=0.1010, kl_div=0.0081, policy=-0.0241, tanh_log_policy_to_log=-0.0005, unclipped_policy=-0.0355, value=2.0078, weighted_ce=-0.0001, weighted_entropy=-0.0057, weighted_kl_div=0.0013]

Saved iteration 9 model files to /home/ubuntu/atreides/experiments/models/rl1/0009
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl1/0009 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-num-seqs=512 --max-num-batched-tokens=32768 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/32 [00:00<?, ?episode/s]

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl1/config.yaml


1|2|Loss: 0.0125: 100%|██████████| 2/2 [00:23<00:00, 11.10s/it, entropy=0.2623, entropy_target=0.2377, kl_div=-0.0079, policy=0.0153, tanh_log_policy_to_log=-0.0109, unclipped_policy=-0.0259, value=2.2582, weighted_ce=-0.0234, weighted_entropy=-0.0041, weighted_kl_div=0.0032]

Saved iteration 10 model files to /home/ubuntu/atreides/experiments/models/rl1/0010
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl1/0010 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-num-seqs=512 --max-num-batched-tokens=32768 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/32 [00:00<?, ?episode/s]

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl1/config.yaml


1|9|Loss: 0.0107: 100%|██████████| 9/9 [01:17<00:00,  7.83s/it, entropy=0.6588, entropy_target=0.1588, kl_div=0.0352, policy=0.0395, tanh_log_policy_to_log=-0.0069, unclipped_policy=-0.0120, value=2.0642, weighted_ce=-0.0189, weighted_entropy=-0.0088, weighted_kl_div=0.0081]  

Saved iteration 11 model files to /home/ubuntu/atreides/experiments/models/rl1/0011
Starting 1 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl1/0011 --port=8001 --block-size=32 --disable-log-requests --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.95 --max-num-seqs=512 --max-num-batched-tokens=32768 --return-tokens-as-token-ids --swap-space=8 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


val:   0%|          | 0/32 [00:00<?, ?episode/s]

In [6]:
explore_result = await trainer.explore()

explore:   0%|          | 0/128 [00:00<?, ?episode/s]

Early stopping exploration due to expired patience (36.11519999998988 remaining samples x 0.037037037037037035 patience per sample = 1.3375999999996253 seconds)


In [11]:
await trainer.tune(explore_result)

$ tune run --nnodes=1 --nproc-per-node=1 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl2/config.yaml
Running with torchrun...


INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.checkpointing._checkpointer.FullModelHFCheckpointer
  checkpoint_dir: /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725
  checkpoint_files:
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00004-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00001-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725/model-00002-of-00004.safetensors
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3

CancelledError: 

In [11]:
from openai import AsyncOpenAI

completion_sampler = await trainer.get_completion_sampler()
client: AsyncOpenAI = completion_sampler.samplers[0].client  # type: ignore

In [None]:
vllm.client.chat.completions.create()

In [None]:
from typing import Any

episode = explore_result.episodes[0]
completion = next(iter(episode.completion.leaves()))
tokens = completion.all_tokens(trainer.tokenizer, cache=True).tolist()
plain_completion = await vllm.client.completions.create(
    model=trainer.model,
    prompt=tokens,
    max_tokens=1,
    extra_body={
        "prompt_logprobs": True,
    },
)
prompt_logprobs: list[dict[str, dict[str, Any]]] = plain_completion.choices[0].prompt_logprobs  # type: ignore
prompt_logprobs

In [56]:
assert len(prompt_logprobs) == len(tokens)

In [None]:
reference_logprobs = [
    prompt_logprob[str(token)]["logprob"] if prompt_logprob else torch.nan
    for token, prompt_logprob in zip(tokens, prompt_logprobs)
]
for c in completion.ancestors(including_self=True, reverse=True):
    count = c.token_count(trainer.tokenizer, cache=True)
    c.reference_logprobs, reference_logprobs = (
        torch.tensor(reference_logprobs[:count]),
        reference_logprobs[count:],
    )

completion.reference_logprobs

In [None]:
torch.tensor(
    [
        prompt_logprob[str(token)]["logprob"] if prompt_logprob else torch.nan
        for token, prompt_logprob in zip(tokens, prompt_logprobs)
    ]
)

In [None]:
len(prompt_logprobs), len(completion.all_tokens(trainer.tokenizer, cache=True).tolist())

In [None]:
import pickle

with open("/tmp/err_execute_model_input_20241201-005319.pkl", "rb") as f:
    data = pickle.load(f)

data

In [None]:
list(data)

In [34]:
getattr(plain_completion, "prompt_logprobs", None)

In [None]:
plain_completion.choices[0].prompt_logprobs

In [None]:
completion.all_tokens(trainer.tokenizer, cache=True).tolist()

In [None]:
plain_completion

In [10]:
trainer.patience_per_val_sample = 1.0
trainer.patience_per_test_sample = 1.0
trainer.tune_recipe_config.optimizer.lr = 8e-6
trainer.tune_recipe_config.loss.clip_epsilon = 0.1
trainer.tune_recipe_config.loss.weighted_ce_coef = 0.2

In [None]:
await trainer.train(iterations=3)

In [None]:
trainer.eval_scores

In [5]:
from lib.rl.pack import packed_tensors_from_dir

tensors = packed_tensors_from_dir(
    dir="./models/rl/tensors", num_sequences=50, sequence_length=16384
)

In [None]:
tensors["mask"][0][2][2]

In [None]:
trainer.max_mask_sequence_batch_size = 16
# (eval_score, eval_exceptions),
(result,) = await asyncio.gather(
    # trainer.eval("val", 0, return_exceptions=True),
    trainer.explore(1, return_exceptions=True),
)
# print(f"Eval score: {eval_score:.2%}")
print(
    f"Generated {sum(completion.num_token_logprobs() for episode in result.episodes for completion in episode.completion.descendants()):,} tokens"
)
tensors = trainer.tensors(result.episodes)
(tensors["mask"] == result.tensors()["mask"]).all()

In [None]:
tensors = trainer.tensors(result.episodes)
(tensors["mask"] == result.tensors()["mask"]).all()

In [None]:
torch.tensor(tensors["advantages"].shape).prod()

In [None]:
tensors["mask"].shape

In [None]:
import matplotlib.pyplot as plt
import torch


def show(mask: torch.Tensor) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(mask, cmap="inferno")
    plt.colorbar(label="Relative Position")
    plt.title("Relative Position Attention Mask")
    plt.xlabel("Target Position")
    plt.ylabel("Source Position")
    plt.show()


i = 1
_tensors = tensors
key = "input_pos"

show(
    _tensors["mask"][i].cumsum(dim=1)
    * (
        _tensors["mask"][i]
        & (
            ~torch.isnan(_tensors[key][i]).unsqueeze(0)
            & ~torch.isnan(_tensors[key][i]).unsqueeze(1)
        )
    )
)

In [None]:
result.tensors()["mask"].shape

In [None]:
(tensors["mask"] == result.tensors()["mask"]).all()

In [None]:
for i in range(127):
    for j in range(127):
        if (tensors["mask"][i] == result.tensors()["mask"][j]).all():
            print(i, j)

In [None]:
key = "advantages"
torch.isclose(
    tensors[key], result.tensors()[key], rtol=1e-5, atol=1e-8, equal_nan=True
).all()

In [None]:
torch.all((tensors["weights"] == result.tensors()["weights"]))

In [None]:
raise result.exceptions[1]

In [None]:
result.exceptions

In [None]:
2_033_717 / 4.25

In [None]:
2_064_056 / 6.66

In [None]:
2_064_056 / 10.75

In [None]:
2_071_601 / 8

In [None]:
2_071_601 / 11.75

In [None]:
4_119_041 / 16.5

In [None]:
await trainer.train(iterations=1)

In [None]:
val_score, episodes = await asyncio.gather(trainer.eval("val", 0), trainer.explore(1))

In [None]:
from torchtune.models.llama3_1 import llama3_1_8b
from torchtune.training import cleanup_before_training
from torchtune.training.metric_logging import DiskLogger
from typing import Any

from lib.recipes.rl import ComponentConfig, RLConfig, RLRecipe
from lib.rl.pack import PackedDataset, packed_tensors_to_dir
from lib.rl.ppo import PPOLoss


tensors, checkpoint_dir, checkpoint_files = await trainer.tune_resources(episodes)

PLACEHOLDER: Any = None

config = RLConfig(
    # Dataset
    dataset=ComponentConfig(
        PackedDataset, **packed_tensors_to_dir(tensors, trainer.output_dir + "/tensors")
    ),
    seed=42,
    shuffle=False,
    # Model
    model=ComponentConfig(llama3_1_8b),
    num_output_chunks=4,
    # Checkpointer
    checkpointer=ComponentConfig(
        "torchtune.training.FullModelHFCheckpointer",
        checkpoint_dir=checkpoint_dir,
        checkpoint_files=checkpoint_files,
        recipe_checkpoint=None,
        output_dir=trainer.output_dir,
        model_type="LLAMA3",
    ),
    resume_from_checkpoint=False,
    # Fine-tuning arguments
    batch_size=4,
    epochs=1,
    optimizer=ComponentConfig(
        "torch.optim.AdamW",
        # "bitsandbytes.optim.PagedAdamW8bit",
        # "bitsandbytes.optim.AdamW",
        # params=PLACEHOLDER,
        lr=5e-6,
        fused=True,
    ),
    loss=ComponentConfig(
        PPOLoss,
        # clip_epsilon=0.3,
        # entropy_coef=0.0,
        # kl_coef=0.0,
        clip_epsilon=0.3,
        entropy_coef=0.025,
        kl_coef=0.025,
        normalize_advantages=False,
    ),
    max_steps_per_epoch=None,
    compile=False,
    optimizer_in_bwd=False,
    gradient_accumulation_steps=1,
    # Training env
    device="cuda",
    # Memory management
    enable_activation_checkpointing=True,
    enable_activation_offloading=False,
    custom_sharded_layers=["tok_embeddings", "output"],
    # Reduced precision
    dtype="bf16",
    # Logging
    metric_logger=ComponentConfig(
        DiskLogger, log_dir="/home/ubuntu/atreides/experiments/logs"
    ),
    log_every_n_steps=1,
    log_peak_memory_stats=True,
)

# recipe = RLRecipe(config)
# recipe.setup(config)
# recipe.train()
# recipe.cleanup()
# del tensors, recipe
# cleanup_before_training()
# trainer.save(base_checkpoint_dir=checkpoint_dir)

In [19]:
from omegaconf import OmegaConf

dict_config = config.dict_config()
OmegaConf.save(dict_config, trainer.output_dir + "/config.yaml")

In [None]:
import os
import sys
from typing import IO

torchrun_kwargs = {"nnodes": 1, "nproc_per_node": 2}
kwargs = {}
env = {"CUDA_LAUNCH_BLOCKING": "1"}

args = [
    "tune",
    "run",
    *[
        f"--{key.replace('_', '-')}{f'={value}' if value is not True else ''}"
        for key, value in torchrun_kwargs.items()
    ],
    "lib.recipes.rl.RLRecipe",
    "--config",
    trainer.output_dir + "/config.yaml",
    *[
        f"--{key.replace('_', '-')}{f'={value}' if value != True else ''}"
        for key, value in kwargs.items()
    ],
]
print(f"$ {' '.join(args)}")

In [None]:
process = await asyncio.create_subprocess_exec(
    *args,
    stdout=asyncio.subprocess.PIPE,
    stderr=asyncio.subprocess.PIPE,
    env={
        **os.environ,
        **(env or {}),
    },
)


async def log_output(stream: asyncio.StreamReader, io: IO[str]) -> None:
    while True:
        line = await stream.readline()
        if not line:
            break
        decoded_line = line.decode()
        io.write(decoded_line)
        io.flush()


tasks = []
if process.stdout:
    tasks.append(asyncio.create_task(log_output(process.stdout, sys.stdout)))
if process.stderr:
    tasks.append(asyncio.create_task(log_output(process.stderr, sys.stderr)))
_ = await asyncio.gather(*tasks)

In [None]:
from lib.recipes.rl import recipe_main
import os
from torch import distributed as dist
from torchtune.training import is_distributed

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"


recipe_main(config)

In [7]:
from omegaconf import DictConfig, OmegaConf

dict_config = config.dict_config()
OmegaConf.save(dict_config, trainer.output_dir + "/config.yaml")

In [None]:
from lib.rl.completion import Completion


OmegaConf.create(
    OmegaConf.to_yaml(
        DictConfig(dict(name=f"{Completion.__module__}.{Completion.__name__}"))
    )
)

In [27]:
import traceback
import sys

traceback.clear_frames(sys.exc_info()[2])

In [28]:
cleanup_before_training()

In [None]:
trainer.save(base_checkpoint_dir=checkpoint_dir)

In [None]:
import matplotlib.pyplot as plt
import torch


def show(mask: torch.Tensor) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(mask, cmap="inferno")
    plt.colorbar(label="Relative Position")
    plt.title("Relative Position Attention Mask")
    plt.xlabel("Target Position")
    plt.ylabel("Source Position")
    plt.show()


i = 1

show(
    tensors["mask"][i].cumsum(dim=1)
    * (
        tensors["mask"][i]
        & (
            ~torch.isnan(tensors["advantages"][i]).unsqueeze(0)
            & ~torch.isnan(tensors["advantages"][i]).unsqueeze(1)
        )
    )
)

In [None]:
from IPython.display import HTML

HTML(
    f'<div style="white-space: pre-wrap">{list(episodes[2].completion.leaves())[0].html(30.0)}</div>'
)

In [None]:
def mask_and_pos_ids(
    ids: torch.Tensor, parent_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Creates an attention mask and position IDs for hierarchical attention based on node IDs and their parent IDs.

    Args:
        ids: A tensor of shape (batch_size, sequence_length) containing node IDs
        parent_ids: A tensor of shape (batch_size, sequence_length) containing parent IDs for each node

    Returns:
        A tuple containing:
        - mask: A boolean tensor of shape (batch_size, sequence_length, sequence_length) where True indicates
          allowed attention connections. Each position can attend to itself and any of its ancestors
          in the hierarchy, but only for previous positions (due to causal masking).
        - pos_ids: A tensor of shape (batch_size, sequence_length, sequence_length) containing relative
          position IDs for each allowed attention connection, with -1 for masked positions.
    """
    mask = ids.unsqueeze(1) == ids.unsqueeze(2)
    _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    while torch.any(mask != _mask):
        parent_ids = parent_ids.gather(
            1, torch.argmax((parent_ids.unsqueeze(2) == ids.unsqueeze(1)).int(), dim=2)
        )
        mask = _mask
        _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    mask &= torch.tril(torch.ones_like(mask, dtype=torch.bool, device=ids.device))
    # mask = torch.linalg.matrix_power(mask.float(), mask.size(1) - 1) > 0
    pos_ids = (torch.where(mask, mask.cumsum(2), 0) - 1).max(1).values
    return mask, pos_ids


def test_mask_and_pos_ids(
    ids: list[int],
    parent_ids: list[int],
    expected_mask: list[list[int]],
    expected_pos_ids: list[int],
):
    mask, pos_ids = mask_and_pos_ids(
        ids=torch.tensor([ids]), parent_ids=torch.tensor([parent_ids])
    )
    assert torch.all(mask.int() == torch.tensor([expected_mask])), f"\n{mask.int()[0]}"
    assert torch.all(
        pos_ids == torch.tensor([expected_pos_ids])
    ), f"{pos_ids[0].tolist()}"


test_mask_and_pos_ids(
    ids=[0, 1],
    parent_ids=[0, 1],
    expected_mask=[[1, 0], [0, 1]],
    expected_pos_ids=[0, 0],
)

test_mask_and_pos_ids(
    ids=[0, 1, 1],
    parent_ids=[0, 0, 0],
    expected_mask=[[1, 0, 0], [1, 1, 0], [1, 1, 1]],
    expected_pos_ids=[0, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 0, 1, 2],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]],
    expected_pos_ids=[0, 1, 2, 3],
)

test_mask_and_pos_ids(
    ids=[0, 0, 1, 1],
    parent_ids=[0, 0, 1, 1],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]],
    expected_pos_ids=[0, 1, 0, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 1, 0, 1],
    expected_mask=[[1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1]],
    expected_pos_ids=[0, 0, 1, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 2, 3, 3],
    parent_ids=[0, 1, 0, 0, 1, 1],
    expected_mask=[
        [1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 0, 1, 2, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3, 4, 4, 5, 5],
    parent_ids=[0, 0, 1, 1, 2, 2, 3, 3],
    expected_mask=[
        [1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 1, 0, 0, 0],
        [1, 1, 1, 0, 1, 1, 0, 0],
        [1, 1, 0, 1, 0, 0, 1, 0],
        [1, 1, 0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 1, 2, 2, 3, 4, 3, 4],
)

test_mask_and_pos_ids(
    ids=[2, 1, 0],
    parent_ids=[2, 2, 0],
    expected_mask=[
        [1, 0, 0],
        [1, 1, 0],
        [0, 0, 1],
    ],
    expected_pos_ids=[0, 1, 0],
)