## Changing logits based on desired alliteration

We want to generate a poem where as many words as possible start with a desired letter

In [None]:
#%pip install --upgrade --quiet transformers torch fbgemm-gpu accelerate

In [1]:
# CHANGE this to the Llama model for which you have applied for access via Hugging Face
# See: https://www.llama.com/docs/getting-the-models/hugging-face/
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

import os
from dotenv import load_dotenv
load_dotenv("../keys.env")
assert os.environ["HF_TOKEN"][:2] == "hf",\
       "Please sign up for access to the specific Llama model via HuggingFace and provide access token in keys.env file"

## Zero-shot generation

Without any logits processing

In [2]:
from transformers import pipeline

pipe = pipeline(
    task="text-generation", 
    model=MODEL_ID,
    use_fast=True,
    kwargs={
        "return_full_text": False,
    },
    model_kwargs={}
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [4]:
def generate_poem(animal: str) -> str:
    system_prompt = f"""
        You are writing nursery rhymes about animals for a children's book.
        Each poem should be 3-5 lines long.
        Return only the poem, without any introduction or preamble.
    """
    user_prompt = f"""
        Write a poem about a {animal}.
    """

    input_message = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}   
    ]
    
    results = pipe(input_message, 
                   max_new_tokens=256)
    return results[0]['generated_text'][-1]['content'].strip()

poem = generate_poem("donkey")
print(poem)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Hee-haw, a donkey so bright,
Carrying loads with all his might,
His soft fur and gentle eyes shine,
A friendly friend, always on my mind.


Result:
```
Hee-haw, a donkey so bright,
Carrying loads with all his might,
His soft fur and gentle eyes shine,
A friendly friend, always on my mind.
```
Well shine & mind is not the perfect rhyme,
but still pretty good ... but if we ask for the poem about donkeys to
also be alliterative?

In [3]:
def generate_alliterative_poem(animal: str) -> str:
    system_prompt = f"""
        You are writing alliterative nursery rhymes about animals for a children's book.
        Each poem should be 3-5 lines long and contain as many words as possible
        that start with the desired letter.
        Return only the poem, without any introduction or preamble.
    """
    
    # alliterate on the first letter of the animal. So, donkey would be D
    user_prompt = f"""
        Write a poem about a {animal}. Use many words that start with {animal.upper()[0]}
    """

    input_message = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}   
    ]
    
    results = pipe(input_message, 
                   max_new_tokens=256)
    return results[0]['generated_text'][-1]['content'].strip()

poem = generate_alliterative_poem("donkey")
print(poem)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Daring donkeys dash down dusty deeds,
Doting devotees delighted in their deeds,
Doting donkeys deliver delightful dreams,
Delighting dreamers daily with their dignified deeds.


Result:
```
Daring donkeys dash down dusty deeds,
Doting devotees delighted in their deeds,
Doting donkeys deliver delightful dreams,
Delighting dreamers daily with their dignified deeds.
```
The poem above is not great. By trying to match the style, the quality has gone way down.

## Use logits processing to enhance the alliteration

We'll use the poem prompt, but use logits processing to prefer words that start with the desired letter

In [11]:
import torch
from transformers.generation.logits_process import (
    LogitsProcessor,
    LOGITS_PROCESSOR_INPUTS_DOCSTRING,
)
from transformers.utils import add_start_docstrings

class AlliterativeLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, start_letter):
        self.tokenizer = tokenizer
        self.start_letter = start_letter
      
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(
        self, input_ids: torch.LongTensor, input_logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        output_logits = input_logits.clone()
        
        for idx, seq in enumerate(input_ids):
            # decode the sequence
            decoded = self.tokenizer.decode(seq)
            # count the number of words that start with desired letter
            num_matches = sum([1 for word in decoded.split() if word[0] == self.start_letter])
            input_score = input_logits[idx]
            output_score = input_score / (num_matches + 1)  # logits go from -inf to 0, so less negative is better
            output_logits[idx] = output_score
            
        return output_logits
    
def generate_alliterative_poem_v2(animal: str) -> str:
    system_prompt = f"""
        You are writing nursery rhymes about animals for a children's book.
        Each poem should be 3-5 lines long.
        Return only the poem, without any introduction or preamble.
    """
    user_prompt = f"""
        Write a poem about a {animal}.
    """

    input_message = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}   
    ]
    
    # alliterate on the first letter of the animal. So, donkey would be D
    grammar_processor = AlliterativeLogitsProcessor(pipe.tokenizer, animal[0])
    
    results = pipe(input_message, 
                   max_new_tokens=256,
                   do_sample=False, 
                   logits_processor=[grammar_processor])
    return results[0]['generated_text'][-1]['content'].strip()

poem = generate_alliterative_poem_v2("donkey")
print(poem)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Hee-haw, the donkey says with glee,
His loud voice echoes wild and free,
He brays all day, a happy sight.
