## Autocomplete phrases based on logits (Sequence Selection)

Typically, autocomplete is done using past search behavior. An interesting alternative is to use an LLM grounded on the document being searched -- this avoids cold start issues and reduces leakage of senstive data. You can  think of it as a variant of the Sequence Selection in Logits Masking, except that the human typing ends up doing the selection!

In [None]:
#%pip install --upgrade --quiet transformers torch fbgemm-gpu accelerate

In [1]:
# CHANGE this to the Llama model for which you have applied for access via Hugging Face
# See: https://www.llama.com/docs/getting-the-models/hugging-face/
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

import os
from dotenv import load_dotenv
load_dotenv("../keys.env")
assert os.environ["HF_TOKEN"][:2] == "hf",\
       "Please sign up for access to the specific Llama model via HuggingFace and provide access token in keys.env file"

## Load document

Ideally, this is done only once (such as by using Context Caching or Prompt Caching)

In [2]:
# Download text of play from Project Gutenberg
TXT_URL="https://www.gutenberg.org/cache/epub/1522/pg1522.txt"
LOCAL_FILE="julius_caesar.txt"

import requests

def download_text_file(url, file_path):
    response = requests.get(url)
    response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
    with open(file_path, "wb") as file:
        file.write(response.content)
    print(f"File downloaded successfully to {file_path}")

download_text_file(TXT_URL, LOCAL_FILE)

File downloaded successfully to julius_caesar.txt


In [3]:
lines = open(LOCAL_FILE).readlines()

In [5]:
start_index = 0
end_index = -1
for idx, line in enumerate(lines):
    if line.startswith("*** START OF THE PROJECT GUTENBERG EBOOK"):
        start_index = idx
    if line.startswith("*** END OF THE PROJECT GUTENBERG EBOOK"):
        end_index = idx
lines = lines[start_index+1:end_index]
len(lines)

4662

## Use logits processing to select the next word

Display options to user, take what they provide back

In [6]:
import torch
import numpy as np
from transformers.generation.logits_process import (
    LogitsProcessor,
    LOGITS_PROCESSOR_INPUTS_DOCSTRING,
)
from transformers.utils import add_start_docstrings

class AutocompleteLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, selection_func):
        self.tokenizer = tokenizer
        self.selection_func = selection_func
      
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(
        self, input_ids: torch.LongTensor, input_logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        output_logits = input_logits.clone()
        
        decoded = [self.tokenizer.decode(seq) for seq in input_ids]
        selected = selector(decoded) 
        
        # logits goes from -inf to zero.  Mask out everything other than the selected index torch doesn't like it to be -np.inf
        for idx in range(len(input_ids)):
            if idx != selected:
                output_logits[idx] = -10000
                  
        return output_logits

In [None]:
from transformers import pipeline

pipe = pipeline(
    task="text-generation", 
    model=MODEL_ID,
    use_fast=True,
    kwargs={
        "return_full_text": False,
    },
    model_kwargs={}
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
def simulate_human_selection(choices: [str]) -> int:
    import random
    selected = random.randrange(len(my_list))
    for idx, choice in enumerate(choices):
        print(idx, ": ", choice, " (selected)" if idx == selected else "")
    return selected

def get_autocomplete_phrases(document: str, typed_so_far: str) -> str:
    system_prompt = f"""
        Use the following document to identify a potential continuation
        for the given phrase. Provide just the continuation without any preamble.
        
        <document>
        {document}
        </document>
        
        **Phrase to complete**:
    """
    user_prompt = f"""
        {typed_so_far}
    """

    input_message = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}   
    ]
    
    # alliterate on the first letter of the animal. So, donkey would be D
    autocomplete = AutocompleteLogitsProcessor(pipe.tokenizer, simulate_human_selection)
    
    results = pipe(input_message, 
                   max_new_tokens=16,
                   do_sample=True,
                   temperature=0.8,
                   num_beams=10,
                   logits_processor=[autocomplete])
    
    return results[0]['generated_text'][-1]['content'].strip()

choices = get_autocomplete_phrases('\n'.join(lines), "Lend")
print(poem)

NameError: name 'pipe' is not defined

Result:
```
Little donkey, ears so bright,
Hee-hawing loud through day's delight,
He trots along with gentle pace,
A friendly friend in a sunny place.
```
Has 3 ds

## Combine prompting and sequence selection

Enhance the prompt but make it clear we want it to be readable,
but use logits processing to prefer words that start with the desired letter

In [13]:
def generate_alliterative_poem_v3(animal: str) -> str:
    start_letter = animal[0]
    
    system_prompt = f"""
        You are writing nursery rhymes about animals for a children's book.
        Each poem should be 3-5 lines long. The poem must be readable and suitable for children.
        Return only the poem, without any introduction or preamble.
    """
    user_prompt = f"""
        Write a poem about a {animal} that has a few alliterations involving {start_letter}.
        Do not overdo alliteration, and emphasize readability.
    """

    input_message = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}   
    ]
    
    # alliterate on the first letter of the animal. So, donkey would be D
    grammar_processor = AlliterativeLogitsProcessor(pipe.tokenizer, start_letter)
    
    results = pipe(input_message, 
                   max_new_tokens=256,
                   do_sample=True,
                   temperature=0.8,
                   num_beams=10,
                   use_cache=True, # default is True
                   logits_processor=[grammar_processor])
    return results[0]['generated_text'][-1]['content'].strip()

poem = generate_alliterative_poem_v3("donkey")
print(poem)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Down the dusty desert, donkey did stray
Dreaming of delicious dates to devour each day
Dainty donkey danced down the desert way


Result:
```
Down the dusty desert, donkey did stray
Dreaming of delicious dates to devour each day
Dainty donkey danced down the desert way
```
It's still a readable poem, but we have picked the one with the most ds.