In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
from lib.langfuse import langfuse
langfuse.enabled = False
# langfuse.auth_check()

In [4]:
import json
from lib.rl.episode import Episode, EpisodeCompletion
import random
import re
from typing import TypedDict


class TemporalCluePuzzle(TypedDict):
    num_clues: int
    prompt: str
    solution: dict[str, str]


temporal_clue_puzzles: list[TemporalCluePuzzle] = json.load(
    open("./data/temporal-clue-puzzles.json")
)
random.seed(42)
random.shuffle(temporal_clue_puzzles)

In [5]:
from lib.clue import Clue

chain_of_thought_examples: list[dict[str, str]] = json.load(
    open("./data/chain-of-thought-examples.json")
)
chain_of_thought_examples.pop(6)
chain_of_thought_examples.pop(3)

def get_episode(puzzle: TemporalCluePuzzle) -> Episode:

    def validate(completion: EpisodeCompletion) -> None:
        ...

    def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in completions:
            content = completion.last_assistant_message.get("content")
            assert isinstance(content, str)
            num_correct = 0
            for key, value in puzzle["solution"].items():
                if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
                    match = matches[-1]
                    if match.strip().lower() == value.lower():
                        num_correct += 1
            completion.commit(reward=num_correct / len(puzzle["solution"]))
            
    example = random.choices(chain_of_thought_examples, k=1)

    return Episode(
        messages=[
            {
                "role": "user",
                "content": puzzle["prompt"]
                .replace(
                    "Fill out your final answers in the following format:",
                    "After verifiably finding all the correct answers, fill out your final answers in the following format:",
                )
                ,
            },
            # {
            #     "role": "assistant",
            #     "content": "Let's think this through step by step...",
            # },
        ],
        on_sample=on_sample,
        examples=[
            {"role": "user", "content": example[0]["prompt"]},
            {
                "role": "assistant", 
                "content": example[0]["chain_of_thought"]
                + (example[0]["answer"] and f"\n\n---\n\n{example[0]['answer']}"),
            },
            # {"role": "user", "content": example[1]["prompt"]},
            # {
            #     "role": "assistant",
            #     "content": example[1]["chain_of_thought"] 
            #     + (example[1]["answer"] and f"\n\n---\n\n{example[1]['answer']}"),
            # },
        ],
        # logprobs_mask=Clue.get_logprobs_mask(),
    )


temporal_clue_episodes = [get_episode(puzzle) for puzzle in temporal_clue_puzzles]

In [6]:
temporal_clue_episodes[64:] = [
    get_episode(puzzle)
    for puzzle in json.load(open("./data/temporal-clue-puzzles-2.json"))
]

In [7]:
import polars as pl

zebra_grid_questions = pl.read_parquet(
    "hf://datasets/allenai/ZebraLogicBench-private/grid_mode/test-00000-of-00001.parquet"
).to_dicts()
random.shuffle(zebra_grid_questions)


def get_episode(question: dict) -> Episode:
    prompt = f"""{question["puzzle"]}
Fill in the grid with the correct values:

| {' | '.join(question["solution"]["header"])} |
| {' | '.join(["-" * len(header) for header in question["solution"]["header"]])} |
"""

    for _ in question["solution"]["rows"]:
        prompt += f"| {' | '.join([" " * len(header) for header in question["solution"]["header"]])} |\n"

    pattern = re.compile(
        r"\| " + r"\|".join(r"(.*?)" for _ in question["solution"]["header"]) + r" \|"
    )

    def on_sample(completions: list[EpisodeCompletion]):
        for completion in completions:
            assert "content" in completion.last_assistant_message and isinstance(
                completion.last_assistant_message["content"], str
            )
            num_cells = sum(len(row) for row in question["solution"]["rows"])
            num_correct = 0
            for match, row in zip(
                re.findall(pattern, completion.last_assistant_message["content"])[
                    -len(question["solution"]["rows"]) :
                ],
                question["solution"]["rows"],
            ):
                for cell, value in zip(match, row):
                    if cell.strip().lower() == value.lower():
                        num_correct += 1
            completion.commit(reward=num_correct / num_cells)

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )

zebra_grid_episodes = [get_episode(question) for question in zebra_grid_questions]

In [8]:
from datasets import load_dataset

math_questions = list(
    load_dataset("lighteval/MATH", "all")["train"].to_iterable_dataset()  # type: ignore
)
random.shuffle(math_questions)


question_solution = None
pattern = re.compile(r"\\boxed{([^}]+)}")


def get_episode(question: dict) -> Episode:
    prompt = (
        f"{question['problem']}\n\n"
        "Solve this math problem and show your work. Your final answer MUST be "
        "formatted in a LaTeX box using \\boxed{{}}. For example: "
        "$1+1=\\boxed{{2}}$\n\n"
        "You can submit multiple attempts. Each attempt should end with a boxed "
        "answer. Your last answer will be weighted the most, but you can get "
        "partial credit if an earlier answer is correct. If after multiple "
        "attempts you decide an earlier answer is the correct one, just submit "
        "it again to get full credit."
    )

    global question_solution
    question_solution = question["solution"]
    solution = re.search(pattern, question["solution"])
    assert solution is not None, question["solution"]
    solution = solution.group(1)

    def on_sample(completions: list[EpisodeCompletion]):
        for completion in completions:
            content = completion.last_assistant_message.get("content")
            assert isinstance(content, str)
            solutions = [
                match.group(1) for match in re.finditer(r"\\boxed{([^}]+)}", content)
            ][::-1]
            try:
                reward = 0.9 ** solutions.index(solution)
            except ValueError:
                reward = 0
            completion.commit(reward=reward)

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )


math_episodes = [
    get_episode(question)
    for question in math_questions[:2048]
    if re.search(pattern, question["solution"]) is not None
]

In [9]:
import asyncio
from dataclasses import dataclass, field
from lib.rl.completion import SplitMethod
from lib.rl.completion_sampler import CompletionSampler, SamplingKwargs
from lib.rl.trainer import ExploreImpl, ExploreOptions
from lib.tokenizer import Tokenizer
import numpy as np
from typing import Callable


@dataclass
class DefaultExploreImpl(ExploreImpl):
    explore_options: ExploreOptions

    async def __call__(
        self,
        completion_sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode | BaseException],
        update_progress: Callable[[float], None],
    ) -> None:
        def done_callback(task: asyncio.Task[Episode]) -> None:
            try:
                done_episodes.put_nowait(task.result())
            except BaseException as exception:
                done_episodes.put_nowait(exception)

        priority = 1
        while episode := await ready_episodes.get():
            asyncio.create_task(
                self._explore_episode(
                    completion_sampler, tokenizer, episode, update_progress, priority
                )
            ).add_done_callback(done_callback)
            priority += 1

    async def _explore_episode(
        self,
        completion_sampler: CompletionSampler,
        tokenizer: Tokenizer,
        episode: Episode,
        update_progress: Callable[[float], None],
        priority: int,
    ) -> Episode:
        for _ in range(self.explore_options.iterations):
            await episode.sample_completions(
                completion_sampler=completion_sampler,
                tokenizer=tokenizer,
                num_parents=self.explore_options.num_parents,
                branch_factor=self.explore_options.branch_factor,
                get_recovery_pattern=self.explore_options.get_recovery_pattern,
                max_splits_per_completion=self.explore_options.max_split_points
                or self.explore_options.num_parents,
                priority=priority,
                sample_probability_power=self.explore_options.get_sample_probability_power(),
                sampling_kwargs=self.explore_options.sampling_kwargs,
                split_by=self.explore_options.split_method,
                split_separators=self.explore_options.split_separators,
            )
            update_progress(1 / self.explore_options.iterations)
        return episode


@dataclass
class SimpleExploreImpl(ExploreImpl):
    num_samples: int
    sampling_kwargs: SamplingKwargs | None = None

    async def __call__(
        self,
        completion_sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode | BaseException],
        update_progress: Callable[[float], None],
    ) -> None:
        while episode := await ready_episodes.get():
            task = asyncio.create_task(
                episode.sample_completions(
                    completion_sampler,
                    tokenizer,
                    num_parents=1,
                    branch_factor=self.num_samples,
                    sampling_kwargs=self.sampling_kwargs,
                )
            )

            def done_callback(_: asyncio.Task[bool], episode=episode) -> None:
                try:
                    done_episodes.put_nowait(episode)
                    update_progress(1)
                except BaseException as e:
                    done_episodes.put_nowait(e)

            task.add_done_callback(done_callback)


@dataclass
class TreeExploreImpl(ExploreImpl):
    branch_factor: int
    depth: int
    num_roots: int | None = None
    best_leaf_sampling_temperature: float = 0.01
    sampling_kwargs: SamplingKwargs | None = None
    split_method: SplitMethod = "count"
    split_separators: set[str] = field(default_factory=set)

    async def __call__(
        self,
        completion_sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode | BaseException],
        update_progress: Callable[[float], None],
    ) -> None:
        model = await completion_sampler.get_model()

        async def expand(episode: Episode, priority: int) -> None:
            num_roots = self.num_roots or self.branch_factor

            # If there are existing trajectories, we'll sample one
            # of the best ones to stabilize and improve training.
            leaves = list(episode.completion.leaves())
            if leaves:
                best_leaf = random.choices(
                    leaves,
                    weights=[
                        np.exp(
                            leaf.cumulative_reward()
                            / self.best_leaf_sampling_temperature
                        )
                        for leaf in leaves
                    ],
                    k=1,
                )[0]
                best_leaf = best_leaf.recursive_copy(model=model)
                best_leaf.commit()
                while best_leaf.parent and best_leaf.parent.parent is None:
                    best_leaf = best_leaf.merge()

            pending: set[asyncio.Task] = {
                asyncio.create_task(
                    episode.sample_completions(
                        completion_sampler,
                        tokenizer,
                        num_parents=1,
                        branch_factor=num_roots - 1 if best_leaf else num_roots,
                        priority=priority,
                        sampling_kwargs=self.sampling_kwargs,
                        split_by=self.split_method,
                        split_separators=self.split_separators,
                    )
                )
            }

            num_leaves = 0
            while pending:
                finished, pending = await asyncio.wait(
                    pending, return_when=asyncio.FIRST_COMPLETED
                )
                for task in finished:
                    try:
                        task.result()
                    except BaseException as e:
                        await done_episodes.put(e)
                        return
                _num_leaves = 0
                for leaf in episode.completion.leaves(model=model):
                    _num_leaves += 1
                    num_partitions = self.depth - leaf.depth() + 1
                    if num_partitions > 1:
                        parents = list(
                            leaf.split(
                                by=self.split_method,
                                at=(
                                    split / num_partitions
                                    for split in range(1, num_partitions)
                                ),
                                separators=self.split_separators,
                                cache=True,
                            )
                        )[:-1]
                        for parent in parents:
                            pending.add(
                                asyncio.create_task(
                                    episode._sample_completions(
                                        parent=parent,
                                        model=model,
                                        completion_sampler=completion_sampler,
                                        tokenizer=tokenizer,
                                        branch_factor=self.branch_factor,
                                        fork_decay=1.0,
                                        recovery_pattern=None,
                                        split_separators=self.split_separators,
                                        sampling_kwargs=self.sampling_kwargs
                                        or SamplingKwargs(),
                                        priority=priority,
                                    )
                                )
                            )
                update_progress(
                    (_num_leaves - num_leaves)
                    / (num_roots * (self.branch_factor ** (self.depth - 1)))
                )
                num_leaves = _num_leaves

            await done_episodes.put(episode)

        priority = 0
        while episode := await ready_episodes.get():
            priority += 1
            asyncio.create_task(expand(episode, priority))


@dataclass
class IterativeVineExploreImpl(ExploreImpl):
    branch_factor: int
    depth: int
    sampling_kwargs: SamplingKwargs | None = None
    split_method: SplitMethod = "count"
    split_separators: set[str] = field(default_factory=set)

    async def __call__(
        self,
        completion_sampler: CompletionSampler,
        tokenizer: Tokenizer,
        ready_episodes: asyncio.Queue[Episode],
        done_episodes: asyncio.Queue[Episode | BaseException],
        update_progress: Callable[[float], None],
    ) -> None:
        model = await completion_sampler.get_model()

        async def iterate_vine(episode: Episode, priority: int) -> None:

            for depth in range(self.depth):
                if depth == 0:
                    parent = episode.completion
                else:
                    parent = max(
                        episode.completion.leaves(model=model),
                        key=lambda leaf: leaf.reward,
                    )
                    parent = list(
                        parent.split(
                            by=self.split_method,
                            at=[1 / (self.depth - depth + 1)],
                            separators=self.split_separators,
                            cache=True,
                        )
                    )[0]
                await episode._sample_completions(
                    parent=parent,
                    model=model,
                    completion_sampler=completion_sampler,
                    tokenizer=tokenizer,
                    branch_factor=self.branch_factor,
                    fork_decay=1.0,
                    recovery_pattern=None,
                    split_separators=self.split_separators,
                    sampling_kwargs=self.sampling_kwargs or SamplingKwargs(),
                    priority=priority,
                )
                update_progress(1 / self.depth)

            await done_episodes.put(episode)

        priority = 0
        while episode := await ready_episodes.get():
            priority += 1
            asyncio.create_task(iterate_vine(episode, priority))

In [11]:
from aioitertools.helpers import maybe_await
import asyncio
from collections import Counter
import itertools as it
from lib import clue
from lib.rl.episode import Episode
from lib.rl.ppo import PPOLoss
from lib.rl.recipe import ComponentConfig, TuneRecipeConfig
from lib.rl.trainer import Eval, ExploreOptions, Trainer, vLLMConfig
import torch
from torchtune.models.llama3_1 import llama3_1_8b
from typing import AsyncIterable


episodes_per_iteration = 16 * torch.cuda.device_count()


async def train_episodes(
    revisit_frequency: float = 0.0,
) -> AsyncIterable[Episode | BaseException]:
    pending: set[asyncio.Task[Episode | BaseException]] = set()
    episodes = (
        maybe_await(episode)
        for episodes in zip(
            # (clue.sample_random_episode() for _ in it.repeat(0)),
            it.cycle(temporal_clue_episodes[64:]),
            # it.cycle(zebra_grid_episodes[64:]),
            # it.cycle(math_episodes[64:]),
        )
        for episode in episodes
    )
    visited_episodes = Counter[Episode]()
    while True:
        pending.update(
            asyncio.create_task(next(episodes))
            for _ in range(episodes_per_iteration - len(pending))  # type: ignore
        )
        done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
        if len(visited_episodes) > episodes_per_iteration:
            while random.random() < revisit_frequency:
                episode = min(visited_episodes, key=lambda e: visited_episodes[e])
                visited_episodes[episode] += 1
                yield episode
        for task in done:
            try:
                result = task.result()
                if isinstance(result, Episode):
                    visited_episodes[result] += 1
                yield result
            except BaseException as e:
                yield e


async def val_episodes() -> AsyncIterable[Episode | BaseException]:
    for fut in asyncio.as_completed(clue.sample_random_episode() for _ in range(64)):
        try:
            yield await fut
        except BaseException as e:
            yield e


explore_options = ExploreOptions(
    iterations=1,
    num_parents=6,
    branch_factor=2,
    patience=60,
    advantage_max_weight=0.15,
    sample_probability_power=None,
    sampling_kwargs={"max_tokens": 4096, "stop": ["://", "<|end_of_text|>"]},
    # split_method="prob",
    # split_point_std_deviation=0.5,
)

model_name = "rl115"

trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    output_dir=f"./models/{model_name}",
    explore_options=explore_options,
    # explore_impl=DefaultExploreImpl(explore_options),
    # explore_impl=SimpleExploreImpl(
    #     num_samples=8, sampling_kwargs={"max_tokens": 4096}
    # ),
    explore_impl=TreeExploreImpl(
        branch_factor=2,
        depth=5,
        num_roots=8,
        # best_leaf_sampling_temperature=0.05,
        sampling_kwargs={"max_tokens": 4096, "stop": ["://", "<|end_of_text|>"]},
    ),
    force_terminate_vllms=True,
    train_episodes=train_episodes(revisit_frequency=0.5),
    episodes_per_iteration=episodes_per_iteration,
    max_mask_sequence_batch_size=1,
    evals=[
        # Eval(
        #     name="variable_clue",
        #     episodes=val_episodes(),
        #     samples_per_episode=3,
        #     sampling_kwargs={"max_tokens": 4096},
        # ),
        Eval(
            name="temporal_clue",
            episodes=temporal_clue_episodes[:64],
            samples_per_episode=3,
            sampling_kwargs={"max_tokens": 4096, "stop": ["://", "<|end_of_text|>"]},
        ),
        # Eval(
        #     name="zebra_grid",
        #     episodes=zebra_grid_episodes[:64],
        #     samples_per_episode=3,
        #     sampling_kwargs={"max_tokens": 4096},
        # ),
        # Eval(
        #     name="math",
        #     episodes=math_episodes[:64],
        #     samples_per_episode=3,
        #     sampling_kwargs={"max_tokens": 4096},
        # ),
    ],
    tune_model=llama3_1_8b,
    tune_model_type="LLAMA3",
    tune_recipe_config=TuneRecipeConfig(
        seed=42,
        shuffle=True,
        num_output_chunks=4,
        resume_from_checkpoint=False,
        batch_size=1,
        epochs=1,
        max_steps_per_epoch=32,
        optimizer=ComponentConfig(
            "torch.optim.AdamW",
            # "bitsandbytes.optim.PagedAdamW8bit",
            # "bitsandbytes.optim.AdamW",
            # params=PLACEHOLDER,
            lr=2e-6,
            fused=True,
        ),
        loss=ComponentConfig(
            PPOLoss,
            policy_coef=0.0,
            clip_epsilon=0.2,
            unclipped_policy_coef=0.0,
            tanh_log_policy_coef=0.8,
            tanh_log_policy_beta=2.0,
            use_reference_logprobs=True,
            advantage_prediction_coef=0.0,
            predicted_advantage_weight=0.0,
            value_coef=0.0,
            entropy_coef=0.0,
            entropy_target=0.6,
            entropy_target_coef=0.05,
            kl_coef=0.05,
            weighted_entropy_coef=0.0,
            weighted_kl_coef=0.0,
            weighted_ce_coef=0.0,
            normalize_values=False,
            normalize_value_predictions=False,
            normalize_advantages=False,
        ),
        compile=False,
        optimizer_in_bwd=False,
        gradient_accumulation_steps=1,
        enable_activation_checkpointing=True,
        enable_activation_offloading=False,
        custom_sharded_layers=["tok_embeddings", "output"],
        log_every_n_steps=1,
        log_peak_memory_stats=True,
    ),
    # tune_run=False,
    tune_sequence_length=16384,
    vllm_config=vLLMConfig(
        env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
        kwargs=dict(
            block_size=32,
            disable_log_requests=True,
            enable_chunked_prefill=True,
            enable_prefix_caching=True,
            enforce_eager=True,
            gpu_memory_utilization=0.9,
            max_model_len=16384,
            max_num_seqs=512,
            max_num_batched_tokens=16384,
            preemption_mode="swap",
            return_tokens_as_token_ids=True,
            swap_space=100,
        ),
        max_concurrent_samples=512,
        min_time_between_requests=0.0,
        timeout=120 + 15 * torch.cuda.device_count(),
    ),
    wandb_kwargs=dict(
        name=model_name,
        id=model_name,
    ),
)

INFO 01-05 02:43:25 config.py:510] This model supports multiple tasks: {'embed', 'classify', 'reward', 'generate', 'score'}. Defaulting to 'generate'.
INFO 01-05 02:43:25 llm_engine.py:234] Initializing an LLM engine (v0.6.6.post1) with config: model='NousResearch/Hermes-2-Theta-Llama-3-8B', speculative_config=None, tokenizer='NousResearch/Hermes-2-Theta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), s

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
await trainer.train(iterations=50, verbosity=1)

Starting 2 vLLM servers...
$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --port=8011 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/32 [00:00<?, ?episode/s]

Tuning model on 46 sequences
Experienced the following exception while stopping vLLM servers: <class 'TimeoutError'> 
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl115/config.yaml


1|23|Loss: 0.0120: 100%|██████████| 23/23 [04:12<00:00, 10.74s/it, advantage=0.0000, advantage_prediction=0.2794, entropy=0.5468, entropy_target=0.0532, exploration=0.4170, kl_div=0.1897, policy=1838595200.0000, reinforce=0.0191, tanh_log_policy=-0.0002, unclipped_policy=-9021022208.0000, value=0.0000, weighted_ce=0.0191, weighted_entropy=-0.0258, weighted_kl_div=-0.0075]       

Saved iteration 1 model files to /home/ubuntu/atreides/experiments/models/rl115/0001
Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0001 --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0001 --port=8012 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/32 [00:00<?, ?episode/s]

Tuning model on 49 sequences
Experienced the following exception while stopping vLLM servers: <class 'TimeoutError'> 
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl115/config.yaml


1|24|Loss: 0.0201: 100%|██████████| 24/24 [04:23<00:00, 10.72s/it, advantage=0.0000, advantage_prediction=0.6457, entropy=0.6227, entropy_target=0.0227, exploration=0.9370, kl_div=0.3480, policy=402964160.0000, reinforce=0.0076, tanh_log_policy=0.0019, unclipped_policy=-2060111616.0000, value=0.0000, weighted_ce=0.0076, weighted_entropy=-0.0093, weighted_kl_div=-0.0044]        

Saved iteration 2 model files to /home/ubuntu/atreides/experiments/models/rl115/0002
Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0002 --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0002 --port=8013 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/32 [00:00<?, ?episode/s]

Tuning model on 58 sequences
Experienced the following exception while stopping vLLM servers: <class 'TimeoutError'> 
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl115/config.yaml


1|29|Loss: 0.0074: 100%|██████████| 29/29 [05:17<00:00, 10.72s/it, advantage=0.0000, advantage_prediction=2.7164, entropy=0.6366, entropy_target=0.0366, exploration=0.4686, kl_div=0.2028, policy=5342518837248.0000, reinforce=-0.0835, tanh_log_policy=-0.0057, unclipped_policy=5340809658368.0000, value=0.0000, weighted_ce=-0.0835, weighted_entropy=0.0731, weighted_kl_div=0.0266]    

Saved iteration 3 model files to /home/ubuntu/atreides/experiments/models/rl115/0003
Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0003 --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0003 --port=8014 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/32 [00:00<?, ?episode/s]

Tuning model on 56 sequences
Experienced the following exception while stopping vLLM servers: <class 'TimeoutError'> 
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl115/config.yaml


1|28|Loss: 0.0127: 100%|██████████| 28/28 [05:06<00:00, 10.71s/it, advantage=0.0000, advantage_prediction=0.8513, entropy=0.5809, entropy_target=0.0191, exploration=0.7960, kl_div=0.3106, policy=67462365184.0000, reinforce=-0.0402, tanh_log_policy=-0.0047, unclipped_policy=18982391808.0000, value=0.0000, weighted_ce=-0.0402, weighted_entropy=0.0196, weighted_kl_div=0.0159]    

Saved iteration 4 model files to /home/ubuntu/atreides/experiments/models/rl115/0004
Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0004 --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0004 --port=8015 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/32 [00:00<?, ?episode/s]

Tuning model on 59 sequences
Experienced the following exception while stopping vLLM servers: <class 'TimeoutError'> 
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl115/config.yaml


1|29|Loss: 0.0146: 100%|██████████| 29/29 [05:16<00:00, 10.73s/it, advantage=0.0000, advantage_prediction=0.5995, entropy=0.4606, entropy_target=0.1394, exploration=0.3083, kl_div=0.1560, policy=2371535872.0000, reinforce=0.0007, tanh_log_policy=-0.0003, unclipped_policy=-27041712128.0000, value=0.0000, weighted_ce=0.0007, weighted_entropy=-0.0045, weighted_kl_div=-0.0008]         

Saved iteration 5 model files to /home/ubuntu/atreides/experiments/models/rl115/0005
Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0005 --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl115/0005 --port=8016 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/32 [00:00<?, ?episode/s]

In [12]:
trainer.tune_recipe_config.loss.kl_coef = 0.3
await trainer.tune(trainer.explore_results[-1], verbosity=2)

Tuning model on 56 sequences
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl111/config.yaml
Running with torchrun...


W0105 00:49:18.623000 143029 torch/distributed/run.py:793] 
W0105 00:49:18.623000 143029 torch/distributed/run.py:793] *****************************************
W0105 00:49:18.623000 143029 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0105 00:49:18.623000 143029 torch/distributed/run.py:793] *****************************************
INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: lib.rl.mlp_head_checkpointer.MLPHeadCheckpointer
  checkpoint_dir: /home/ubuntu/atreides/experiments/models/rl111/0001
  checkpoint_files:
  - /home/ubuntu/atreides/experiments/models/rl111/0001/hf_model_0001.pt
  - /home/ubuntu/atreides/experiments/models/rl111/0001/hf_model_0002.pt
  - /home/ubuntu/atreides/experiments/models/rl111/0001/hf_mo

Saved iteration 2 model files to /home/ubuntu/atreides/experiments/models/rl111/0002


In [13]:
trainer.tune_recipe_config.loss.kl_coef = 0.4
await trainer.train(iterations=50, verbosity=1)

Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl111/0002 --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl111/0002 --port=8013 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Tuning model on 62 sequences
Experienced the following exception while stopping vLLM servers: <class 'TimeoutError'> 
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl111/config.yaml


1|31|Loss: 0.0520: 100%|██████████| 31/31 [05:38<00:00, 10.72s/it, advantage=0.0000, advantage_prediction=0.7889, entropy=0.6102, entropy_target=0.0102, exploration=0.3028, kl_div=0.1441, policy=-0.0074, reinforce=0.0150, tanh_log_policy=-0.0077, unclipped_policy=-0.0423, value=0.0000, weighted_ce=0.0150, weighted_entropy=-0.0291, weighted_kl_div=-0.0009]

Saved iteration 3 model files to /home/ubuntu/atreides/experiments/models/rl111/0003
Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl111/0003 --port=8004 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl111/0003 --port=8014 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

In [29]:
trainer.tune_recipe_config.shuffle = True

In [14]:
# trainer.tune_recipe_config.shuffle = False
trainer.tune_recipe_config.loss.kl_coef = 0.4
await trainer.tune(trainer.explore_results[-1], verbosity=2)

Tuning model on 51 sequences
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl110/config.yaml
Running with torchrun...


W0104 23:57:40.116000 124059 torch/distributed/run.py:793] 
W0104 23:57:40.116000 124059 torch/distributed/run.py:793] *****************************************
W0104 23:57:40.116000 124059 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0104 23:57:40.116000 124059 torch/distributed/run.py:793] *****************************************
INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: lib.rl.mlp_head_checkpointer.MLPHeadCheckpointer
  checkpoint_dir: /home/ubuntu/atreides/experiments/models/rl110/0001
  checkpoint_files:
  - /home/ubuntu/atreides/experiments/models/rl110/0001/hf_model_0001.pt
  - /home/ubuntu/atreides/experiments/models/rl110/0001/hf_model_0002.pt
  - /home/ubuntu/atreides/experiments/models/rl110/0001/hf_mo

Saved iteration 2 model files to /home/ubuntu/atreides/experiments/models/rl110/0002


In [11]:
result = await trainer.explore()

Starting 2 vLLM servers...
$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --port=8000 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve NousResearch/Hermes-2-Theta-Llama-3-8B --port=8001 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
INFO 01-04 17:58:02 api_server.py:712] vLLM API server version 0.6.6.post1
INFO 01-04 17:58:02 api_server.py:713] args: Namespace(subparser='serve', model_tag='NousResearch/Hermes-2-Theta-Llama-3-8B', config='', host=None, port=8000, uvicorn_log_level='info', a

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.47it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.48it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.42it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.43it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  2.12it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  2.11it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.78it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.74it/s]

Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.79it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.75i

INFO 01-04 17:58:15 model_runner.py:1099] Loading model weights took 14.9595 GB
INFO 01-04 17:58:15 model_runner.py:1099] Loading model weights took 14.9595 GB
INFO 01-04 17:58:15 worker.py:241] Memory profiling takes 0.75 seconds
INFO 01-04 17:58:15 worker.py:241] the current vLLM instance can use total_gpu_memory (79.10GiB) x gpu_memory_utilization (0.90) = 71.19GiB
INFO 01-04 17:58:15 worker.py:241] model weights take 14.96GiB; non_torch_memory takes 0.17GiB; PyTorch activation peak memory takes 2.48GiB; the rest of the memory reserved for KV Cache is 53.57GiB.
INFO 01-04 17:58:16 gpu_executor.py:76] # GPU blocks: 13714, # CPU blocks: 25600
INFO 01-04 17:58:16 gpu_executor.py:80] Maximum concurrency for 16384 tokens per request: 26.79x
INFO 01-04 17:58:16 worker.py:241] Memory profiling takes 0.78 seconds
INFO 01-04 17:58:16 worker.py:241] the current vLLM instance can use total_gpu_memory (79.10GiB) x gpu_memory_utilization (0.90) = 71.19GiB
INFO 01-04 17:58:16 worker.py:241] model

INFO:     Started server process [16149]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8001 (Press CTRL+C to quit)


INFO 01-04 17:59:01 api_server.py:640] Using supplied chat template:
INFO 01-04 17:59:01 api_server.py:640] None
INFO 01-04 17:59:01 launcher.py:19] Available routes are:
INFO 01-04 17:59:01 launcher.py:27] Route: /openapi.json, Methods: HEAD, GET
INFO 01-04 17:59:01 launcher.py:27] Route: /docs, Methods: HEAD, GET
INFO 01-04 17:59:01 launcher.py:27] Route: /docs/oauth2-redirect, Methods: HEAD, GET
INFO 01-04 17:59:01 launcher.py:27] Route: /redoc, Methods: HEAD, GET
INFO 01-04 17:59:01 launcher.py:27] Route: /health, Methods: GET
INFO 01-04 17:59:01 launcher.py:27] Route: /tokenize, Methods: POST
INFO 01-04 17:59:01 launcher.py:27] Route: /detokenize, Methods: POST
INFO 01-04 17:59:01 launcher.py:27] Route: /v1/models, Methods: GET
INFO 01-04 17:59:01 launcher.py:27] Route: /version, Methods: GET
INFO 01-04 17:59:01 launcher.py:27] Route: /v1/chat/completions, Methods: POST
INFO 01-04 17:59:01 launcher.py:27] Route: /v1/completions, Methods: POST
INFO 01-04 17:59:01 launcher.py:27] Ro

INFO:     Started server process [16147]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INFO 01-04 17:59:01 chat_utils.py:333] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
INFO:     127.0.0.1:35036 - "POST /v1/chat/completions HTTP/1.1" 200 OK
vLLM server started succesfully. Logs can be found at ./logs/vllm.log
INFO 01-04 17:59:02 chat_utils.py:333] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
INFO:     127.0.0.1:50284 - "POST /v1/chat/completions HTTP/1.1" 200 OK
vLLM server started succesfully. Logs can be found at ./logs/vllm.log


explore:   0%|          | 0/64 [00:00<?, ?episode/s]

In [18]:
from torch.nn.attention.flex_attention import create_block_mask

In [None]:
from torch.nn.attention.flex_attention import (
        BlockMask,
        create_block_mask as create_block_causal_mask_flex,
        flex_attention,
    )

In [16]:
from torchtune.modules.attention_utils import _MaskType
_MaskType

typing.Union[torch.Tensor, torch.nn.attention.flex_attention.BlockMask]

In [30]:
trainer.tune_recipe_config.clip_grad_norm = 1.0
await trainer.tune(result, verbosity=2)

Tuning model on 43 sequences
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl101/config.yaml
Running with torchrun...


W0104 18:50:54.849000 19603 torch/distributed/run.py:793] 
W0104 18:50:54.849000 19603 torch/distributed/run.py:793] *****************************************
W0104 18:50:54.849000 19603 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0104 18:50:54.849000 19603 torch/distributed/run.py:793] *****************************************
INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: lib.rl.mlp_head_checkpointer.MLPHeadCheckpointer
  checkpoint_dir: /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297454c680725
  checkpoint_files:
  - /home/ubuntu/.cache/huggingface/hub/models--NousResearch--Hermes-2-Theta-Llama-3-8B/snapshots/57a73110702e7b05ba3f39fef36297

Saved iteration 1 model files to /home/ubuntu/atreides/experiments/models/rl101/0001


In [31]:
await trainer.train(iterations=50, verbosity=1)

Starting 2 vLLM servers...
$ vllm serve /home/ubuntu/atreides/experiments/models/rl101/0001 --port=8000 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
$ vllm serve /home/ubuntu/atreides/experiments/models/rl101/0001 --port=8002 --block-size=32 --disable-log-requests --enable-chunked-prefill --enable-prefix-caching --enforce-eager --gpu-memory-utilization=0.9 --max-model-len=16384 --max-num-seqs=512 --max-num-batched-tokens=16384 --preemption-mode=swap --return-tokens-as-token-ids --swap-space=100 --api-key=default
vLLM servers started succesfully. Logs can be found at ./logs/vllm.log


temporal_clue:   0%|          | 0/64 [00:00<?, ?episode/s]

explore:   0%|          | 0/64 [00:00<?, ?episode/s]

Tuning model on 28 sequences
Experienced the following exception while stopping vLLM servers: <class 'TimeoutError'> 
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl101/config.yaml


AssertionError: No model checkpoint files found to save in output directory /home/ubuntu/atreides/experiments/models/rl101

In [32]:
await trainer.tune(trainer.explore_results[-1], verbosity=2)

Tuning model on 28 sequences
$ tune run --nnodes=1 --nproc-per-node=2 lib.rl.recipe.TuneRecipe --config /home/ubuntu/atreides/experiments/models/rl101/config.yaml
Running with torchrun...


W0104 19:02:08.771000 21853 torch/distributed/run.py:793] 
W0104 19:02:08.771000 21853 torch/distributed/run.py:793] *****************************************
W0104 19:02:08.771000 21853 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0104 19:02:08.771000 21853 torch/distributed/run.py:793] *****************************************
INFO:torchtune.utils._logging:Running FullFinetuneRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: lib.rl.mlp_head_checkpointer.MLPHeadCheckpointer
  checkpoint_dir: /home/ubuntu/atreides/experiments/models/rl101/0001
  checkpoint_files:
  - /home/ubuntu/atreides/experiments/models/rl101/0001/hf_model_0001.pt
  - /home/ubuntu/atreides/experiments/models/rl101/0001/hf_model_0002.pt
  - /home/ubuntu/atreides/experiments/models/rl101/0001/hf_model_

AssertionError: No model checkpoint files found to save in output directory /home/ubuntu/atreides/experiments/models/rl101