# Generating metered verse with LLaMA

This notebook demonstrates how to use Meta's 6 billion parameter LLaMA model to generate song lyrics with a specific metric structure. If you've been yearning to rewrite the happy birthday song so that it's just about dogs, `MetricGenerator` can help :). 

This notebook relies on the `MetricGenerator` class implemented my [bragi](https://github.com/joehoover/bragi) repository. If you're wondering, Bragi (pictured below) is the [Norse god](https://en.wikipedia.org/wiki/Bragi) of poetry.

In [277]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

from IPython.display import Image
from IPython.core.display import HTML 
Image(url= "http://my_site.com/my_picture.jpg")
Image(url= "https://upload.wikimedia.org/wikipedia/commons/thumb/2/20/Bragi_by_Wahlbom.jpg/440px-Bragi_by_Wahlbom.jpg", width=150, height=150)


The `MetricGenerator` class allows you to specify an *initialization* song that is used to specify a target metric structure. Then, given this target metric structure, it attempts to generate new lyrics with the same metric structure. 

For example, the following lines

<blockquote>
Happy birthday to you<br>
Happy birthday to you<br>
Happy birthday dear Marvin<br>
Happy birthday to you
</blockquote>

have 6, 6, 7, and 6 syllables respectively. With `MetricGenerator`, you can use those lyrics to initialize a metric structure and then generate new lyrics with the same metric structure.

And, most importantly, you can guide the new lyrics with a prompt!



## Quick demo

Let's check it out! First, we'll load our dependencies, specify some paths, and initialize a model and tokenizer.

For this demo, I'll use Meta's 6 billion parameter `LLaMA` model.

In [None]:
from bragi.metric_generator import MetricGenerator
from transformers import LLaMAForCausalLM, LLaMATokenizer
import torch 

CACHE_DIR = 'weights'
SEP = "<sep>"
MODEL_PATH  = "/src/weights"
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


model = LLaMAForCausalLM.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True).to(device)
tokenizer = LLaMATokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True)

Now I'll initialize the metric generator class.

In [247]:
generator = MetricGenerator(model=model, tokenizer=tokenizer, device=device)

To generate lyrics, you just need to provide a prompt and the lyrics you want to use to initialize your metric structure. Let's try the happy birthday song displayed above.

In [245]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
print(text_init)

Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


And let's try to generate a new song that has the same metric structure, but is about dogs. To do that, we can use something like the following prompt:

In [248]:
prompt = """This is a song about dogs:\n"""
print(prompt)

This is a song about dogs:



Now, we can call our `MetricGenerator` instance. The `__call__` method of `MetricGenerator` accepts all of the arguments that `Transformers.model.generate()` does, but it adds a few extra methods that allow us to constrain the generation process so that the target metric structure is respected.

**Note:** If you want the raw token_ids, you can call `MetricGenerator.generate()` with the same arguments. In contrast, `MetricGenerator.__call__()` returns the decoded output string.

In [251]:
torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    free_tokens=['||', '?', '.', ','],
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    bad_words_ids=[[8876]],
)

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the dogs can 
be so good, so bad, and 
so strange. Sometimes I love them 
when they are the most good


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


Neat! We rewrote happy birthday to be aboud dogs.

## Exploring the effect of prompt conditioning

Let's try some other prompts and see how responsive the model is to prompt conditioning. 

## **California ☀️!**

In [274]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This is a song about about California sun:\n"""

In [275]:
torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This is a song about about California sun:


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the colors 
the lights are bright, the air 
is warm and the water is 
clear. I love the way the


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


## **Dark and stormy 🌊**!

In [266]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:\n\""""

In [268]:

torch.manual_seed(5)
output = generator(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output, pt=False)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init, pt=False)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:
"


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
In the storm, the wind and 
the sea are one. The waves 
are the breath of the gods. As 
they break and fall back, they


----Syllables-----
Syllables per line in output: [6, 6, 7, 6]
Syllables per line in `text_init`: [6, 6, 7, 6]


# Methodology

Now that you've seen `MetricGenerator` in action, let's take a peak under the hood. We'll discuss:

1. How `MetricGenerator` constrains generation
2. The fact that it occasionally violates the target metric structure
3. Shortcomings of the current implementation
4. Ways that we could improve `MetricGenerator` and shortcomings

### How does `MetricGenerator` constrain generation?

`MetricGenerator` uses a logit warper to constrain model vocabulary during the generation process. A logit warper is a callable that modifies a model's probability distribution over tokens during generation. Specifically, `MetricGenerator` uses the `SyllableRestrictionWarper`, which is implemented in `bragi.logit_warpers`. 

The `SyllableRestrictionWarper` constrains the metric structure of model generated output by tracking a line-level metric *budget* and a number-of-lines budget. Every time the generator *spends* one more more syllables, the budget is decremented. When all of the syllables permitted for a given line are spent, the warper forces a new line token by setting the logit scores for all other tokens to `-Inf`. Further, when a new line token is emitted, the number-of-lines budget is decremented by one. Finally, when the number-of-lines budget reaches zero, the warper forces the eos token using the same approach. 




#### Calculating syllable cost

In [200]:
generator = MetricGenerator(model=model, tokenizer=tokenizer, device='cuda')

### A song about dogs

In [201]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This is a song about dogs:\n"""
torch.manual_seed(2)
output = generator.generate(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the dogs can 
be so good, so bad, and 
so strange. Sometimes I love them 
when they are the most good


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


In [204]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This is a song about about California sun:\n"""
torch.manual_seed(2)
output = generator.generate(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This is a song about about California sun:


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the colors 
the lights are bright, the air 
is warm and the water is 
clear. I love the way the


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


In [214]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:\n\""""
torch.manual_seed(5)
output = generator.generate(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output, pt=False)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init, pt=False)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:
"


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
In the storm, the wind and 
the sea are one. The waves 
are the breath of the gods. As 
they break and fall back, they


----Syllables-----
Syllables per line in output: [6, 6, 7, 6]
Syllables per line in `text_init`: [6, 6, 7, 6]


# Deviations from specified meter

In [215]:
!pip install tqdm

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [241]:
import tqdm

prompt = """This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:\n\""""
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
init_syllables = generator.calculate_syllable_budget(text_init)

output_syllables = []
for i in tqdm.tqdm(range(0, 100)):
    output = generator.generate(
        prompt = prompt,
        text_init = text_init,
        # syllable_budget = torch.Tensor([6., 6.]),
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        remove_invalid_values=True,
        do_sample=True,
        top_k=25,
        temperature=.7,
        max_length = 100,
        new_line_token='||',
        free_tokens=['||', '?', '.', ','], 
        bad_words_ids=[[8876]],
    )
    
    output_syllables.append(generator.calculate_syllable_budget(output))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:01<00:00,  3.02s/it]


In [242]:
stacked_output_syllables = torch.stack(output_syllables)
deviation = stacked_output_syllables - init_syllables

In [243]:
m = deviation.mean()
std = deviation.std()

print(f"The average line-level violation of the metric constraint is {m}.")
print(f"The vast majority of deviations are within +/- {std*2} syllables of the constraint.")

The average line-level violation of the metric constraint is -0.5174999833106995,       which suggests that the mechanism is unbiased but still flawed.
The vast majority of deviations are within +/- 1.824031949043274 syllables of the constraint.
