# Structured Generation Experiments with HuggingFace - transformers LogitsProcessor

After reading this article:

[https://towardsdatascience.com/structured-generative-ai-e772123428e4](https://towardsdatascience.com/structured-generative-ai-e772123428e4)

and discovering `from transformers.generation.logits_process import LogitsProcessorList, LogitsProcessor`

---

- The idea of this article is forcing output to conform to SQL syntax.
- The mental image is to think of the task as "translating to a structured language"
- So the list of legitimate tokens at every generation step is limited

**We want to "insert" this knowledge (i.e. of legitimate tokens) into the generative process**

---

### How to do it

At each step your output is a list of logit values **for all possible tokens in your vocabulary**.

To limit token generation, the idea is then to assign `-inf` value to all the tokens that you do not want to occur at that step.

### LogitsProcessor

To use this from HuggingFace you need to implement a class with a `__call__` method, which will be called after the logits are calculated, but before the sampling step.

The method:
- receives all token logits and generated input IDs
- returns modified logits for all tokens (i.e. some will now be `-inf` based on some rule(s) for example)

**We'll use BART for this notebook and we're trying to generate SQL**

In [1]:
import torch

from transformers import BartForConditionalGeneration, BartTokenizerFast, PreTrainedTokenizer
from transformers.generation.logits_process import LogitsProcessorList, LogitsProcessor


name = 'facebook/bart-large'
tokenizer = BartTokenizerFast.from_pretrained(name, add_prefix_space=True)
pretrained_model = BartForConditionalGeneration.from_pretrained(name)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


We want to generate a "translation" from natural language to SQL - use the following example text, and see what the model does without any training **and without any structured generation constraint of course either**

In [2]:
# imagine a kind of SQL query described in natural language:
to_translate = 'customers emails from the us'

# TODO: I have questions about this step but at end of article there is comment
# about tokenization (I wonder if the split into words is because you want to ensure
# that the output/constraints correspond to SQL-lang word-level constraints??)
words = to_translate.split()
tokenized_text = tokenizer([words], is_split_into_words=True)

out = pretrained_model.generate(
    torch.tensor(tokenized_text["input_ids"]),
    max_new_tokens=20,
)

print(out)
print("----")

print(tokenizer.convert_ids_to_tokens(out[0], skip_special_tokens=True))
print("----")

print(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(out[0], skip_special_tokens=True)))

tensor([[   2,    0, 9690, 5575,   31,    5,  201,    2]])
----
['More', 'Ġemails', 'Ġfrom', 'Ġthe', 'Ġus']
----
More emails from the us


Of course we shouldn't expect SQL at this step!

We won't train the model but we will see if we can actually guide it to return SQL queries.

We develop a function that maps each generated token to a list of permissible next tokens.

**For simplicity here: we focus on the immediate predecessor token, but you can implement more advanced mechanisms**

We use a dictionary, which defines for each token (key) which tokens (values) are allowed to follow it.

NOTE: the `<s>` is the start generate token, and we want our SQL queries to start with either SELECT or DELETE here.

In this dataset hypothesis, the columns are : `name, email, id` only.

In [11]:
column_names = ["name", "email", "id"]
table_names = ["customers", "vendors"]

rules = {
    "<s>": ["SELECT", "DELETE"],
    'SELECT': column_names,  # names of columns in our schema
    'DELETE': column_names,
    'name': [',', 'FROM'],
    'email': [',', 'FROM'],
    'id': [',', 'FROM'],
    ',': column_names,
    'FROM': table_names,  # names of tables in our schema
    'customers': ['</s>'],
    'vendors': ['</s>'],  # end of the generation
}

Convert these tokens to IDs.

**You do this in a class that inherits from `LogitsProcessor`**

We also then implement the `__call__` function in this class, which is called after the logits are calculated.

The function:
1. creates a new tensor of `-infs` 
2. checks which IDs are legitimate according to the rules dict
3. places their scores in the new tensor

In [59]:
def convert_token_to_id(token):
    return tokenizer(token, add_special_tokens=False)['input_ids'][0]

class SQLLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.rules = {convert_token_to_id(k): [convert_token_to_id(v0) for v0 in v] for k,v in rules.items()}

    def __call__(self, input_ids, scores):
        if not (input_ids == self.tokenizer.bos_token_id).any():
        # we must allow the start token to appear before we start processing
            #print("HERE------")
            #print(scores, scores.shape)
            #print("====")
            return scores
        
        #print(scores, scores.shape, type(scores))
        # create a new tensor of -inf
        new_scores = torch.full((1, self.tokenizer.vocab_size), float('-inf'))
        #print(new_scores, new_scores.shape, type(new_scores))
        
        # ids of legitimate tokens
        legit_ids = self.rules[int(input_ids[0, -1])]
        #print(legit_ids)
        
        # place their values in the new tensor
        new_scores[:, legit_ids] = scores[0, legit_ids]
        #print(new_scores, new_scores.shape, type(new_scores))
        
        debug_new_scores = new_scores.repeat(4,1) # SPENT A WHILE DEBUGGING - WITH THE ABOVE STATEMENTS I FOUND THAT THE print(scores) WAS GIVING TENSORS OF SHAPE (4, vocab_size) SO SOMEWHERE IN CODE THERE IS A STEP WHERE IT'S MAKING 4 COPIES!?!?! ALL THE CONTENTS SEEM TO BE THE SAME
        
        return debug_new_scores

## Structured generation

That's all we need to run generation with a logits processor:

In [61]:
to_translate = 'customers emails from the us'
words = to_translate.split()
tokenized_text = tokenizer([words], is_split_into_words=True, return_offsets_mapping=True)

#print(torch.tensor(tokenized_text["input_ids"]))
logits_processor = LogitsProcessorList([SQLLogitsProcessor(tokenizer)])

out = pretrained_model.generate(
    torch.tensor(tokenized_text["input_ids"]),
    max_new_tokens=20,
    logits_processor=logits_processor,
)
print(tokenizer.convert_tokens_to_string(
    tokenizer.convert_ids_to_tokens(
        out[0], skip_special_tokens=True)))

 SELECT email , email , id , email FROM customers


Note that the output is a bit strange for "real" SQL, but we **didn't train the model at this point!**

---

# Be careful of tokenization

Tokenization is crucial when using GenAI for structured output.

Article gives example of e.g. training to generate JSON - if your model assigns different token ids to `[` and `[[` or e.g. `my_var_name` depending on if it has adjacent bracket `{my_var` vs `{ my_var` etc. the results and training will be worse as the logic that the model is being asked to learn is more complicated (it has to learn that the 2 brackets-symbol is 2 copies of the 1-bracket, not a distinct concept etc.)

**So when training (IE WHEN YOU DESIGN YOUR OWN TOKENIZER ??? TODO: CHECK THIS - OTHERWISE YOU HAVE NO CONTROL OVER IT IF YOU JUST USE WHICHEVER MODEL OFF THE SHELF?) ensure each concept and punctuation is consistently converted to the same token - by e.g. adding spaces before words and characters in this JSON example etc**

Then during prediction model will output JSON with spaces which you can then remove before parsing.

---

## Further reading

- recent bookmarks May 2024 on domain-specific tokenization
- ZeTT : Zero-shot Tokenizer Transfer (arXiv paper)