In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import asyncio
from lib.clue import Clue, DeductiveSolver
from lib.rl.episode import Episode, EpisodeCompletion
from lib.rl.trainer import Trainer
import re


def sample_random_episode() -> Episode:
    game = Clue(
        num_players=3,
        elements={
            "suspect": Clue.suspects[:3],
            "weapon": Clue.weapons[:3],
            "room": Clue.rooms[:3],
            # "motive": Clue.motives[:6],
            # "time": Clue.get_times("21:00", "03:00", "1h"),
        },
    )
    game.play(
        deductive_solver=DeductiveSolver(
            # note_cards_in_hand=False,
            # note_responses_to_suggestions=False,
            # note_cards_that_players_do_not_have=False,
            # check_unique_card_placement_constraints=False,
            # check_player_hand_size_constraints=False,
            check_solution_has_one_and_only_one_card_per_element=False,
            check_one_of_constraints=False,
            check_inverse_one_of_constraints=False,
            merge_and_check_disjoint_inverse_one_of_constraints=False,
            exhaustively_test_possible_assignments=False,
        ),
        cp_solver_max_solve_time_per_turn=0.01,
        check_cp_solver_grid=False,
        check_if_deductive_solver_and_cp_solver_grids_match=False,
        print_playthrough=False,
    )
    prompt = game.get_prompt()
    follow_up = "Fill out your answer like this:\n" + "\n".join(
        f"{element.capitalize()}: <#{element.upper()}#>" for element in game.elements
    )

    async def reward_completion(completion: EpisodeCompletion) -> EpisodeCompletion:
        if len(completion.messages) == 2:
            follow_up_completion = await completion.follow_up(
                messages=[
                    {"role": "user", "content": follow_up},
                ]
            )
        answer = follow_up_completion.last_assistant_message.get("content")
        assert isinstance(answer, str)
        completion.reward = sum(
            [
                bool(
                    re.search(
                        f"{element}: {solution}",
                        answer,
                        re.IGNORECASE,
                    )
                )
                for element, solution in game.solution.items()
            ]
        ) / len(game.solution)
        return completion

    async def on_sample(completions: list[EpisodeCompletion]) -> None:
        for completion in await asyncio.gather(
            *[reward_completion(completion) for completion in completions]
        ):
            completion.commit()

    return Episode(
        messages=[{"role": "user", "content": prompt}],
        on_sample=on_sample,
    )


def train_episodes():
    while True:
        yield sample_random_episode()


trainer = Trainer(
    base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
    samples_per_episode=32,
    branch_factor=5,
    train_episodes=train_episodes(),
    episodes_per_iteration=128,
    val_episodes=[sample_random_episode() for _ in range(64)],
    tune_sequence_length=8192,
    vllm_kwargs=dict(disable_log_requests=True, scheduling_policy="priority"),
    vllm_max_concurrent_requests=256,
)

In [None]:
val_score, episodes = await asyncio.gather(trainer.eval("val", 0), trainer.explore(1))

In [5]:
vllm = await trainer.vllm()
vllm.process.terminate()
trainer._vllm_task, trainer._completion_sampler = None, None

In [None]:
import os
import subprocess
from torchtune.models.llama3_1 import llama3_1_8b
from torchtune.training import cleanup_before_training
from torchtune.training.metric_logging import DiskLogger
from typing import Any

from lib.recipes.rl import ComponentConfig, RLConfig, RLRecipe
from lib.rl.pack import PackedDataset, packed_tensors
from lib.rl.ppo import PPOLoss

tensors = packed_tensors(
    episodes,
    model=trainer.model,
    sequence_length=trainer.tune_sequence_length,
    trajectories_per_episode=(
        int(trainer.samples_per_episode * trainer.tune_episode_sample_fraction)
        if trainer.tune_episode_sample_fraction < 1.0
        else None
    ),
    tokenizer=trainer.tokenizer,
)

checkpoint_dir = subprocess.run(
    f"HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download {trainer.model}",
    shell=True,
    capture_output=True,
    text=True,
).stdout.strip()
print("Checkpoint directory:", checkpoint_dir)

checkpoint_output_dir = "/home/ubuntu/atreides/experiments/models/rl"
os.makedirs(checkpoint_output_dir, exist_ok=True)

PLACEHOLDER: Any = None

config = RLConfig(
    # Dataset
    dataset=ComponentConfig(PackedDataset, tensors=tensors),
    seed=42,
    shuffle=False,
    # Model
    model=ComponentConfig(llama3_1_8b),
    num_output_chunks=4,
    # Checkpointer
    checkpointer=ComponentConfig(
        "torchtune.training.FullModelHFCheckpointer",
        checkpoint_dir=checkpoint_dir,
        checkpoint_files=[
            "model-00001-of-00004.safetensors",
            "model-00002-of-00004.safetensors",
            "model-00003-of-00004.safetensors",
            "model-00004-of-00004.safetensors",
        ],
        recipe_checkpoint=None,
        output_dir=checkpoint_output_dir,
        model_type="LLAMA3",
    ),
    resume_from_checkpoint=False,
    # Fine-tuning arguments
    batch_size=4,
    epochs=1,
    optimizer=ComponentConfig(
        # AdamW,
        "bitsandbytes.optim.PagedAdamW8bit",
        params=PLACEHOLDER,
        lr=5e-6,
        # fused=True,
    ),
    loss=ComponentConfig(
        PPOLoss,
        # clip_epsilon=0.3,
        # entropy_coef=0.0,
        # kl_coef=0.0,
        clip_epsilon=0.3,
        entropy_coef=0.025,
        kl_coef=0.025,
    ),
    max_steps_per_epoch=None,
    compile=False,
    optimizer_in_bwd=False,
    gradient_accumulation_steps=1,
    # Training env
    device="cuda",
    # Memory management
    enable_activation_checkpointing=True,
    enable_activation_offloading=False,
    custom_sharded_layers=["tok_embeddings", "output"],
    # Reduced precision
    dtype="bf16",
    # Logging
    metric_logger=ComponentConfig(
        DiskLogger, log_dir="/home/ubuntu/atreides/experiments/logs"
    ),
    output_dir="/home/ubuntu/atreides/experiments/logs",
    log_every_n_steps=1,
    log_peak_memory_stats=True,
)

recipe = RLRecipe(config)
recipe.setup(config)
recipe.train()
recipe.cleanup()
del recipe
cleanup_before_training()

In [7]:
del recipe
cleanup_before_training()

In [None]:
tensors["tokens"][0][tensors["mask"][0][1405]]

In [None]:
tensors2 = packed_tensors(
    episodes,
    model=trainer.model,
    sequence_length=trainer.tune_sequence_length,
    trajectories_per_episode=(
        int(trainer.samples_per_episode * trainer.tune_episode_sample_fraction)
        if trainer.tune_episode_sample_fraction < 1.0
        else None
    ),
    tokenizer=trainer.tokenizer,
)

In [None]:
tensors["advantages"][0][1406]

In [None]:
tensors["input_pos"][0][tensors["mask"][0][1406]]

In [None]:
print(trainer.tokenizer.decode(tensors["tokens"][0][tensors["mask"][0][1405]]))

In [None]:
print(trainer.tokenizer.decode(trainer.tokenizer.encode(next(episodes[0].completion.leaves()).all_message_params(), continue_final_message=False)))

In [None]:
print(trainer.tokenizer.decode(trainer.tokenizer.encode(next(episodes[0].completion.leaves()).all_message_params())))

In [None]:
tensors["mask"][0]

In [None]:
print(trainer.tokenizer.decode(tensors["tokens"][0]))

In [None]:
import gc

gc.collect()

In [30]:
import torch

torch.cuda.empty_cache()

In [7]:
import torch


def get_mask(ids: torch.Tensor, ancestor_ids: torch.Tensor) -> torch.Tensor:
    """Creates an attention mask for hierarchical attention based on node IDs and their ancestor IDs.

    Args:
        ids: A tensor of shape (batch_size, sequence_length) containing node IDs
        ancestor_ids: A tensor of shape (batch_size, sequence_length, max_ancestors) containing ancestor IDs for each node
            including itself, padded with zeros

    Returns:
        A boolean tensor of shape (batch_size, sequence_length, sequence_length) where True indicates
        allowed attention connections. Each position can attend to itself and any of its ancestors
        in the hierarchy, but only for previous positions (due to causal masking).
    """
    # Compare each position against all ancestors of each other position
    # Shape: (batch, seq, seq, max_ancestors)
    mask = ids.unsqueeze(1).unsqueeze(3) == ancestor_ids.unsqueeze(2)
    # Reduce over ancestors dimension to get final mask
    # Shape: (batch, seq, seq)
    mask = mask.any(dim=3)
    # Apply causal mask
    mask &= torch.tril(torch.ones_like(mask, dtype=torch.bool, device=ids.device))
    return mask


# mask = get_mask(tensors["ids"], tensors["ancestor_ids"])

In [None]:
get_mask(
    ids=torch.tensor([[0, 1, 2, 3, 3, 4, 4, 5, 5, 6]]),
    ancestor_ids=torch.tensor(
        [
            [
                [0, 0, 0],
                [1, 0, 0],
                [2, 1, 0],
                [3, 1, 0],
                [3, 1, 0],
                [4, 0, 0],
                [4, 0, 0],
                [5, 4, 0],
                [5, 4, 0],
                [6, 6, 6],
            ]
        ]
    ),
).int()

In [None]:
(~torch.isnan(tensors["advantages"][0]).unsqueeze(0) & ~torch.isnan(tensors["advantages"][0]).unsqueeze(1)).int()

In [None]:
import matplotlib.pyplot as plt
import torch


def show(mask: torch.Tensor) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(mask, cmap="inferno")
    plt.colorbar(label="Relative Position")
    plt.title("Relative Position Attention Mask")
    plt.xlabel("Target Position")
    plt.ylabel("Source Position")
    plt.show()

i = 1

show(
    tensors["mask"][i].cumsum(dim=1)
    * (
        tensors["mask"][i]
        & (
            ~torch.isnan(tensors["advantages"][i]).unsqueeze(0)
            & ~torch.isnan(tensors["advantages"][i]).unsqueeze(1)
        )
    )
)

In [None]:
torch.isnan(tensors["advantages"][i]).unsqueeze(0)
            & ~torch.isnan(tensors["advantages"][i]).unsqueeze(1)
        )

In [None]:
show((~torch.isnan(tensors["advantages"][0]).unsqueeze(0) & ~torch.isnan(tensors["advantages"][0]).unsqueeze(1)).int())

In [None]:
%%timeit
tensors = await trainer.tune(episodes[:100])

In [None]:
tensors

In [None]:
await trainer.eval("val")

In [None]:
episodes = await trainer.explore()

In [None]:
len(episodes)

In [None]:
from IPython.display import HTML

HTML(
    f'<div style="white-space: pre-wrap">{list(episodes[2].completion.leaves())[0].html(30.0)}</div>'
)

In [None]:
episodes: list[Episode] = trainer.eval_episodes["val"]  # type: ignore
divisor = max(
    sum(
        1
        for episode in episodes
        if any(child.model == trainer.model for child in episode.completion.children)
    ),
    1,
)
score = (
    sum(episode.completion.value(model=trainer.model) for episode in episodes) / divisor
)
score, divisor

In [None]:
[len(list(episode.completion.descendants())) for episode in episodes]

In [None]:
list(list(episodes[0].completion.children)[0].children)[0].message_params()

In [None]:
[episode.completion.value(model=trainer.model) for episode in episodes]

In [None]:
sum(
    sum(child.model == trainer.model for child in episode.completion.children)
    for episode in episodes
)

In [None]:
trainer.eval_scores["val"]

In [None]:
import torch


def get_mask(ids: torch.Tensor, parent_ids: torch.Tensor) -> torch.Tensor:
    """
    Creates a causal attention mask based on token group IDs and their parent relationships.

    Each token can attend to:
    1. All preceding tokens in its own group (tokens with same ID)
    2. All tokens in any ancestor group (following parent_ids chain up)
    The mask enforces causality - tokens cannot attend to future positions.

    Args:
        ids: 1D tensor of token group IDs. Shape: [sequence_length]
        parent_ids: 1D tensor of parent group IDs. Shape: [sequence_length]
            Any integer value is valid - if a parent ID is not found in ids,
            that group is treated as a root node.
            Each ID must have a corresponding parent ID at the same index.

    Returns:
        2D boolean tensor of shape [sequence_length, sequence_length] where
        mask[i,j] = True means position i can attend to position j.
        False means position i cannot attend to position j.

    Example:
        ids = torch.tensor([0, 0, 1, 1])
        parent_ids = torch.tensor([100, 100, 0, 0])
        # Group 1 tokens can attend to group 0 tokens since 0 is their parent
        # Group 0's parent (100) is not in ids so it's treated as root
    """
    # Create mask of shape [seq_len, seq_len] initialized to False
    seq_len = len(ids)
    mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)

    # For each target position
    for i in range(seq_len):
        # Get current id and parent id
        curr_id = ids[i]
        curr_parent = parent_ids[i]

        # For each source position up to and including target
        for j in range(i + 1):
            # Token can attend to itself and earlier tokens in same group
            if ids[j] == curr_id:
                mask[i, j] = True

            # Token can attend to tokens in ancestor groups
            ancestor_id = curr_parent
            while ancestor_id in ids:  # Only follow chain while ancestor exists in ids
                if ids[j] == ancestor_id:
                    mask[i, j] = True
                # Move up to next ancestor
                # Find first occurrence of current ancestor to get its parent
                ancestor_idx = torch.where(ids == ancestor_id)[0][0]
                ancestor_id = parent_ids[ancestor_idx]

    return mask


def get_fast_mask(ids: torch.Tensor, parent_ids: torch.Tensor) -> torch.Tensor:
    """Faster implementation of get_mask using vectorized operations"""
    seq_len = len(ids)

    # Get unique group IDs and map them to indices
    unique_ids = torch.unique(ids)
    group_id_list = unique_ids.tolist()  # Converts tensor to list of Python numbers
    group_id_to_index = {group_id: idx for idx, group_id in enumerate(group_id_list)}
    num_groups = len(unique_ids)

    # Build group_id_to_parent_id mapping
    group_parents = {}
    for group_id in group_id_list:
        indices = (ids == group_id).nonzero(as_tuple=True)[0]
        idx = indices[0].item()  # Select the first occurrence
        parent_id = parent_ids[idx].item()
        group_parents[group_id] = parent_id

    # For each group, compute its ancestors
    group_ancestors = {}
    for group_id in group_id_list:
        ancestors = set()
        parent_id = group_parents.get(group_id, None)
        while parent_id in group_id_list and parent_id not in ancestors:
            ancestors.add(parent_id)
            parent_id = group_parents.get(parent_id, None)
        group_ancestors[group_id] = ancestors

    # Create allowed_groups per group index
    allowed_groups = torch.zeros((num_groups, num_groups), dtype=torch.bool)
    for i, group_id in enumerate(group_id_list):
        # Each group can attend to itself
        allowed_groups[i, i] = True
        # And its ancestors
        for ancestor_id in group_ancestors[group_id]:
            ancestor_idx = group_id_to_index[ancestor_id]
            allowed_groups[i, ancestor_idx] = True

    # Map positions to group indices
    group_indices = torch.tensor(
        [group_id_to_index[group_id.item()] for group_id in ids], dtype=torch.long
    )

    # Get allowed groups per position
    allowed_groups_per_position = allowed_groups[
        group_indices
    ]  # Shape: [seq_len, num_groups]

    # Create group indices matrix for source positions
    group_indices_source = group_indices.unsqueeze(0).expand(
        seq_len, seq_len
    )  # Shape: [seq_len, seq_len]

    # Compute mask by checking if source group is allowed for target position
    mask = allowed_groups_per_position.gather(1, group_indices_source)

    # Enforce causality (tokens cannot attend to future positions)
    causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    mask = mask & causal_mask

    return mask


ids = torch.tensor([0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 4, 5, 5] * 2)
parent_ids = torch.tensor([-1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2] * 2)
mask = get_mask(ids, parent_ids)
fast_mask = get_fast_mask(ids, parent_ids)
assert torch.all(mask == fast_mask), "Fast implementation does not match original"

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
plt.imshow(fast_mask, cmap="binary")
plt.colorbar()
plt.title("Attention Mask")
plt.xlabel("Source Position")
plt.ylabel("Target Position")
plt.show()

In [187]:
def get_faster_mask(ids: torch.Tensor, parent_ids: torch.Tensor) -> torch.Tensor:
    mask = ids.unsqueeze(0) == ids.unsqueeze(1)
    _mask = mask | (ids.unsqueeze(0) == parent_ids.unsqueeze(1))
    parent_ids = parent_ids[ids]
    while torch.any(mask != _mask):
        mask = _mask
        _mask = mask | (ids.unsqueeze(0) == parent_ids.unsqueeze(1))
        parent_ids = parent_ids[parent_ids]
    mask &= torch.tril(torch.ones_like(mask))
    return mask

In [None]:
%%timeit
get_faster_mask(ids, parent_ids)

In [None]:
import random


class Node:
    def __init__(self, id: int, parent: "Node | None" = None, size: int = 1):
        self.id = id
        self.parent = parent
        self.size = size


nodes = [Node(id=0, size=random.randint(1, 1)), Node(id=1, size=random.randint(1, 1))]
for i in range(2, 8):
    parent = random.choice(nodes)
    size = random.randint(1, 2)
    node = Node(id=i, parent=parent, size=size)
    nodes.append(node)
ids = torch.tensor([node.id for node in nodes for _ in range(node.size)])
parent_ids = torch.tensor(
    [
        node.parent.id if node.parent else node.id
        for node in nodes
        for _ in range(node.size)
    ]
)
# mask = get_mask(ids, parent_ids)
# faster_mask = get_faster_mask(ids, parent_ids)
torch.stack([ids, parent_ids])

In [None]:
def mask_and_pos_ids(
    ids: torch.Tensor, parent_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Creates an attention mask and position IDs for hierarchical attention based on node IDs and their parent IDs.

    Args:
        ids: A tensor of shape (batch_size, sequence_length) containing node IDs
        parent_ids: A tensor of shape (batch_size, sequence_length) containing parent IDs for each node

    Returns:
        A tuple containing:
        - mask: A boolean tensor of shape (batch_size, sequence_length, sequence_length) where True indicates
          allowed attention connections. Each position can attend to itself and any of its ancestors
          in the hierarchy, but only for previous positions (due to causal masking).
        - pos_ids: A tensor of shape (batch_size, sequence_length, sequence_length) containing relative
          position IDs for each allowed attention connection, with -1 for masked positions.
    """
    mask = ids.unsqueeze(1) == ids.unsqueeze(2)
    _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    while torch.any(mask != _mask):
        parent_ids = parent_ids.gather(
            1, torch.argmax((parent_ids.unsqueeze(2) == ids.unsqueeze(1)).int(), dim=2)
        )
        mask = _mask
        _mask = mask | (ids.unsqueeze(1) == parent_ids.unsqueeze(2))
    mask &= torch.tril(torch.ones_like(mask, dtype=torch.bool, device=ids.device))
    # mask = torch.linalg.matrix_power(mask.float(), mask.size(1) - 1) > 0
    pos_ids = (torch.where(mask, mask.cumsum(2), 0) - 1).max(1).values
    return mask, pos_ids


def test_mask_and_pos_ids(
    ids: list[int],
    parent_ids: list[int],
    expected_mask: list[list[int]],
    expected_pos_ids: list[int],
):
    mask, pos_ids = mask_and_pos_ids(
        ids=torch.tensor([ids]), parent_ids=torch.tensor([parent_ids])
    )
    assert torch.all(mask.int() == torch.tensor([expected_mask])), f"\n{mask.int()[0]}"
    assert torch.all(
        pos_ids == torch.tensor([expected_pos_ids])
    ), f"{pos_ids[0].tolist()}"


test_mask_and_pos_ids(
    ids=[0, 1],
    parent_ids=[0, 1],
    expected_mask=[[1, 0], [0, 1]],
    expected_pos_ids=[0, 0],
)

test_mask_and_pos_ids(
    ids=[0, 1, 1],
    parent_ids=[0, 0, 0],
    expected_mask=[[1, 0, 0], [1, 1, 0], [1, 1, 1]],
    expected_pos_ids=[0, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 0, 1, 2],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]],
    expected_pos_ids=[0, 1, 2, 3],
)

test_mask_and_pos_ids(
    ids=[0, 0, 1, 1],
    parent_ids=[0, 0, 1, 1],
    expected_mask=[[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]],
    expected_pos_ids=[0, 1, 0, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3],
    parent_ids=[0, 1, 0, 1],
    expected_mask=[[1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1]],
    expected_pos_ids=[0, 0, 1, 1],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 2, 3, 3],
    parent_ids=[0, 1, 0, 0, 1, 1],
    expected_mask=[
        [1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 0, 1, 2, 1, 2],
)

test_mask_and_pos_ids(
    ids=[0, 1, 2, 3, 4, 4, 5, 5],
    parent_ids=[0, 0, 1, 1, 2, 2, 3, 3],
    expected_mask=[
        [1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 1, 0, 0, 0],
        [1, 1, 1, 0, 1, 1, 0, 0],
        [1, 1, 0, 1, 0, 0, 1, 0],
        [1, 1, 0, 1, 0, 0, 1, 1],
    ],
    expected_pos_ids=[0, 1, 2, 2, 3, 4, 3, 4],
)

test_mask_and_pos_ids(
    ids=[2, 1, 0],
    parent_ids=[2, 2, 0],
    expected_mask=[
        [1, 0, 0],
        [1, 1, 0],
        [0, 0, 1],
    ],
    expected_pos_ids=[0, 1, 0],
)

In [None]:
import torch


def compute_reachability_matrix(adj_matrix):
    # Number of nodes
    num_nodes = adj_matrix.size(0)

    # Start with the adjacency matrix as reachability
    reachability = adj_matrix.clone()

    # Add paths of length 2 to N-1
    for _ in range(num_nodes - 1):
        # Update reachability with additional paths
        reachability = reachability + torch.mm(reachability, adj_matrix)

    # Convert to binary (reachable or not)
    reachability = (reachability > 0).float()

    return reachability


# Example adjacency matrix
adj_matrix = torch.tensor(
    [[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1]], dtype=torch.float32
)

reachability_matrix = compute_reachability_matrix(adj_matrix)
print(reachability_matrix)

In [None]:
def compute_reachability_matrix_one_swoop(adj_matrix):
    num_nodes = adj_matrix.size(0)

    # Compute the series sum: (I + A + A^2 + ... + A^(N-1))
    reachability = torch.matrix_power(adj_matrix, num_nodes - 1)

    # Binarize the result
    reachability = (reachability > 0).float()

    return reachability


reachability_matrix = compute_reachability_matrix_one_swoop(adj_matrix)
print(reachability_matrix)

In [None]:
import random


class Node:
    def __init__(self, id: int, parent: "Node | None" = None, size: int = 1):
        self.id = id
        self.parent = parent
        self.size = size


nodes = [Node(id=0, size=random.randint(1, 1)), Node(id=1, size=random.randint(1, 1))]
for i in range(2, 128):
    parent = random.choice(nodes)
    size = random.randint(1, 128)
    node = Node(id=i, parent=parent, size=size)
    nodes.append(node)
ids = torch.tensor([node.id for node in nodes for _ in range(node.size)])
parent_ids = torch.tensor(
    [
        node.parent.id if node.parent else node.id
        for node in nodes
        for _ in range(node.size)
    ]
)
# mask = get_mask(ids, parent_ids)
# faster_mask = get_faster_mask(ids, parent_ids)
torch.stack([ids, parent_ids])

In [None]:
len(ids)

In [None]:
%%timeit
_ = mask_and_pos_ids(torch.stack([ids, ids]), torch.stack([parent_ids, parent_ids]))

In [148]:
def show(mask):
    plt.figure(figsize=(10, 10))
    plt.imshow(mask, cmap="binary")
    plt.colorbar()
    plt.title("Attention Mask")
    plt.xlabel("Source Position")
    plt.ylabel("Target Position")
    plt.show()

In [None]:
import asyncio

process = await asyncio.create_subprocess_exec(
    "vllm",
    "serve",
    "NousResearch/Hermes-2-Theta-Llama-3-8B",
    stdout=asyncio.subprocess.PIPE,
    stderr=asyncio.subprocess.PIPE,
)
while True:
    print((await process.stdout.readline()).decode())

In [None]:
from lib.vllm import start_vllm_server, vllm_server_metrics
import os

model = "NousResearch/Hermes-2-Theta-Llama-3-8B"

os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
shutdown_server, client = await start_vllm_server(
    disable_log_requests=True,
    max_model_len=16384,
    model=model,
)

In [None]:
import asyncio
from lib.rl.episode import Episode
from typing import AsyncIterable, Literal

Split = Literal["train", "val", "test"]


async def episodes(split: Split) -> AsyncIterable[Episode]:
    for _ in range(10):
        await asyncio.sleep(1)
        yield Episode()  # type: ignore


async for episode in episodes(split="val"):
    print(episode)

In [None]:
from lib.rl.trainer import Trainer

episode = Episode()

Trainer(
    base_model=model,
    episodes={
        "train": [episode],
        "val": [episode],
        "test": [episode],
    },
)