In [66]:
import argparse
from contextlib import nullcontext
import os
from typing import Any, Dict, List, Literal, Optional

# Third-party
import ctrlg
import torch
import wandb
from datasets import load_dataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoTokenizer, LogitsProcessor, LogitsProcessorList
from transformers import TrainerCallback


def get_dfa_model(
    hmm_model,
    tokenizer,
    prompt_ids: List[int],
    keyphrase_ids: List[List[str]] = [[" "]],
    suffix_ids: Optional[List[int]] = None,
    min_new_tokens=5,
    max_new_tokens=32,
):
    """_summary_

    Args:
        prompt_ids (List[int]): E.g., [1, 2, 3]
        keyphrase_ids (List[List[str]], optional): E.g., [["4"]]
        suffix_ids (Optional[List[int]], optional): E.g., [4]

    Returns:
        _type_: _description_
    """

    device = next(hmm_model.parameters()).device
    vocab_size = len(tokenizer)

    # Prefix & suffix constraints
    prefix = ""  # generate text starting with nothing
    suffix = ".<|endoftext|>"  # generate text ending with '<|endoftext|>'; a suffix must end with the eos token
    prefix_ids = tokenizer.encode(prefix)
    if suffix_ids is None:
        suffix_ids = tokenizer.encode(suffix)

    # DFA Construction
    # ac_builder constructs a DFA representing the constraint that (at least)
    # one the patterns must appear; a pattern is a sequence of token ids
    ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)

    dfa_graphs = []
    for keyphrase in keyphrase_ids:
        patterns = [tokenizer.encode(x) for x in keyphrase]
        dfa_graphs.append(ac_builder.build(patterns))

    # taking the intersection of the DFAs, i.e., "logical and" of the constraints.
    # This function also minimizes the constructed DFA, which is mainly CPU-based operations;
    # Due to its pure python implemenation, DFA minimization can be slow for complex constraints
    dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode="intersection")

    # compile the dfa_graph for efficient GPU execution
    dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

    # Constraint logits processor
    constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
        hmm_model,
        dfa_model,
        min_new_tokens,
        max_new_tokens,
        prompt_ids,
        prefix_ids=prefix_ids,
        suffix_ids=suffix_ids,
    )

    return constraint_logits_processor


class DummyLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        prompts: List[dict[str, Any]], # E.g., [{'prompt': 'What is 2+2?', 'solution': '4'}, {'prompt': 'What is 2+3?', 'solution': '5'}]
        prompt_ids: torch.Tensor,  # [...]
        tokenizer: AutoTokenizer,
        hmm_model: Optional[torch.nn.Module] = None,
        min_new_tokens: int = 5,
        max_new_tokens: int = 32,
        constraint_mode: Literal["suffix", "keyphrase"] = "suffix",
    ):
        """_summary_

        Args:
            generate_inputs (dict[str, torch.Tensor  |  Any]): _description_
            prompts (List[dict[str, Any]]): List of dicts containing `prompt` and `solution` keys.
        """
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.hmm_model = hmm_model
        self.min_new_tokens = min_new_tokens
        self.max_new_tokens = max_new_tokens

        # check if all solutions are the same:
        self.dfa_logits_processor = None
        if (
            all(p["solution"] == prompts[0]["solution"] for p in prompts)
            and hmm_model is not None
        ):
            device = prompt_ids.device
            hmm_model = hmm_model.to(device)

            keyphrases = (
                [[prompts[0]["solution"]]]
                if constraint_mode == "keyphrase"
                else [[" "]]
            )
            suffix_ids = (
                tokenizer.encode(prompts[0]["solution"])
                if constraint_mode == "suffix"
                else None
            )

            self.dfa_logits_processor = get_dfa_model(
                hmm_model=hmm_model,
                tokenizer=tokenizer,
                # BUG: pottential bug prompt_ids is tensor but keyphrase,suffix_ids are lists
                prompt_ids=prompt_ids[0],
                # keyphrases=[[prompts[0]["solution"]]],
                # suffix_ids=tokenizer.encode(prompts[0]["solution"]),
                keyphrase_ids=keyphrases,
                suffix_ids=suffix_ids,
                min_new_tokens=min_new_tokens,
                max_new_tokens=max_new_tokens,
            )
        elif hmm_model is not None:
            # give warning that no dfa logits processor will be used
            print("Warning: no DFA logits processor will be used")

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
        if self.dfa_logits_processor is not None:
            return self.dfa_logits_processor(input_ids, scores)
        return scores



def get_tokenizer(model_name):

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Check if tokenizer is already a chat tokenizer
    if tokenizer.chat_template is not None:
        return tokenizer

    # Very dumb "chat" template: just concatenates user + assistant messages.
    # You can adjust this to something more realistic if you want.
    tokenizer.chat_template = """{% for message in messages %}
    {% if message['role'] == 'user' %}
    User: {{ message['content'] }}
    {% elif message['role'] == 'assistant' %}
    Assistant: {{ message['content'] }}
    {% endif %}
    {% endfor %}Assistant:"""

    # "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

    return tokenizer



In [46]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2-large")
tokenizer = get_tokenizer("gpt2-large")
hmm_model = ctrlg.HMM.from_pretrained("ctrlg/hmm_gpt2-large_common-gen_4096")

In [None]:
ds = load_dataset("trl-lib/DeepMath-103K", split="train").select(range(10))

In [60]:
idx = 2
prompt_ids = torch.tensor(tokenizer.apply_chat_template(ds[idx]["prompt"]))
prompts = [ds[i] for i in range(1)]
print(ds[idx])
logits_processor = DummyLogitsProcessor(
    prompt_ids=prompt_ids,
    prompts=prompts,
    tokenizer=tokenizer,
    hmm_model=hmm_model,
)

{'prompt': [{'content': 'Compute the limit: $$\\lim_{x\\rightarrow 0}\\frac{3x^2-3x\\sin x}{x^2+x\\cos\\frac{1}{x}}$$', 'role': 'user'}], 'solution': '$0$'}


InductorError: CppCompileError: C++ compile error

Command:
clang++ /var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/nr/cnrboeoye3ekapo3ioaksfsdkmujvkny74onmkyewiyvhwdpmvra.main.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_NEON -D AT_BUILD_ARM_VEC256_WITH_SLEEF -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -shared -fPIC -undefined dynamic_lookup -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -Werror=ignored-optimization-argument -Xclang -fopenmp -include /var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h -I/opt/homebrew/opt/python@3.13/Frameworks/Python.framework/Versions/3.13/include/python3.13 -I/Users/marawangamal/Documents/github/ctrl-rlvr/.venv/lib/python3.13/site-packages/torch/include -I/Users/marawangamal/Documents/github/ctrl-rlvr/.venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/opt/homebrew/opt/libomp/include -o /var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/nr/cnrboeoye3ekapo3ioaksfsdkmujvkny74onmkyewiyvhwdpmvra.main.so -lomp -lc10 -L/opt/homebrew/opt/python@3.13/Frameworks/Python.framework/Versions/3.13/lib -L/Users/marawangamal/Documents/github/ctrl-rlvr/.venv/lib/python3.13/site-packages/torch/lib -L/opt/homebrew/opt/libomp/lib

Output:
fatal error: file '/var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h' has been modified since the precompiled header '/var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h.pch' was built: mtime changed (was 1763344469, now 1763855015)
note: please rebuild precompiled header '/var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h.pch'
1 error generated.


Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


In [65]:
get_dfa_model(
    hmm_model=hmm_model,
    tokenizer=tokenizer,
    prompt_ids=[1,2,3],
    keyphrase_ids=[["4"]],
    suffix_ids=[4],
    min_new_tokens=5,
    max_new_tokens=32,
)

TypeError: get_dfa_model() got an unexpected keyword argument 'keyphrase_ids'. Did you mean 'keyphrases'?

In [54]:
all(p["solution"] == prompts[0]["solution"] for p in prompts)

True

In [None]:
# set TORCHDYNAMO_VERBOSE=1
import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"

constraint_mode = "suffix"
keyphrase_ids = (
    [[prompts[0]["solution"]]]
    if constraint_mode == "keyphrase"
    else [[" "]]
)
suffix_ids = (
    tokenizer.encode(prompts[0]["solution"])
    if constraint_mode == "suffix"
    else None
)

dfa_logits_processor = get_dfa_model(
    hmm_model=hmm_model,
    tokenizer=tokenizer,
    # BUG: pottential bug prompt_ids is tensor but keyphrase,suffix_ids are lists
    prompt_ids=[1, 2, 3],
    keyphrase_ids=[["4"]],
    suffix_ids=[4],
    min_new_tokens=6,
    max_new_tokens=128,
)

InductorError: CppCompileError: C++ compile error

Command:
clang++ /var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/nr/cnrboeoye3ekapo3ioaksfsdkmujvkny74onmkyewiyvhwdpmvra.main.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_NEON -D AT_BUILD_ARM_VEC256_WITH_SLEEF -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -shared -fPIC -undefined dynamic_lookup -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -Werror=ignored-optimization-argument -Xclang -fopenmp -include /var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h -I/opt/homebrew/opt/python@3.13/Frameworks/Python.framework/Versions/3.13/include/python3.13 -I/Users/marawangamal/Documents/github/ctrl-rlvr/.venv/lib/python3.13/site-packages/torch/include -I/Users/marawangamal/Documents/github/ctrl-rlvr/.venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/opt/homebrew/opt/libomp/include -o /var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/nr/cnrboeoye3ekapo3ioaksfsdkmujvkny74onmkyewiyvhwdpmvra.main.so -lomp -lc10 -L/opt/homebrew/opt/python@3.13/Frameworks/Python.framework/Versions/3.13/lib -L/Users/marawangamal/Documents/github/ctrl-rlvr/.venv/lib/python3.13/site-packages/torch/lib -L/opt/homebrew/opt/libomp/lib

Output:
fatal error: file '/var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h' has been modified since the precompiled header '/var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h.pch' was built: mtime changed (was 1763344469, now 1763855015)
note: please rebuild precompiled header '/var/folders/r_/d81gvkws1n5fg2_7mb57cwzh0000gn/T/torchinductor_marawangamal/precompiled_headers/c62bfr35myvt7jik6nuf236zkm2n42zpgvjfrvmt26eeh234b7cl.h.pch'
1 error generated.


Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
