In [14]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')


In [1]:
import pronouncing

from transformers.generation.logits_process import LogitsProcessorList
import torch

from transformers.testing_utils import require_torch, torch_device
import random

from transformers.generation.logits_process import MinLengthLogitsProcessor, TopKLogitsWarper


To minimize computational cost and latency, we'll develop with `gpt2`

In [2]:
%%capture
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

In [5]:
# %%capture
# !pip install unittest

Define some functions for getting syllables

In [387]:
def cmu_syllable_counter(word):
    """
    Returns inf for OOV tokens.
    Note: This prohibits things like numbers and punctuation. Very naive and dumb.
    """
    pronunciation_list = pronouncing.phones_for_word(word)
    if len(pronunciation_list) > 0:
        syllable_count = pronouncing.syllable_count(pronunciation_list[0])
    else:
        return float("Inf")
    
    return syllable_count

def syllable_mapper(vocab):
    syllable_map = {}
    for token, idx in vocab.items():
        n_syllables = cmu_syllable_counter(token)
        try:
            syllable_map[n_syllables].append(idx)
        except KeyError:
            syllable_map[n_syllables] = [idx]
    return syllable_map

def token_syllable_scores(tokenizer, pt=True, free_tokens=['\n', '!', ',', ':', '?', ';', ' ']):
    """
    Returns list or torch tensor of size==tokenizer.vocab_size where element i is the count of syllables
    for token i.
    """
    sorted_vocab = {k: v for k, v in sorted(tokenizer.vocab.items(), key=lambda item: item[1])}
    syllable_scores = []
    for token, idx in sorted_vocab.items():

        # Have to decode the vocab item to deal with special characters, e.g. '\n' is represented as 'Ċ'
        
   
        decoded_token = tokenizer.decode(sorted_vocab[token])
        if decoded_token != ' ':
            decoded_token = decoded_token.strip()
        if decoded_token not in free_tokens:
            n_syllables = cmu_syllable_counter(decoded_token)
        else:
            n_syllables = 0

        syllable_scores.append(n_syllables)
    
    # Scoring is wrong for '\n' patching
    syllable_scores[tokenizer('\n')['input_ids'][0]] = 0
        
    if pt:
        return torch.Tensor(syllable_scores)
    return syllable_scores

In [386]:
tokenizer('\n')['input_ids'][0]

198

In [383]:
tokenizer('\n')['input_ids'][0]

198

In [384]:
tokenizer.decode(198)

'\n'

In [385]:
def test_syllable_mapper():
    vocab = {'living': 0, 'on': 1, 'the': 2, "road": 3}
    
    def convert_ids_to_tokens(token_id, vocab):
        return [token for token in vocab if vocab[token]==token_id][0]
    
    syllable_map = syllable_mapper(vocab)
    assert len(syllable_map[1]) == 3
    assert convert_ids_to_tokens(syllable_map[2][0], vocab) == 'living'
    


class TestTokenizer():
    
    def __init__(self):
        self.vocab = {'living': 0, 'on': 1, 'the': 2, "road": 5, "!": 4, "corpus": 3}
    
    def decode(self, token_id): 
        return [k for k,v in self.vocab.items() if v==token_id][0]
    
def test_token_syllable_scores():
    tokenizer = TestTokenizer()
    vocab = {'living': 0, 'on': 1, 'the': 2, "road": 5, "!": 4, "corpus": 3}
    
    syllable_scores = token_syllable_scores(tokenizer, pt=False)
    assert syllable_scores[0] == 2
    assert syllable_scores[1] == 1
    assert syllable_scores[3] == 2
    assert syllable_scores[4] == 0
    
test_syllable_mapper()
test_token_syllable_scores()


TypeError: 'TestTokenizer' object is not callable

In [388]:
syllable_scores = token_syllable_scores(tokenizer, pt=True)

In [389]:
syllable_scores[198]

tensor(0.)

Let's demonstrate how this works

In [9]:
# Bug now that max is inf
inf_mask = syllable_scores == float("Inf")
temp_scores = syllable_scores.masked_fill(inf_mask, -float('Inf'))
max_word = temp_scores.argmax()
word = tokenizer.decode(max_word.item())
print(f"A word with the max number of syllables is '{word.strip()}', which has {cmu_syllable_counter(word)} syllables. It has a score of {temp_scores[max_word.item()]}.")
                     

A word with the max number of syllables is 'homosexuality', which has inf syllables. It has a score of 7.0.


In [10]:
import transformers

In [11]:
text = """Happy birthday to you
Happy birthday to you
Happy birthday dear John"""


In [12]:
text_init = """Happy birthday to you"""

Here's some code I wrote based on the `poesy` package that returns syllable counts by line.

In [15]:
# from poesy import Poem
from bragi.verse_parsers import PoesyParsedVerseHandler
verse_handler = PoesyParsedVerseHandler()
_, syllable_budget = verse_handler.example(text_init)
syllable_budget = torch.Tensor(syllable_budget)
print('text_init: ', f'"{text_init}"', '\nSyllable budget', syllable_budget)

text_init:  "Happy birthday to you" 
Syllable budget tensor([6.])


Here, I define a LogitsWarper that tracks syllable counts by line.

In [109]:
tokenizer('\n', return_tensors='pt')['input_ids'].item()

198

In [110]:
tokenizer.decode(198)

'\n'

In [782]:
from transformers.generation.logits_process import LogitsWarper

# TODO
# 1. Add support for beam search
class SyllableRestrictionWarper(LogitsWarper):
    def __init__(
        self, 
        prompt: str,
        tokenizer: transformers.PreTrainedTokenizerFast,
        syllable_budget: torch.Tensor,
        syllable_scorer: callable,
        filter_value: float = -float("Inf"), 
        min_tokens_to_keep: int = 1,
        free_tokens=['!', ',', ':', '?', ';', ' ',],
        new_line_token = '\n',
        num_beams: int = 10,
    ):
#         if not isinstance(syllable_budget, int) or syllable_budget < 0:
#             raise ValueError(f"`syllable_budget` has to be a strictly positive or zero integer, but is {syllable_budget}")

        self.syllable_budget = syllable_budget.repeat(num_beams)
        self.filter_value = filter_value
        self.syllable_scores = syllable_scorer(tokenizer, pt=True, free_tokens=free_tokens)#.repeat(num_beams, 1)
        
        self.new_line_token = '\n'
        self.new_line_token_id = tokenizer('\n', return_tensors='pt')['input_ids'].item()
        
        print('prompt in warper ', prompt)
        
        self.prompt_offset = tokenizer(prompt, return_tensors='pt')['input_ids'].shape[1]
        
        self.line_number = 0
        self.line_budget = syllable_budget.shape[0] - 1
        print(f"The line budget is: {self.line_budget}")
        
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        batch_size = scores.shape[0]


        syllable_scores = self.syllable_scores
        
        if self.line_number > self.syllable_budget.shape[0] - 1:
            syllable_budget = torch.Tensor([0])
        else:
            syllable_budget = self.syllable_budget[self.line_number, None]

        # Update syllable budget 
        syllable_budget = self.update_syllable_budget(input_ids, syllable_budget, syllable_scores)
        
        
        # Remove all tokens with more syllables than `syllable_budget`
        syllable_scores = syllable_scores.repeat(batch_size, 1)
        indices_to_remove =  syllable_scores > syllable_budget[:,None]
        
        # Check if line has been completed
        line_completed = syllable_budget <= 0
        scores = scores.masked_fill(indices_to_remove, self.filter_value)

        if True in line_completed:
  
            
            # Force EOS if line budget is spent
            if line_budget < 0 or self.line_number > self.syllable_budget.shape[0]:

                indices_to_remove[line_completed,:] = torch.full_like(indices_to_remove[line_completed,:], True)
                scores[line_completed, tokenizer.eos_token_id] = -1
            # Otherwise, force new line and move to next line budget
            else:
                scores[line_completed,  self.new_line_token_id] = -1
                
            self.line_number += 1
            self.line_budget -= 1

        return scores

    
    def update_syllable_budget(self, input_ids, syllable_budget, syllable_scores):
        if input_ids.shape[1] > self.prompt_offset:
            syllable_cost = syllable_scores[input_ids[:,-1]]
            syllable_budget -= syllable_cost

        return syllable_budget 

In [783]:
# from typing import Optional, List

# from bragi.verse_parsers import PoesyParsedVerseHandler


# class MetricGenerator():
#     def __init__(
#         self, 
#         model, 
#         tokenizer, 
#         syllable_scorer
#     ):
#         self.model = model
#         self.tokenizer = tokenizer
#         self.syllable_scorer = syllable_scorer
#         self.verse_handler = verse_handler = PoesyParsedVerseHandler()
        
    
#     def _calculate_syllable_budget(
#         self,
#         text_init
#     ):
        
#         _, syllable_budget = self.verse_handler.example(text_init)
#         syllable_budget = torch.Tensor(syllable_budget)
#         return syllable_budget
    
#     def generate(
#         self, 
#         prompt, 
#         text_init: Optional[str] = None,
#         syllable_budget: Optional[torch.Tensor] = None, 
#         free_tokens: Optional[List] = ['\n', '!', ',', ':', '?', ';', ' '],
#         num_beams: Optional[int] = 1, 
#         **kwargs
#     ):
        
#         if text_init and syllable_budget:
#             raise Error("You cannot specify both `syllable_budget` and `text_init`. Choose one or the other.")
        
#         if not text_init and not torch.is_tensor(syllable_budget):
#             raise Error("You must provide either `syllable_budget` or `text_init`.")
        
#         if text_init:
#             syllable_budget = self._calculate_syllable_budget(text_init)
            
#         processors = LogitsProcessorList()
        
#         processors.append(
#             SyllableRestrictionWarper(
#                 tokenizer=tokenizer,
#                 syllable_budget=syllable_budget,
#                 syllable_scorer=syllable_scorer,
#                 free_tokens=free_tokens,
#                 num_beams = num_beams,
#                 prompt = prompt,
#             )

#         )
        
#         input_ids = tokenizer(prompt, return_tensors="pt").input_ids

#         outputs = model.generate(
#             input_ids,
#             num_beams=num_beams,
#             logits_processor=processors,
#             **kwargs
#         )

#         return outputs

In [845]:
syllable_scorer

<function __main__.token_syllable_scores(tokenizer, pt=True, free_tokens=['\n', '!', ',', ':', '?', ';', ' '])>

In [877]:
device = 'cuda:1'
model = model.to(device)

In [881]:
from bragi.metric_generator import MetricGenerator
generator = MetricGenerator(model=model, tokenizer=tokenizer, device=device)#, syllable_scorer=syllable_scorer)

In [882]:
# from bragi.verse_parsers import token_syllable_scores


In [897]:
text_init = "Crappy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"
prompt = """Happy birthday to you\n"""

output = generator.generate(
    prompt = prompt,
    text_init = text_init,
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
    do_sample=True,
    max_length = 200,
)

print(output)
# print(tokenizer.decode(output[0], skip_special_tokens=True).strip())

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I really loved this book
you should buy it: It's
the best i love about all
about your little girl



In [898]:
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")

Syllables per line in output: tensor([6., 5., 7., 6.])
Syllables per line in `text_init`: tensor([6., 6., 7., 6.])


In [715]:
tokenizer = tokenizer
syllable_budget = torch.Tensor([6., 6.])
syllable_scorer = token_syllable_scores
free_tokens=['\n', '!', ',', ':', '?', ';', ' ']
prompt = """Happy birthday to you\n"""
num_beams = 1

processors = LogitsProcessorList()
processors.append(
    SyllableRestrictionWarper(
        tokenizer=tokenizer,
        syllable_budget=syllable_budget,
        syllable_scorer=syllable_scorer,
        free_tokens=free_tokens,
        num_beams = num_beams,
        prompt = prompt,
    )

)

input_ids = tokenizer(prompt, return_tensors="pt").input_ids

outputs = model.generate(
    input_ids,
    num_beams=num_beams,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
    logits_processor=processors,
    do_sample=False
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


prompt in warper  Happy birthday to you

The line budget is: 1
Output:
----------------------------------------------------------------------------------------------------
Happy birthday to you
I'm so happy that
You are my friend, I love




In [641]:
text_init = "Happy birthday to you,\nHappy birthday to you,\nHappy birthday dear Marvin,\nHappy birthday to you"

# from poesy import Poem
from bragi.verse_parsers import PoesyParsedVerseHandler
verse_handler = PoesyParsedVerseHandler()
_, syllable_budget = verse_handler.example(text_init)
syllable_budget = torch.Tensor(syllable_budget)
print('text_init: ', f'"{text_init}"', '\nSyllable budget', syllable_budget)

text_init:  "Happy birthday to you,
Happy birthday to you,
Happy birthday dear Marvin,
Happy birthday to you" 
Syllable budget tensor([6., 6., 7., 6.])


In [None]:
class TopKLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
    Args:
        top_k (`int`):
            The number of highest probability vocabulary tokens to keep for top-k-filtering.
        filter_value (`float`, *optional*, defaults to `-float("Inf")`):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """

    def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

        self.top_k = max(top_k, min_tokens_to_keep)
        self.filter_value = filter_value

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        top_k = min(self.top_k, scores.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores

In [48]:
import unittest

global_rng = random.Random()

def ids_tensor(shape, vocab_size, rng=None, name=None):
    #  Creates a random int32 tensor of the shape within the vocab size
    if rng is None:
        rng = global_rng

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()


class LogitsProcessorTest(unittest.TestCase):
    def _get_uniform_logits(self, batch_size: int, length: int):
        scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
        return scores
    
    def test_min_length_dist_processor(self):
            vocab_size = 20
            batch_size = 4
            eos_token_id = 0

            min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)

            # check that min length is applied at length 5
            input_ids = ids_tensor((batch_size, 5), vocab_size=20)
            scores = self._get_uniform_logits(batch_size, vocab_size)
            scores_before_min_length = min_dist_processor(input_ids, scores)
            assert scores_before_min_length[:, eos_token_id].tolist() == 4 * [-float("inf")]

            # check that min length is not applied anymore at length 15
            input_ids = ids_tensor((batch_size, 15), vocab_size=20)
            scores = self._get_uniform_logits(batch_size, vocab_size)
            scores_before_min_length = min_dist_processor(input_ids, scores)
            assert not torch.isinf(scores_before_min_length).any()
            
    def test_top_k_dist_warper(self):
            input_ids = None
            vocab_size = 10
            batch_size = 2

            # create ramp distribution
            ramp_logits = (
                torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
            )
            ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size

            top_k_warp = TopKLogitsWarper(3)

            scores = top_k_warp(input_ids, ramp_logits)
            
            return scores

            # check that correct tokens are filtered
            self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
            self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])

            # check special cases
            length = 5

            logits = self._get_uniform_logits(batch_size=batch_size, length=length)
            top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)

            scores = top_k_warp_safety_check(input_ids, logits)
            # uniform dist is not changed
            self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])

            ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
            scores = top_k_warp_safety_check(input_ids, ramp_logits)

            # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
            self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])


In [47]:
lp_test = LogitsProcessorTest()
lp_test.test_min_length_dist_processor()
lp_test.test_top_k_dist_warper()
