# Generating metered verse with LLaMA

This notebook demonstrates how to use Meta's 6 billion parameter LLaMA model to generate song lyrics with a specific metric structure. If you've been yearning to rewrite the happy birthday song so that it's just about dogs, `MetricGenerator` can help :). 

This notebook relies on the `MetricGenerator` class implemented my [bragi](https://github.com/joehoover/bragi) repository. If you're wondering, Bragi (pictured below) is the [Norse god](https://en.wikipedia.org/wiki/Bragi) of poetry.

In [277]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

from IPython.display import Image
from IPython.core.display import HTML 
Image(url= "http://my_site.com/my_picture.jpg")
Image(url= "https://upload.wikimedia.org/wikipedia/commons/thumb/2/20/Bragi_by_Wahlbom.jpg/440px-Bragi_by_Wahlbom.jpg", width=150, height=150)


The `MetricGenerator` class allows you to specify an *initialization* song that is used to specify a target metric structure. Then, given this target metric structure, it attempts to generate new lyrics with the same metric structure. 

For example, the following lines

<blockquote>
Happy birthday to you<br>
Happy birthday to you<br>
Happy birthday dear Marvin<br>
Happy birthday to you
</blockquote>

have 6, 6, 7, and 6 syllables respectively. With `MetricGenerator`, you can use those lyrics to initialize a metric structure and then generate new lyrics with the same metric structure.

And, most importantly, you can guide the new lyrics with a prompt!

## Quick demo

Let's check it out! First, we'll load our dependencies, specify some paths, and initialize a model and tokenizer.

For this demo, I'll use Meta's 6 billion parameter `LLaMA` model.

In [None]:
from bragi.metric_generator import MetricGenerator
from transformers import LLaMAForCausalLM, LLaMATokenizer
import torch 

CACHE_DIR = 'weights'
SEP = "<sep>"
MODEL_PATH  = "/src/weights"
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


model = LLaMAForCausalLM.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True).to(device)
tokenizer = LLaMATokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True)

Now I'll initialize the metric generator class.

In [324]:
generator = MetricGenerator(model=model, tokenizer=tokenizer, device=device)

To generate lyrics, you just need to provide a prompt and the lyrics you want to use to initialize your metric structure. Let's try the happy birthday song displayed above.

In [245]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
print(text_init)

Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


And let's try to generate a new song that has the same metric structure, but is about dogs. To do that, we can use something like the following prompt:

In [320]:
prompt = """This is a song about dogs:\n"""
print(prompt)

This is a song about dogs:



Now, we can call our `MetricGenerator` instance. The `__call__` method of `MetricGenerator` accepts all of the arguments that `Transformers.model.generate()` does, but it adds a few extra methods that allow us to constrain the generation process so that the target metric structure is respected.

**Note:** If you want the raw token_ids, you can call `MetricGenerator.generate()` with the same arguments. In contrast, `MetricGenerator.__call__()` returns the decoded output string.

In [326]:
torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    free_tokens=['||', '?', '.', ','],
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    bad_words_ids=[[8876]],
)

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the dogs can 
be so good, so bad, and 
so strange. Sometimes I love them 
when they are the most good


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


Neat! We rewrote happy birthday to be aboud dogs.

## Exploring the effect of prompt conditioning

Let's try some other prompts and see how responsive the model is to prompt conditioning. 

## **California ☀️!**

In [274]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This is a song about about California sun:\n"""

In [275]:
torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This is a song about about California sun:


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the colors 
the lights are bright, the air 
is warm and the water is 
clear. I love the way the


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


## **Dark and stormy 🌊**!

In [266]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:\n\""""

In [268]:

torch.manual_seed(5)
output = generator(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output, pt=False)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init, pt=False)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:
"


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
In the storm, the wind and 
the sea are one. The waves 
are the breath of the gods. As 
they break and fall back, they


----Syllables-----
Syllables per line in output: [6, 6, 7, 6]
Syllables per line in `text_init`: [6, 6, 7, 6]


# Methodology

Now that you've seen `MetricGenerator` in action, let's take a peak under the hood. We'll discuss:

1. How `MetricGenerator` constrains generation
2. The fact that it occasionally violates the target metric structure
3. Shortcomings of the current implementation
4. Ways that we could improve `MetricGenerator` and shortcomings

### How does `MetricGenerator` constrain generation?

`MetricGenerator` uses a logit warper to constrain model vocabulary during the generation process. A logit warper is a callable that modifies a model's probability distribution over tokens during generation. Specifically, `MetricGenerator` uses the `SyllableRestrictionWarper`, which is implemented in `bragi.logit_warpers`. 

The `SyllableRestrictionWarper` constrains the metric structure of model generated output by tracking a line-level metric *budget* and a number-of-lines budget. Every time the generator *spends* one more more syllables, the budget is decremented. When all of the syllables permitted for a given line are spent, the warper forces a new line token by setting the logit scores for all other tokens to `-Inf`. Further, when a new line token is emitted, the number-of-lines budget is decremented by one. Finally, when the number-of-lines budget reaches zero, the warper forces the eos token using the same approach. 




#### Calculating syllable cost

The `SyllableRestrictionWarper` requires a method for calculating the *cost* of each token in the model's vocabulary. In theory, this cost function can implement arbitrary scoring logic; however, for our purposes model tokens are *scored* according to their number of syllables. 

This is accomplished using the `phones_for_word` and `syllable_count` methods implemented in the `pronouncing` library, a python interface for the CMU Pronouncing Dictionary. `Bragi` provides the `verse_parsers.cmu_syllable_counter` function, which wraps the `pronouncing` methods.

In [284]:
from bragi.verse_parsers import cmu_syllable_counter
print(f"'Hello' has {cmu_syllable_counter('Hello')} syllables")
print(f"'world' has {cmu_syllable_counter('world')} syllables")

'Hello' has 2 syllables
'world' has 1 syllables


`MetricGenerator` calculates the syllables of every token in the model's vocabulary.

However, because LLaMA uses a [BPE tokenizer],(https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer), so it's crucial to calculate token syllables on decoded tokens and not the raw model vocabulary. For example, the characters 'and' are tokenized into '▁and', which is OOV for the CMU dictionary. `MetricGenerator` uses the `verse_parsers.token_syllable_scores` which handles decoding and syllable calculation for the all tokens in the model vocabulary. 

#### Calculating the metric structure of the initialization text

To calculate the metric structure of the initialization text, `MetricGenerator` uses a wrapper around the `poesy` package. `poesy` is a python interface for the `eSpeak` speech synthesis library, which provides more robust syllable calculation than simply performing lexical lookups against the CMU Pronouncing Dictionary. 

`bragi` provides the `PoesyParsedVerseHandler` class as an interface for the `poesy`. You can use the `example` method to extract the syllabic structure of a string in which new lines are specified as '\n'.

For example:

In [295]:
from bragi.verse_parsers import PoesyParsedVerseHandler
verse_handler = PoesyParsedVerseHandler()

text = """An elderly man called Keith,
Mislaid his set of false teeth.
They'd been laid on a chair,
He'd forgot they were there,
Sat down, and was bitten beneath."""

print(text)

example, syllable_budget = verse_handler.example(text)
print(f"\nThe lines of the limerick above have the following syllable counts: {syllable_budget}")

An elderly man called Keith,
Mislaid his set of false teeth.
They'd been laid on a chair,
He'd forgot they were there,
Sat down, and was bitten beneath.

The lines of the limerick above have the following syllable counts: [7, 7, 6, 6, 8]


**Note:** The `PoesyParsedVerseHandler.example()` method also transforms an input text to a control code that can be used for fine-tuning. In addition to capturing the syllabic structure of the input text, it also attempts to detect the rhyme scheme. This functionality was inpsired by [Ormazabal et al. (2022)](https://aclanthology.org/2022.findings-emnlp.268/), which demonstrates that fine-tuning a language model on a control code such as the one below substantially improves its ability to produce verse with a specific metric structure.

E.g.:

In [298]:
print(example)

<PREF>
<SYLLABLES: 7><RHYME: A>
<SYLLABLES: 7><RHYME: A>
<SYLLABLES: 6><RHYME: B>
<SYLLABLES: 6><RHYME: B>
<SYLLABLES: 8><RHYME: A>
</PREF>

An elderly man called Keith,
Mislaid his set of false teeth.
They'd been laid on a chair,
He'd forgot they were there,
Sat down, and was bitten beneath.


### Putting it all together

`SyllableRestrictionWarper` uses the line-level syllable counts obtained from `poesy` to construct a syllable budget. Then at each inference step, words are dynamically masked according to their syllable cost, conditional on the available budget for the current line. 

For example, imagine that the tokens corresponding to the 6-syllable character sequence "An elderly man called" have already been generated for a line with a 7 syllable budget. In this case, tokens like "Marvin", "Magnolia", and "murky" will be masked, as will any other token with two or more syllables. 

# Limitations

## Metric Structure Errors

`MetricGenerator` does not always strictly adhere to the initialized metric structure. The indexing over syllable and line budgets is still in alpha and there's probably a bug causing this behavior. 

In [317]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """Song lyrics for "Heavy Metal Sunset", an 80's metal song about Los Angeles:\n\""""

torch.manual_seed(1)
output = generator(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=3,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.8,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output, pt=False)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init, pt=False)}")


---prompt---
Song lyrics for "Heavy Metal Sunset", an 80's metal song about Los Angeles:
"


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
Heartache and sorrow 
A thousand miles apart 
I can see it clear, I can 
I feel it in my heart


----Syllables-----
Syllables per line in output: [5, 6, 7, 6]
Syllables per line in `text_init`: [6, 6, 7, 6]


### Investigating deviations from specified meter

In order to get a sense of `MetricGenerator`'s failure rate, we can look at the distribution of deviations from expected number of syllables per line. Below, I run 100 generations against a target text and calculate the mean and standard deviation of these deviations.

In [215]:
!pip install tqdm

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [241]:
import tqdm

prompt = """This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:\n\""""
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
init_syllables = generator.calculate_syllable_budget(text_init)

output_syllables = []
for i in tqdm.tqdm(range(0, 100)):
    output = generator.generate(
        prompt = prompt,
        text_init = text_init,
        # syllable_budget = torch.Tensor([6., 6.]),
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        remove_invalid_values=True,
        do_sample=True,
        top_k=25,
        temperature=.7,
        max_length = 100,
        new_line_token='||',
        free_tokens=['||', '?', '.', ','], 
        bad_words_ids=[[8876]],
    )
    
    output_syllables.append(generator.calculate_syllable_budget(output))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:01<00:00,  3.02s/it]


In [242]:
stacked_output_syllables = torch.stack(output_syllables)
deviation = stacked_output_syllables - init_syllables

In [243]:
m = deviation.mean()
std = deviation.std()

print(f"The average line-level violation of the metric constraint is {m}.")
print(f"The vast majority of deviations are within +/- {std*2} syllables of the constraint.")

The average line-level violation of the metric constraint is -0.5174999833106995,       which suggests that the mechanism is unbiased but still flawed.
The vast majority of deviations are within +/- 1.824031949043274 syllables of the constraint.


### Beam search is not supported

An obvious way to improve output quality is to conduct generation with beam search. However, the current implementation of the syllable budgeting process does not support beam search.

# Future Work


**Excluding partial words from the model vocabulary**

Currently, beyond the syllable budget, `MetricGenerator` does not enforce any additional constraints on the generation process or available model vacubulary at a given step. This means that it *can* output partial words or non-sense tokens. Accordingly, excluding partial-word tokens may improve output quality. However, standard best-practices for text generation such as specifying `topk` appear to largely mitigate the partial-word issue. With a reasonably small `topk` value (e.g. 25-50), partial word emissions are a negligible problem. Accordingly, this has been left for future work. 

**Fine-tuning on in-distribution data**

While emitting partial words is largely mitigated by masking all but the highest probability tokens, no current generation settings address the problem of abrupt stopping. Because `MetricGenerator` forces an eos token when the syllable budget of the last line is spent, most generated sequences have an unnatural ending. While implementing support for beam search may partially mitigate this, it will likely be far more effective to fine-tune on in-domain data such as song lyrics and poems. The fine-tuning process could be further enhanced using control codes ([Ormazabal et al. (2022)](https://aclanthology.org/2022.findings-emnlp.268/)) that encode the metric structure of the target text. At inference time, generation would condition on the control code and this would–in theory–guide the model toward natural endings. 

Fine-tuning on in-distribution data would also provide an opportunity to improve the conditioning effect of prompts on generated sequences. For example, the instruction-tuning methods implemented in [Chakrabarty (2022)](https://arxiv.org/abs/2210.13669) allow users to request rhetorical devices such as metaphor. 

**Constrained verse generation with a fine-tuned model**

Ultimately, the best solution will likely rely on a combination of rule-based constraints (e.g.  excluding tokens based on their syllabic cost) and fine-tuning. This approach should yield a model that is forced to follow the rules but also tought how to follow them with style.

# Deviations from specified meter