In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

# Generating verse with a constrained vocabulary

For many reasons, we might want to constrain the vocabulary of our model at inference time. For example, we might want to only sample tokens that are common English words. Or, perhaps we want to introduce hard thematic constraints by including only words from a certain domain. `MetricGenerator` allows us to do this easily with the `tokens_to_include` argument.

To demonstrate, we'll use GPT2 to generate a song with and without constrained vocabularies. 

### Environment initialization

First, let's setup our environment.

In [56]:
import torch
from bragi.metric_generator import MetricGenerator
from transformers import AutoTokenizer, AutoModelForCausalLM

device = 'cpu'
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

Initialize a `MetricGenerator` instance.

In [6]:
generator = MetricGenerator(model=model, tokenizer=tokenizer, device=device)

Now let's generate lyrics without a lexical constraint.

In [59]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """This is a song about dogs:\n"""

torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    free_tokens=['||', '?', '.', ','],
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    bad_words_ids=[[8876]],
)

print('----output-----')
print(output)
print('\n')

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


----output-----
I heard there were four of
a four hundred dogs on
it was hard for me to get
to the end of the day




Now, let's try constraining our vocabulary to something silly.

In [61]:
tokens_to_include=["dogs", "run", "pant", "lick", "jump", "to", "from", "fast", "now", "good"]

torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    free_tokens=['||', '?', '.', ','],
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    bad_words_ids=[[8876]],
    tokens_to_include=tokens_to_include
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


---prompt---
This is a song about dogs:


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
To dogs, to dogs. Dogs, dogs
. To dogs? To Dogs. From dogs
to Dogs? From Dogs to? To? From
To?? to? from to to To


----Syllables-----
Syllables per line in output: tensor([6., 6., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


Neat! Now let's try something a little more expansive. We'll restrict the model vocabulary to the 10k most common English words. To do that, I'll pull a common words list to use as a lexicon.

In [62]:
import urllib  

# Pull 100,000 most frequently used English words from this resource
# Note, it's old and not necessarily authoritative. This is just an example
lexicon_url = "https://gist.githubusercontent.com/h3xx/1976236/raw/bbabb412261386673eff521dddbe1dc815373b1d/wiki-100k.txt"

# Read data
data = urllib.request.urlopen(lexicon_url) # it's a file like object and works just like a file
tokens_to_include = []
for line in data: # files are iterable
    line = line.decode('utf8')
    if '#' not in line:
        # We'll remove new lines and lower case each word.
        tokens_to_include.append(line.replace('\n', '').lower())
        
# Since we lowered words, we probably have duplicates.
tokens_to_include = list(set(words))


In [65]:
top10k_words=tokens_to_include[0:10000]

In [67]:
tokens_to_include=tokens_to_include[0:1000]

torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    free_tokens=['||', '?', '.', ','],
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    # top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    bad_words_ids=[[8876]],
    tokens_to_include=top10k_words
)

print('---prompt---')
print(prompt.strip())
print('\n')

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


---prompt---
This is a song about dogs:


---text_init----
Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you


----output-----
The song, Love, was about
The dogs love. They love the
love they were.They love their love
Love. The dogs loves their dogs


----Syllables-----
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Syllables per line in output: tensor([6., 6., 8., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])
