In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import ctrlg

# Defin model
model_name = "ctrlg/gpt2-large_common-gen"
hmm_model_name = f'ctrlg/hmm_gpt2-large_common-gen_4096' # alternatively ctrlg/hmm_gpt2-large_common-gen_4096 for better quality

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Define HMM
hmm_model = ctrlg.HMM.from_pretrained(hmm_model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from typing import List, Optional

def get_dfa_model_v2(
        hmm_model: torch.nn.Module,
        prompt_ids: List[int],   # Shape: (B, T)
        tokenizer:AutoTokenizer,
        keyphrases:List[List[str]]=[[' ']], 
        suffix_ids:Optional[List[int]]=None, 
        min_new_tokens:int=5, 
        max_new_tokens:int=32,
        device:torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ):
    """Constructs a DFA model for the given prompt and keyphrases.

    Args:
        prompt_ids (List[int]): Prompt integer list
        keyphrases (List[List[str]], optional): List of keyphrases to be constrained. Defaults to [[' ']].
        suffix_ids (Optional[List[int]], optional): Suffix integer list. Defaults to None.
        min_new_tokens (int, optional): Minimum number of tokens to generate. Defaults to 5.
        max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 32.
        device (torch.device, optional): Device to run the model on. Defaults to torch.device('cuda' if torch.cuda.is_available() else 'cpu').

    Returns:
        constraint_logits_processor: Logits processor for the DFA model.
    """

    vocab_size = len(tokenizer)

    ##################################### prefix, suffix, prompt #####################################
    prefix = '' # generate text starting with nothing
    suffix = '.<|endoftext|>' # generate text ending with '<|endoftext|>'; a suffix must end with the eos token

    prefix_ids = tokenizer.encode(prefix)
    if suffix_ids is None:
        suffix_ids = tokenizer.encode(suffix)

    ##################################### DFA Construction #####################################
    # ac_builder constructs a DFA representing the constraint that (at least) 
    # one the patterns must appear; a pattern is a sequence of token ids
    ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)

    dfa_graphs = []

    # constraint 1:
    for keyphrase in keyphrases:
        patterns = [tokenizer.encode(x) for x in keyphrase]
        dfa_graphs.append(ac_builder.build(patterns))

    # taking the intersection of the DFAs, i.e., "logical and" of the constraints.
    # This function also minimizes the constructed DFA, which is mainly CPU-based operations;
    # Due to its pure python implemenation, DFA minimization can be slow for complex constraints
    dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')

    # compile the dfa_graph for efficient GPU execution
    dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

    ##################################### token length #####################################

    constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
        hmm_model, 
        dfa_model,
        min_new_tokens, 
        max_new_tokens,
        prompt_ids, 
        prefix_ids=prefix_ids, 
        suffix_ids=suffix_ids
    )

    return constraint_logits_processor

In [3]:
import torch

# Places
prompt = 'What is 10+2?'
solution = " 13 "

# Math
# prompt = "To express 20 as a sum of different powers of 2, we would write $20 = 2^4 + 2^2$. The sum of the exponents of these powers is $4 + 2 = 6$. If 400 were expressed as a sum of at least two distinct powers of 2, what would be the least possible sum of the exponents of these powers?"
# solution = "6"

suffix_ids = tokenizer.encode(solution)
prompt_ids = tokenizer.encode(prompt)
# keyphrases=[['beach', "soccer"]], 

max_new_tokens = 32
min_new_tokens = 5
lproc = get_dfa_model_v2(
    hmm_model=hmm_model,
    prompt_ids=prompt_ids, 
    tokenizer=tokenizer, 
    suffix_ids=suffix_ids,
    max_new_tokens=max_new_tokens, 
    min_new_tokens=min_new_tokens,
    device="cpu"
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [68]:
from transformers import LogitsProcessorList


# set beam_size for beam search; usually the larger the beam_size the
# higher the generation quality
# beam_size = 8
# lproc.hmm_batch_size = beam_size

output_gen = model.generate(
    input_ids=torch.tensor(prompt_ids).reshape(1, -1),
    max_new_tokens=max_new_tokens, 
    logits_processor=LogitsProcessorList([lproc]),
    pad_token_id=tokenizer.eos_token_id,
    # num_beams=beam_size,
    do_sample=True,
    length_penalty=0.2,
)

In [70]:
tokenizer.decode(output_gen[0], skip_special_tokens=False)

'What is 10+2? \xa010+2 = 13 \xa0* 13 \xa0 = 13 \xa0* 13 \xa0 = 13 \xa0 13  13  13  13'