## Autocomplete phrases based on logits (Sequence Selection)

Typically, autocomplete is done using past search behavior. An interesting alternative is to use an LLM grounded on the document being searched -- this avoids cold start issues and reduces leakage of senstive data. You can  think of it as a variant of the Sequence Selection in Logits Masking, except that the human typing ends up doing the selection!

In [1]:
#%pip install --upgrade --quiet transformers torch fbgemm-gpu accelerate

In [2]:
# CHANGE this to the Llama model for which you have applied for access via Hugging Face
# See: https://www.llama.com/docs/getting-the-models/hugging-face/
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

import os
from dotenv import load_dotenv
load_dotenv("../keys.env")
assert os.environ["HF_TOKEN"][:2] == "hf",\
       "Please sign up for access to the specific Llama model via HuggingFace and provide access token in keys.env file"

## Load document

Ideally, this is done only once (such as by using Context Caching or Prompt Caching)

In [3]:
# Download text of play from Project Gutenberg
TXT_URL="https://www.gutenberg.org/cache/epub/1522/pg1522.txt"
LOCAL_FILE="julius_caesar.txt"

import requests

def download_text_file(url, file_path):
    response = requests.get(url)
    response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
    with open(file_path, "wb") as file:
        file.write(response.content)
    print(f"File downloaded successfully to {file_path}")

download_text_file(TXT_URL, LOCAL_FILE)

File downloaded successfully to julius_caesar.txt


In [4]:
lines = open(LOCAL_FILE).readlines()

In [5]:
start_index = 0
end_index = -1
for idx, line in enumerate(lines):
    if line.startswith("*** START OF THE PROJECT GUTENBERG EBOOK"):
        start_index = idx
    if line.startswith("*** END OF THE PROJECT GUTENBERG EBOOK"):
        end_index = idx
lines = lines[start_index+1:end_index]
lines = [line for line in lines if len(line.strip()) > 0]
len(lines)

3605

## Use logits processing to select the next word

Display options to user, take what they provide back

In [6]:
import torch
import numpy as np
from transformers.generation.logits_process import (
    LogitsProcessor,
    LOGITS_PROCESSOR_INPUTS_DOCSTRING,
)
from transformers.utils import add_start_docstrings

class AutocompleteLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, selection_func):
        self.tokenizer = tokenizer
        self.selection_func = selection_func
      
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(
        self, input_ids: torch.LongTensor, input_logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        output_logits = input_logits.clone()
        
        decoded = [self.tokenizer.decode(seq) for seq in input_ids]
        selected = self.selection_func(decoded) 
        
        # logits goes from -inf to zero.  Mask out everything other than the selected index torch doesn't like it to be -np.inf
        for idx in range(len(input_ids)):
            if idx != selected:
                output_logits[idx] = -10000
                  
        return output_logits

In [7]:
from transformers import pipeline

pipe = pipeline(
    task="text-generation", 
    model=MODEL_ID,
    use_fast=True,
    kwargs={
        "return_full_text": False,
    },
    model_kwargs={}
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [8]:
def simulate_human_selection(choices: [str]) -> int:
    import random
    # for simplicity, we'll assume they select something in the list
    selected = random.randrange(len(choices))
    return selected

def get_autocomplete_choice(document_lines: [str], typed_so_far: str) -> str:
    # Llama has a very limited context, so let's filter
    document = []
    for line in document_lines:
        if typed_so_far.lower() in line.lower():
            document.append(line.strip())
    print(f"Found {len(document)} lines containing {typed_so_far}: {document}")
    document = '\n'.join(document)
    
    system_prompt = f"""
        Complete this phrase in a style similar to the ones below. You are acting as auto-complete.
        Simply complete the phrase without any introduction or preamble.
        Make sure it is only one sentence
        
        ** Examples **:
        {document}
    """
    user_prompt = f"""
        ** Phrase **:
        {typed_so_far} ____
    """

    input_message = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}   
    ]
    
    # get autocompletion selection
    autocomplete = AutocompleteLogitsProcessor(pipe.tokenizer, simulate_human_selection)
    
    results = pipe(input_message, 
                   max_new_tokens=256,
                   do_sample=True,
                   temperature=0.8,
                   num_beams=5,
                   logits_processor=[autocomplete])
    
    return results[0]['generated_text'][-1]['content'].strip()

In [9]:
typed_text = "Lend"
choice = get_autocomplete_choice(lines, typed_text)
print(choice)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Found 3 lines containing Lend: ['Look in the calendar, and bring me word.', 'Lend me your hand.', 'Friends, Romans, countrymen, lend me your ears;']
your hearts
