In [1]:
import sys
sys.path.append("..")
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor

In [2]:
NUM_ITER = 10
MODEL_ID = "TinyLlama/TinyLlama_v1.1"
TRIE_PATH = "tries/binary_len_5_0_trie.json"
DEVICE = "cuda"
DTYPE = torch.bfloat16
MAX_NEW_TOKENS = 512
TEMPERATURE = 1.0
REPETITION_PENALTY = 1.0
TOP_P = 1.0
TOP_K = 0

device = torch.device(DEVICE)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# Load model
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
model.to(device)
model.to(dtype=DTYPE)
model.resize_token_embeddings(len(tokenizer))

GRAMMAR_PATH = "../examples/grammars/arithmetic.ebnf"
# Load EBNF grammar
with open(GRAMMAR_PATH, "r") as file:
    grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)

  return self.fget.__get__(instance, owner)()


In [3]:
GRAMMAR_PATH = "../examples/test/binary_len_5_0.ebnf"
# Load EBNF grammar
with open(GRAMMAR_PATH, "r") as file:
    grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)

# Initialize logits processor for the grammar
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
logits_processors = LogitsProcessorList([
    inf_nan_remove_processor,
    gad_oracle_processor,
])

# Tokenize prompt into ids
prompt = "Generate a binary string of length 5."

In [4]:
import torch.nn.functional as F
def get_prob(s: str) -> float:
    full_text = prompt + " " + s
    tokens = tokenizer.encode(full_text, return_tensors='pt')
    start_index = len(tokenizer.encode(prompt + " ", return_tensors="pt")[0]) - 1
    print(start_index)
    llm_prob = 1.0
    for i in range(start_index, len(tokens[0]) - 1):
        input_tokens = tokens[:, :i+1].to(DEVICE)
        outputs = model(input_tokens)
        logits = outputs.logits
        last_token_logits = logits[0, -1, :]
        probabilities = F.softmax(last_token_logits, dim=-1)
        
        next_token_id = tokens[0, i + 1]
        next_token_prob = probabilities[next_token_id].item()
        context = tokenizer.decode(input_tokens[0])
        next_token = tokenizer.decode([next_token_id])

        llm_prob *= next_token_prob
    
        #print(f"Context: '{context}', Actual next token: '{next_token}', Probability: {next_token_prob:.4f}")
    return llm_prob

In [5]:
DESIRED_PREFIX = "110"

# Set up the device
device = torch.device(DEVICE)
# Encode the desired prefix
desired_prefix_ids = tokenizer.encode(DESIRED_PREFIX, add_special_tokens=False)

with open(GRAMMAR_PATH, "r") as file:
    grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
logits_processors = LogitsProcessorList([
    inf_nan_remove_processor,
    #gcd_oracle_processor,
])


# Define the prefix_allowed_tokens_fn with correct signature
def prefix_allowed_tokens_fn(batch_id: int, input_ids_partial: torch.LongTensor):
    #print(input_ids_partial)
    #print(tokenizer.decode(input_ids_partial))
    #print(input_ids_partial.shape)
    #print(tokenizer.encode('0', add_special_tokens=False)[0])
    """
    Restricts the generated tokens to enforce a specific prefix.
    
    Args:
        batch_id (int): The index of the current batch.
        input_ids_partial (torch.LongTensor): The input IDs generated so far, including the prompt.
        step (int): The current generation step.
        
    Returns:
        List[int]: A list of allowed token IDs for the current step.
    """
    step = input_ids_partial.shape[0]
    # Total tokens in the prompt
    prompt_length = input_ids.shape[1]
    
    # Calculate the step relative to the prefix
    generation_step = step - prompt_length

    if generation_step < 0:
        # Still generating the prompt; no restrictions
        return list(range(tokenizer.vocab_size))
    elif generation_step < len(desired_prefix_ids):
        # Enforce the desired prefix
        #print(desired_prefix_ids[generation_step])
        return [desired_prefix_ids[generation_step]]
    elif generation_step < len(desired_prefix_ids) + 5 - len(DESIRED_PREFIX):
        # After prefix, allow only '0' or '1' to complete the binary string
        token_0 = tokenizer.encode('0', add_special_tokens=False)[1]
        token_1 = tokenizer.encode('1', add_special_tokens=False)[1]
        return [token_0, token_1]
    else:
        return [tokenizer.eos_token_id]

# Generate continuation using beam search with prefix enforcement
output = model.generate(
    input_ids=input_ids,
    max_new_tokens=MAX_NEW_TOKENS,  # Generate 2 tokens after the prefix to reach length 5
    temperature=TEMPERATURE,
    repetition_penalty=REPETITION_PENALTY,
    top_p=TOP_P,
    top_k=TOP_K,
    output_scores=True,
    return_dict_in_generate=True,
    num_beams=1,                    # Beam search
    #logits_processor=logits_processors,     # Apply grammar and other constraints
    do_sample=False,                         # Deterministic beam search
    pad_token_id=tokenizer.eos_token_id,    # Define padding token
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn  # Enforce prefix constraints
)

NameError: name 'input_ids' is not defined

In [6]:
from transformers_gad.recognizer import AcceptState
import random
RAW, MASKED = "raw", "masked"
EPS = 10 ** -9
class SampleNode:
    def __init__(self, prefix, state: AcceptState, raw_prob, masked_prob, parent=None, last_token=None):
        self.prefix = prefix
        self.state = state
        self.prefix_prob = {RAW: float(raw_prob), MASKED: float(masked_prob)}
        self.sum_prob = None
        self.parent = parent
        self.children = {}
        self.step_prob = {}
        self.last_token = last_token

    def is_new(self):
        return self.sum_prob is None

    def to_string(self, tokenizer):
        return f"{tokenizer.decode(self.prefix[0])}@{self.prefix_prob}"

    def insert_child(self, token, new_state, step_prob):
        token_tensor = torch.tensor([[token]], device=self.prefix.device)
        new_prefix = torch.cat((self.prefix, token_tensor), dim=1)

        child = SampleNode(new_prefix, new_state, self.prefix_prob[RAW] * step_prob, 
                           self.prefix_prob[MASKED] * step_prob / self.sum_prob, self, token)
        self.children[token] = child
        self.step_prob[token] = step_prob

    def sample_next(self, forbid=None):
        items = [token for token in self.children if token != forbid]
        weights = [self.step_prob[token] for token in items]
        #print(items, weights, "forbid", forbid)
        token = random.choices(items, weights=weights, k=1)[0]
        return token

In [13]:
def _process_score(scores):
    scores[scores != scores] = 0.0

    scores[scores == float("inf")] = torch.finfo(scores.dtype).max
    scores[scores == float("-inf")] = torch.finfo(scores.dtype).min
    return F.softmax(scores, dim=-1)

MAX_TURN = 512

def _get_meaningful_parents(node):
    parent_list = []
    while node.parent is not None:
        last_token = node.last_token
        node = node.parent
        if len(node.children) > 1:
            parent_list.append((node, last_token))
    return parent_list

class SampleHolder:
    def __init__(self, model, tokenizer, constraint, device):
        self.model = model
        self.tokenizer = tokenizer
        self.constraint = constraint
        self.device = device
        root_tokens = tokenizer.encode(prompt + " ", return_tensors='pt').to(self.device)
        state = grammar.string_recognizer.get_initial_accept_state()
        self.root = SampleNode(root_tokens, state, 1.0, 1.0, None)

    def draw(self, node: SampleNode = None):
        if node is None: node = self.root
        for _ in range(MAX_TURN):
            #print("Sampling from ", node.to_string(self.tokenizer))
            if node.is_new():
                logits = self.model(node.prefix).logits
                raw_score = logits[0, -1, :]
                vocab = self.constraint.filter_vocab(node.state, self.device)
                
                step_tokens = vocab.nonzero().cpu().tolist() # [[x] for a possible token x]            
                raw_prob = _process_score(raw_score)
    
                sum_prob = sum([raw_prob[index[0]] for index in step_tokens])
                assert sum_prob > EPS
                node.sum_prob = sum_prob
                
                for _token in step_tokens:
                    token = _token[0]
                    prob = raw_prob[token].cpu()
                    if prob / sum_prob < EPS: continue
                    node.insert_child(token, self.constraint._consume_token_id(token, node.state), raw_prob[token].cpu())
            token = node.sample_next()
            node = node.children[token]
            if token == self.tokenizer.eos_token_id:
                return node
        assert False

    def mutate(self, node):
        parents = _get_meaningful_parents(node)
        if len(parents) == 0: return node
        parent, token_x = random.choice(parents)
        #print("Parent ", parent.to_string(self.tokenizer))
        #print(token_x, parent.step_prob)
        token_y = parent.sample_next(token_x)
        node_y = self.draw(parent.children[token_y])
        parents_y = _get_meaningful_parents(node_y)

        phi_x, phi_y = node.prefix_prob[RAW], node_y.prefix_prob[RAW]
        #print(float(node.prefix_prob[MASKED]), float(node_y.prefix_prob[MASKED]), float(parent.prefix_prob[MASKED]))
        #print(float(parent.sum_prob), float(parent.step_prob[token_y]))
        trans_xy = 1 / len(parents) * node_y.prefix_prob[MASKED] / (parent.sum_prob - parent.step_prob[token_x]) #/ parent.prefix_prob[MASKED]
        trans_yx = 1 / len(parents_y) * node.prefix_prob[MASKED] / (parent.sum_prob - parent.step_prob[token_y]) #/ parent.prefix_prob[MASKED]

        accept_prob = phi_y * trans_yx / phi_x / trans_xy
        #print("transition")
        #print("  ", node.to_string(self.tokenizer))
        #print("  ", node_y.to_string(self.tokenizer))
        #print("  ", float(accept_prob), float(trans_xy), float(trans_yx))
        if random.random() < accept_prob:
            return node_y
        else:
            return node

    def mcmc(self, round):
        x = self.draw()
        for _ in range(round):
            x = self.mutate(x)
        return x
        
holder = SampleHolder(model, tokenizer, grammar, DEVICE)
for i in range(100):
    res = holder.mcmc(100)
    print(res.to_string(tokenizer))

<s> Generate a binary string of length 5. 10000</s>@{'raw': 6.539553254469865e-08, 'masked': 0.313358336687088}
<s> Generate a binary string of length 5. 10011</s>@{'raw': 2.0651822474349046e-09, 'masked': 0.022224845364689827}
<s> Generate a binary string of length 5. 10101</s>@{'raw': 2.4343240756508067e-09, 'masked': 0.025169778615236282}
<s> Generate a binary string of length 5. 10000</s>@{'raw': 6.539553254469865e-08, 'masked': 0.313358336687088}
<s> Generate a binary string of length 5. 00000</s>@{'raw': 7.032596016642856e-08, 'masked': 0.3923368453979492}
<s> Generate a binary string of length 5. 10000</s>@{'raw': 6.539553254469865e-08, 'masked': 0.313358336687088}
<s> Generate a binary string of length 5. 00000</s>@{'raw': 7.032596016642856e-08, 'masked': 0.3923368453979492}
<s> Generate a binary string of length 5. 00000</s>@{'raw': 7.032596016642856e-08, 'masked': 0.3923368453979492}
<s> Generate a binary string of length 5. 10000</s>@{'raw': 6.539553254469865e-08, 'masked': 

In [43]:
generated_ids = output.sequences[0]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"生成的文本: {generated_text}")

生成的文本: Generate a binary string of length 5 11011
