In [1]:
import argparse

import torch
import tqdm
from tensordict import NonTensorStack, TensorDict
from tensordict.nn import (
    ProbabilisticTensorDictModule as Prob,
    TensorDictModule as Mod,
    TensorDictSequential as Seq,
)
from torch.distributions import Categorical

from torchrl._utils import _make_ordinal_device
from torchrl.data import MCTSForest

from torchrl.envs import LLMHashingEnv
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, pipeline

try:
    is_sphinx = __sphinx_build__
except NameError:
    is_sphinx = False

parser = argparse.ArgumentParser()
parser.add_argument(
    "--pretrained",
    type=bool,
    default=not is_sphinx,
    help="Set to True to load pre-trained weights, False for random weights.",
)
parser.add_argument(
    "--model",
    choices=["llama3.1", "gpt2"],
    default="gpt2",
    help="Choose the model to use: 'llama3.1' or 'gpt2'.",
)
parser.add_argument(
    "--beta", type=int, default=3, help="Set the beta parameter for the model."
)
parser.add_argument(
    "--pool", type=int, default=1000, help="Set the pool size for processing."
)
parser.add_argument(
    "--nsteps", type=int, default=10, help="Set the number of steps for the process."
)
parser.add_argument(
    "--device",
    type=str,
    default=None,
    help="Specify the device to use (e.g., 'cpu', 'cuda').",
)
parser.add_argument(
    "--device_map",
    type=str,
    default="auto",
    help="Specify the device map for model parallelism (e.g., 'auto').",
)

args = parser.parse_args(
    [
        # When executing this in a notebook, change the parameters here, eg
        #"--device", "cuda:0"
        "--device", "cpu",
        #"--model", "llama3.1",

    ]
)

  from .autonotebook import tqdm as notebook_tqdm
  interpolation: int = Image.BILINEAR,
  interpolation: int = Image.NEAREST,
  interpolation: int = Image.BICUBIC,


In [2]:
if args.model == "gpt2":
    tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
    if args.pretrained:
        cfg = GPT2Config.from_pretrained("openai-community/gpt2")
    else:
        cfg = GPT2Config()
    llm = GPT2LMHeadModel(cfg).eval().requires_grad_(False)

    device = args.device

    #if torch.cuda.is_available():
    #    device = "cuda:0"
    #else:
    #    device = "cpu"

elif args.model == "llama3.1":
    if not args.pretrained:
        raise ValueError("llama3.1 can only be used with --pretrained=True")

    model_id = "meta-llama/Llama-3.1-8B"

    if args.device:
        args.device_map = None
    pipeline = pipeline(
        "text-generation",
        model=model_id,
        model_kwargs={"torch_dtype": torch.bfloat16},
        device_map=args.device_map,
        device=args.device,
    )

    tokenizer = pipeline.tokenizer
    llm = pipeline.model.eval().requires_grad_(False)
    if args.device:
        device = _make_ordinal_device(args.device)
    elif torch.cuda.is_available():
        device = "cuda:0"
    elif torch.mps.is_available():
        torch.mps.empty_cache()
        device = "mps:0"
    else:
        device = "cpu"

torch.set_default_device(device)

text_to_tensor = Seq(
    Mod(tokenizer, in_keys=["query"], out_keys=["out"]),
    # A renaming layer
    Mod(lambda x: x, in_keys=[("out", "input_ids")], out_keys=["observation"]),
).select_out_keys("observation")
td = TensorDict(
    query=NonTensorStack.from_list(["hello world! Give me a high five"] * 4),
    #query=NonTensorStack.from_list([
    #    "hello world! Give me a high five",
    #    "What's up?",
    #    "How's it going?",
    #    "I'm awake."
    #]),
    batch_size=[4],
)
print(text_to_tensor(td))

TensorDict(
    fields={
        observation: Tensor(shape=torch.Size([4, 8]), device=cpu, dtype=torch.int64, is_shared=False),
        query: NonTensorStack(
            ['hello world! Give me a high five', 'hello world!...,
            batch_size=torch.Size([4]),
            device=None)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)


In [3]:
env = LLMHashingEnv(vocab_size=tokenizer.vocab_size, tokenizer=tokenizer)

In [4]:
def select_unique_obs(td):
    # Get the obs (the hash)
    hashes = td["hashing"]
    hashes = hashes.squeeze()
    assert hashes.ndim == 1
    # the indices of the unique values are the unique values of the inverse indices returned from `unique`
    _, unique_hashes = torch.unique(hashes, dim=0, return_inverse=True)
    unique_hashes = unique_hashes.unique()
    return td[unique_hashes]

In [5]:
class LLMWrapper(torch.nn.Module):
    def __init__(self, gpt):
        super().__init__()
        self.gpt = gpt

    def forward(self, x: torch.Tensor) -> TensorDict:
        result = TensorDict.from_dataclass(self.gpt(x, return_dict=True), device=device)
        return result


llm_module = Mod(LLMWrapper(llm), in_keys=["observation"], out_keys=["data"])

In [6]:
select_last = Mod(
    lambda x: x[:, -1:], in_keys=[("data", "logits")], out_keys=["logits"]
)

In [7]:
class CategoricalWithoutReplacement(Categorical):
    def sample(self, sample_shape=()) -> torch.Tensor:
        n = sample_shape.numel()
        probs = self.probs
        probs_shape = probs.shape
        if len(probs_shape) > 2:
            probs = probs.flatten(0, -2)
        samples = torch.multinomial(probs, n, replacement=False)
        return samples.view((*sample_shape, *probs_shape[:-1]))


prob_module = Prob(
    in_keys=["logits"],
    out_keys=["action"],
    default_interaction_type="random",
    distribution_class=CategoricalWithoutReplacement,
    return_log_prob=True,
    log_prob_key="logits_select",
    num_samples=args.pool,
)

In [8]:
def select_top_k(td: TensorDict, top_k=args.beta) -> TensorDict:
    logits = td["logits_select"]
    topk = logits.topk(top_k, dim=0)
    topk_indices = topk.indices.squeeze(-1)
    return td[topk_indices].set("topk_indices", topk_indices)

In [9]:
policy = Seq(
    # Only get the unique obs
    select_unique_obs,
    # Call to the LLM
    llm_module,
    # Select last logit
    select_last,
    # Sample
    prob_module,
    # Reshape to -1
    lambda td: td.reshape(-1),
    # Top-k
    select_top_k,
)

In [10]:
x = tokenizer(["Check out TorchRL!"])["input_ids"]
td = TensorDict(observation=x, batch_size=[1])
td = env.reset(td)
env.check_env_specs(tensordict=td, return_contiguous=False)

2025-01-13 15:07:57,648 [torchrl][INFO] check_env_specs succeeded!


In [11]:
forest = MCTSForest(observation_keys=["hashing"], action_keys=["action", "logits_select"])

In [13]:
input_string = "The reason I ate the apple was "

with torch.no_grad():
    # Total number of candidates
    pool = args.pool
    # Number of selected beams
    beta = args.beta
    x = tokenizer([input_string])["input_ids"]
    reset_td = env.reset(
        TensorDict(observation=x, batch_size=[1]).repeat_interleave(args.beta)
    )
    tds = []
    # beam search
    td = reset_td
    reset_td = reset_td[0].clone()

    pbar = tqdm.tqdm(range(args.nsteps))
    for _ in pbar:
        td = policy(td)
        next_td = env.step(td)

        tds.append(next_td)
        next_td_filtered = next_td.exclude(
            "observation", "text", ("next", "observation"), ("next", "text")
        )
        forest.extend(next_td_filtered)
        pbar.set_description(f"Forest length: {len(forest)}")

        print("action", next_td["action"])
        td = env.step_mdp(next_td)
        print("hash", td["hashing"])

    tds = TensorDict.lazy_stack(tds, -1)
    for i in range(tds.shape[0]):
        print(tds[i, -1]["next", "text"])

    tree = forest.get_tree(reset_td)
    valid_paths = list(tree.valid_paths())
    print("valid paths", valid_paths)

    for path in valid_paths:
        rollout = tree.rollout_from_path(path)
        print(input_string, tokenizer.decode(rollout["action"].squeeze(-1)))
        print(rollout["logits_select"].sum())

    def make_labels(local_tree, path):
        if path:
            r = tree.rollout_from_path(path)
            actions = r["action"]
            return input_string + tokenizer.decode(actions.squeeze(-1))
        return input_string

    tree.plot(make_labels=make_labels)

Forest length: 33:  10%|█         | 1/10 [00:00<00:00,  9.33it/s]

action tensor([[  220],
        [33125],
        [46308]])
hash tensor([[-4556944625400516864],
        [ -808573244379598300],
        [ 2737375120620791417]])


Forest length: 36:  20%|██        | 2/10 [00:01<00:07,  1.14it/s]

action tensor([[33125],
        [ 7570],
        [15401]])
hash tensor([[-3712056468678266966],
        [-4678643172556498879],
        [ 3301479647910938983]])


Forest length: 39:  30%|███       | 3/10 [00:03<00:08,  1.23s/it]

action tensor([[15401],
        [ 7570],
        [15401]])
hash tensor([[8862343678105057667],
        [5446914216984854687],
        [3855610188286197921]])


Forest length: 42:  40%|████      | 4/10 [00:05<00:09,  1.61s/it]

action tensor([[15401],
        [15401],
        [ 1100]])
hash tensor([[ 8101390298399478270],
        [ 8692009458689120094],
        [-4192349278742540019]])


Forest length: 45:  50%|█████     | 5/10 [00:11<00:16,  3.23s/it]

action tensor([[ 1100],
        [27449],
        [15008]])
hash tensor([[-5933794507642788476],
        [-2123644125514036229],
        [ 7528370810698822507]])


Forest length: 48:  60%|██████    | 6/10 [00:18<00:17,  4.48s/it]

action tensor([[33922],
        [ 7570],
        [15401]])
hash tensor([[-8111050819451086459],
        [-7452106307813998861],
        [ 4144126010748071399]])


Forest length: 51:  70%|███████   | 7/10 [00:25<00:15,  5.22s/it]

action tensor([[22653],
        [45894],
        [35143]])
hash tensor([[7901103242211363821],
        [4675350937355286155],
        [ -11016180407394597]])


Forest length: 54:  80%|████████  | 8/10 [00:34<00:12,  6.39s/it]

action tensor([[35143],
        [33922],
        [33346]])
hash tensor([[-1392948906077328681],
        [ 7433371700422540310],
        [ 5673736469984173128]])


Forest length: 57:  90%|█████████ | 9/10 [00:43<00:07,  7.42s/it]

action tensor([[33922],
        [ 7995],
        [ 7995]])
hash tensor([[-9076204173052785009],
        [ 3471902211221434652],
        [ 3471902211221434652]])


Forest length: 60: 100%|██████████| 10/10 [00:46<00:00,  4.61s/it]

action tensor([[33922],
        [23229],
        [16838]])
hash tensor([[5745392113208399321],
        [5711735439505919670],
        [8769517905672612937]])
The reason I ate the apple was  circus Soviet Budget Budget freelance Soviet Grade Shuttle postponed postponed
The reason I ate the apple was  circus Soviet Budget Budget freelance Soviet Grade Shuttle postponed hostage
The reason I ate the apple was  circus Soviet Budget Budget freelance Budget ShuttleGPU alertpleted
valid paths [(0,), (2,), (1, 0), (1, 1, 1), (1, 2, 0), (1, 2, 1), (1, 1, 0, 1), (1, 1, 0, 0, 0), (1, 1, 0, 0, 1, 1), (1, 1, 0, 0, 2, 0), (1, 1, 0, 0, 2, 1), (1, 1, 0, 0, 1, 0, 0), (1, 1, 0, 0, 1, 0, 1)]
The reason I ate the apple was   
tensor(-8.8272)
The reason I ate the apple was   Powerful
tensor(-8.9823)
The reason I ate the apple was   circus circus
tensor(-17.0512)
The reason I ate the apple was   circus Soviet Budget read read
tensor(-43.6607)
The reason I ate the apple was   circus Budget Budget
tensor(-25.2


