## Ctrl-G Tutorial

### **Part A**. Ctrl-G on GPT2-large (less computation required)

**Step 1. load pretrained models**

In [None]:
import os
device = 'cuda'
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set your cuda device
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
import ctrlg
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

# load the pretrained base_model and hmm_model; see README.md for a complete list of 
# released checkpoints. note that the hmm_model and base_model must share the same 
# vocabulary of tokens: i.e., one cannot apply hmm_gpt2-large_common-gen_4096 to 
# tulu2-7b_writing-prompts. To apply Ctrl-G to a custom base_model or to achieve 
# best performance on a specific domain, users would need to distill an hmm_model
# from the base_model. Please refer to tutorial_distillation.ipynb for details.
BASE_MODEL_PATH = f'ctrlg/gpt2-large_common-gen' # a gpt2-large checkpoint domain adapted to the common-gen corpus
HMM_MODEL_PATH = f'ctrlg/hmm_gpt2-large_common-gen_4096' # alternatively 'ctrlg/hmm_gpt2-large_common-gen_32768' for better quality

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH).to(device)
base_model.eval()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
hmm_model = ctrlg.HMM.from_pretrained(HMM_MODEL_PATH).to(device)

**Step 2. specify logical constraints as DFAs (example constraint 1)**

In [None]:
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id


##################################### prefix, suffix, prompt #####################################
prefix = '' # generate text starting with nothing
suffix = '<|endoftext|>' # generate text ending with '<|endoftext|>'; a suffix must end with the eos token
prompt = '<|endoftext|>' # prompt the base model with the '<|endoftext|>' token

prefix_ids = tokenizer.encode(prefix)
suffix_ids = tokenizer.encode(suffix)
prompt_ids = tokenizer.encode(prompt)
##################################### prefix, suffix, prompt #####################################


##################################### DFA Construction #####################################
# ac_builder constructs a DFA representing the constraint that (at least) 
# one the patterns must appear; a pattern is a sequence of token ids
ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
# word_count_builder constructs a DFA representing the constraint that 
# the generated text consists of a to b words; refer to the source code of
# WordCountBuilder for the definition of a word.
word_count_builder = ctrlg.WordCountBuilder(tokenizer, vocab_size)

dfa_graphs = []

# constraint 1:
# one of ' riding a bike', ' ride bikes', ' rides a bike', ' biking', ' bikes' has to appear
# AND one of ' park', ' beach' has to appear
keyphrases = [[' riding a bike', ' ride bikes', ' rides a bike', ' biking', ' bikes'],
            [' park', ' beach']]
for keyphrase in keyphrases:
    patterns = [tokenizer.encode(x) for x in keyphrase]
    dfa_graphs.append(ac_builder.build(patterns))

# constraint 2: generate exactly 10 words
# word_count_builder constructs a DFA representing the constraint that 
# the generated text must contain a to b words
a, b = 10, 10
dfa_graphs.append(word_count_builder.build(a, b))

# taking the intersection of the DFAs, i.e., "logical and" of the constraints.
# This function also minimizes the constructed DFA, which is mainly CPU-based operations;
# Due to its pure python implemenation, DFA minimization can be slow for complex constraints
dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')

# compile the dfa_graph for efficient GPU execution
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)
##################################### DFA Construction #####################################


##################################### token length #####################################
# specify the min_new_tokens and max_new_tokens to be generated (excluding
# the prefix and suffix) make sure that the numbers here would not conflict
# with the given constraint: e.g. ask the model to generate 10 words with
# max_new_tokens = 8
min_new_tokens = 5
max_new_tokens = 32
##################################### token length #####################################

**Step 3. generate with constraints.**

Due to the use of @torch.compile, the first run of the following functions could be significantly slower than the later runs.

In [None]:
# initialze the constraints logits processor
# Note: this part pre-computes & cache certain conditional probability tables;
# one simple optimization is to re-use the same constraint_logits_processor for
# base_model.generate if the constraints do not change.
constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
    hmm_model, dfa_model,
    min_new_tokens, max_new_tokens,
    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


# set beam_size for beam search; usually the larger the beam_size the
# higher the generation quality
beam_size = 128

# set the hmm_batch_size depending on the resource available;
# uses more memory with larger hmm_batch_size but attains best speed 
# when it is set to beam_size
constraint_logits_processor.hmm_batch_size = beam_size

# generate with beam search
input_ids = torch.tensor([prompt_ids], device=device)
outputs = base_model.generate(
        input_ids=input_ids, do_sample=False, length_penalty=0.2,
        num_beams=beam_size, num_return_sequences=beam_size,
        min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
        logits_processor=LogitsProcessorList([constraint_logits_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )

**Step 4. extract & rank outputs via the base model.**

In [None]:
# extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

# rank the generated ids by the base_model probability
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids)

# print top 10 outputs
for idx, generated in enumerate(generated_ids[:10]):
    print(f'{idx}. ' + tokenizer.decode(prefix_ids, skip_special_tokens=True) + \
          '\033[1m' + tokenizer.decode(generated, skip_special_tokens=True) + '\033[0m' + \
          tokenizer.decode(suffix_ids, skip_special_tokens=True))

**Step 5. try some other constraints! (example constraint 2)**

In [None]:
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id


prefix = ' on a fine sunny' # generate text starting with ' on a fine sunny'
suffix = ' in the park.<|endoftext|>' # generate text ending with ' in the park.<|endoftext|>'
prompt = '<|endoftext|> on a fine sunny' # prompt the base model with the '<|endoftext|>' token and the prefix

prefix_ids = tokenizer.encode(prefix)
suffix_ids = tokenizer.encode(suffix)
prompt_ids = tokenizer.encode(prompt)


ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
word_count_builder = ctrlg.WordCountBuilder(tokenizer, vocab_size)

dfa_graphs = []
# constraint 1:
# one of ' girl', ' boy', ' girls', ' boys', ' children' AND
# one of ' dogs', ' cats', ' dog', ' cat' have to appear
# in the GIVEN ORDER.
keyphrases = [[' girl', ' boy', ' girls', ' boys', ' children'],
            [' dogs', ' cats', ' dog', ' cat']]
for keyphrase in keyphrases:
    patterns = [tokenizer.encode(x) for x in keyphrase]
    dfa_graphs.append(ac_builder.build(patterns))
# concatenate the patterns so they appear in the given order
dfa_graphs = [ctrlg.DFA_concatenate(dfa_graphs)]

# constraint 2: generate 7 - 12 words
a, b = 7, 12
dfa_graphs.append(word_count_builder.build(a, b))

dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)


min_new_tokens = 5
max_new_tokens = 32


# initialze the constraints logits processor
constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
    hmm_model, dfa_model,
    min_new_tokens, max_new_tokens,
    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


beam_size = 128
constraint_logits_processor.hmm_batch_size = beam_size
input_ids = torch.tensor([prompt_ids], device=device)
# generate with beam search
outputs = base_model.generate(
        input_ids=input_ids, do_sample=False,
        num_beams=beam_size, num_return_sequences=beam_size,
        min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
        logits_processor=LogitsProcessorList([constraint_logits_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )

# extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

# rank the generated ids by the base_model probability
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids)

# print top 10 outputs
for idx, generated in enumerate(generated_ids[:10]):
    print(f'{idx}. ' + tokenizer.decode(prefix_ids, skip_special_tokens=True) + \
          '\033[1m' + tokenizer.decode(generated, skip_special_tokens=True) + '\033[0m' + \
          tokenizer.decode(suffix_ids, skip_special_tokens=True))

### **Part B**. Ctrl-G on TULU2-7B (more computation required)

Step 1. load pretrained models.

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # set your cuda device
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
import ctrlg
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

device = 'cuda'

# load the pretrained base_model and hmm_model;
BASE_MODEL_PATH = f'ctrlg/tulu2-7b_writing-prompts'
HMM_MODEL_PATH = f'ctrlg/hmm_tulu2-7b_writing-prompts_32768'

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH).to(device)
base_model.eval()
base_model.half() # fp16 inference
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
hmm_model = ctrlg.HMM.from_pretrained(HMM_MODEL_PATH).to(device)

Step 2. specify logical constraints as DFAs.

In [None]:
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id

prefix = 'Once upon a time, in a land far, far away, there was a kingdom. The kingdom was'
suffix = 'beautiful buildings. The people of this kingdom were known for their kindness and generosity, always ready to lend a helping hand.</s>'
soft_constraint = ' in fairytale style' # use empty string for no soft constraint
prompt = f'<|user|>\nContinue the given text{soft_constraint}:\n{prefix}\n<|assistant|>\n'

prefix_ids = tokenizer.encode(prefix)[1:]
suffix_ids = tokenizer.encode(suffix)[1:]
prompt_ids = tokenizer.encode(prompt)

ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
eos_builder = ctrlg.EOSBuilder(vocab_size, eos_token_id)

dfa_graphs = []
keyphrases = [['towering'], ['reach the sky'], ['reflected'], ['lake']]
for keyphrase in keyphrases:
    patterns = [tokenizer.encode(x)[1:] for x in keyphrase]
    dfa_graphs.append(ac_builder.build(patterns))
dfa_graphs.append(eos_builder.build())

dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

min_new_tokens = 16
max_new_tokens = 32

Step 3. generate with constraints.

Due to the use of @torch.compile, the first run of the following functions could be significantly slower than the later runs.

In [None]:
# initialze the constraints logits processor
constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
    hmm_model, dfa_model,
    min_new_tokens, max_new_tokens,
    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


# set the hmm_batch_size & temperature
beam_size = 128 # sample 128 sequences
temperature = 0.7
constraint_logits_processor.hmm_batch_size = beam_size
constraint_logits_processor.temperature = temperature


# generate with sampling, temperature=0.7
input_ids = torch.tensor([prompt_ids], device=device)
outputs = base_model.generate(
        input_ids=input_ids, do_sample=True,
        num_return_sequences=beam_size, 
        min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
        logits_processor=LogitsProcessorList([constraint_logits_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )


# extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

# filter 75% of the generated ids by how well they connect with the suffix
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids,
                                            suffix_logits_only=True, suffix_length_cap=5)[:32]
# rank the generated ids by the base_model for higher quality
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids)

# print top 10 outputs
for idx, generated in enumerate(generated_ids[:10]):
    print(f'{idx}. ' + tokenizer.decode(prefix_ids, skip_special_tokens=True) + \
          ' ' + '\033[1m' + tokenizer.decode(generated, skip_special_tokens=True) + '\033[0m' + ' ' + \
          tokenizer.decode(suffix_ids, skip_special_tokens=True))