In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

# Overview

This notebook demonstrates how to use `MetricGenerator`, a `Transformers` wrapper that attempts to generate song lyrics with the a specific metric structure. `MetricGenerator` allows you to specify an *initialization* song that is used to specify the target metric structure. 


`MetricGenerator` then in


In [3]:
from bragi.metric_generator import MetricGenerator
from transformers import LLaMAForCausalLM, LLaMATokenizer
import torch 

CACHE_DIR = 'weights'
SEP = "<sep>"
MODEL_PATH  = "/src/weights"
device = 'cuda' if torch.cuda.is_available() else 'cpu'


model = LLaMAForCausalLM.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True).to(device)
tokenizer = LLaMATokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

# Overview

This notebook demonstrates how to use `MetricGenerator`, which is a `Transformers` wrapper that injects a logit warper that constrains the model vocabulary during generation. In theory, the 

In [200]:
generator = MetricGenerator(model=model, tokenizer=tokenizer, device='cuda')

### A song about dogs

In [201]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This is a song about dogs:\n"""
torch.manual_seed(2)
output = generator.generate(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the dogs can 
be so good, so bad, and 
so strange. Sometimes I love them 
when they are the most good


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


In [204]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This is a song about about California sun:\n"""
torch.manual_seed(2)
output = generator.generate(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This is a song about about California sun:


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
I love how the colors 
the lights are bright, the air 
is warm and the water is 
clear. I love the way the


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


In [214]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:\n\""""
torch.manual_seed(5)
output = generator.generate(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    free_tokens=['||', '?', '.', ','], 
    bad_words_ids=[[8876]],
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output, pt=False)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init, pt=False)}")
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

---prompt---
This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:
"


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
In the storm, the wind and 
the sea are one. The waves 
are the breath of the gods. As 
they break and fall back, they


----Syllables-----
Syllables per line in output: [6, 6, 7, 6]
Syllables per line in `text_init`: [6, 6, 7, 6]


# Deviations from specified meter

In [215]:
!pip install tqdm

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [241]:
import tqdm

prompt = """This a beautiful poem that uses descriptive, earthy language to describe the ocean during a storm:\n\""""
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
init_syllables = generator.calculate_syllable_budget(text_init)

output_syllables = []
for i in tqdm.tqdm(range(0, 100)):
    output = generator.generate(
        prompt = prompt,
        text_init = text_init,
        # syllable_budget = torch.Tensor([6., 6.]),
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        remove_invalid_values=True,
        do_sample=True,
        top_k=25,
        temperature=.7,
        max_length = 100,
        new_line_token='||',
        free_tokens=['||', '?', '.', ','], 
        bad_words_ids=[[8876]],
    )
    
    output_syllables.append(generator.calculate_syllable_budget(output))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:01<00:00,  3.02s/it]


In [242]:
stacked_output_syllables = torch.stack(output_syllables)
deviation = stacked_output_syllables - init_syllables

In [243]:
m = deviation.mean()
std = deviation.std()

print(f"The average line-level violation of the metric constraint is {m}.")
print(f"The vast majority of deviations are within +/- {std*2} syllables of the constraint.")

The average line-level violation of the metric constraint is -0.5174999833106995,       which suggests that the mechanism is unbiased but still flawed.
The vast majority of deviations are within +/- 1.824031949043274 syllables of the constraint.
