In [134]:
from transformers import AutoTokenizer 
from itertools import groupby
import numpy as np
import string
import re

class HybridPhonemeTokenizer:
    def __init__(self,
        tokenizer_eng = 'parler-tts/parler-tts-mini-v1',
        tokenizer_g2p = 'therealvul/tokenizer_g2pen_v2',
        eng_special = {
            'pad_token': "<pad>",
            'sep_token': "</s>",
            'unk_token': "<unk>",
        },
        g2p_special = {
            'unk_token': "[UNK]",
            'pad_token': "[PAD]",
            'cls_token': "[CLS]",
            'sep_token': "[SEP]",
            'mask_token':"[MASk]",
        },
         **kwargs):

        self.tokenizer_eng = AutoTokenizer.from_pretrained(
            tokenizer_eng, **eng_special)
        self.tokenizer_g2p = AutoTokenizer.from_pretrained(
            tokenizer_g2p, **g2p_special)

        # Not sure if this is actually necessary - ByteLevel pretokenizer
        # removes possibility of <unk> tokens
        self.special_tokens = {
            self.tokenizer_g2p.pad_token_id: self.tokenizer_eng.pad_token_id,
            self.tokenizer_g2p.bos_token_id: self.tokenizer_eng.bos_token_id,
            self.tokenizer_g2p.cls_token_id: self.tokenizer_eng.cls_token_id,
            self.tokenizer_g2p.eos_token_id: self.tokenizer_eng.eos_token_id,
            self.tokenizer_g2p.unk_token_id: self.tokenizer_eng.unk_token_id,
            self.tokenizer_g2p.mask_token_id: self.tokenizer_eng.mask_token_id
        }

        self.g2p_offset = self.prune_tokens()

    def __call__(self, text):
        parts = re.split(r'({.*?})', text)
        result = []
        for i, part in enumerate(parts):
            if not len(part):
                continue
            part = part.strip()
            if not (part.startswith('{') and part.endswith('}')):
                ids = self.tokenizer_eng(part)['input_ids']
                result += [i for i in ids]
            else:
                ids = self.tokenizer_g2p(part[1:-1])['input_ids']
                print(ids)
                for i,id in enumerate(ids):
                    if id in self.special_tokens:
                        ids[i] = self.special_tokens[id] - self.g2p_offset
                result += [i + self.g2p_offset for i in ids]
        return {'input_ids': result, 'attention_mask': list(np.ones_like(result))}

    def prune_tokens(self):
        # Prune the tokenizer here?
        g2p_offset = len(self.tokenizer_eng.get_vocab())
        return g2p_offset
    
    # Returns string constructed from decoded tokens with space handling
    def _list_decode(self, input_ids, skip_special_tokens=False):
        decode_args = {
            'clean_up_tokenization_spaces': True,
            'skip_special_tokens': skip_special_tokens
        }
        output = ''
        for key, group in groupby(input_ids,
            key=lambda x: x >= self.g2p_offset):
            g = list(group)
            if key:
                if len(output) == 0 or output[-1] != ' ':
                    output += ' '
                output += '{'
                output += self.tokenizer_g2p.decode(
                    [i - self.g2p_offset for i in g],
                     **decode_args)
                output += '}'
            else:
                decoded = self.tokenizer_eng.decode(
                    g, **decode_args)
                if len(output) and output[-1] == '}':
                    if len(decoded) and not decoded[0] in string.punctuation:
                        output += ' '
                output += decoded
        return output.strip()

    # Returns list of string tokens with no space handling
    def _decode_tokens(self, input_ids, skip_special_tokens=False):
        toks = []
        for key, group in groupby(input_ids,
            key=lambda x: x >= self.g2p_offset):
            g = list(group)
            if key:
                toks.extend([self.tokenizer_g2p.decode(i) for i in g])
            else:
                toks.extend([self.tokenizer_eng.decode(i) for i in g])
        return toks
    
    def batch_decode(self, input_ids, skip_special_tokens=False):
        if not isinstance(input_ids[0], list):
            return self._decode_tokens(input_ids)

        return [self._list_decode(l, skip_special_tokens) for l in input_ids]

prompt_tokenizer = HybridPhonemeTokenizer()

In [144]:
text = "And my heaven will be a big heaven."
input_ids = prompt_tokenizer(text)['input_ids']
decoded = prompt_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
print(input_ids)
print(decoded)

text = "It's {S OW0 M AH0 CH} larger than life."
input_ids = prompt_tokenizer(text)['input_ids']
decoded = prompt_tokenizer.batch_decode([input_ids], skip_special_tokens=True)
print(input_ids)
print(decoded)
print(prompt_tokenizer(decoded[0])['input_ids'])
print(prompt_tokenizer.batch_decode([prompt_tokenizer(decoded[0])['input_ids']], 
    skip_special_tokens=True))

[275, 82, 9922, 56, 36, 3, 9, 600, 9922, 5, 1]
['And', 'my', 'heaven', 'will', 'be', '', 'a', 'big', 'heaven', '.', '</s>']
[40, 61, 254, 90, 65, 100]
[94, 31, 7, 1, 32140, 32161, 32354, 32190, 32165, 32200, 2186, 145, 280, 5, 1]
["It's { S OW0 M AH0 CH} larger than life."]
[40, 61, 254, 90, 65, 100]
[94, 31, 7, 1, 32140, 32161, 32354, 32190, 32165, 32200, 2186, 145, 280, 5, 1]
[40, 61, 254, 90, 65, 100]
["It's { S OW0 M AH0 CH} larger than life."]


In [101]:
from transformers import AutoTokenizer
tokenizer_eng = AutoTokenizer.from_pretrained('parler-tts/parler-tts-mini-v1')
tokenizer_g2p = AutoTokenizer.from_pretrained('tokenizer_g2p_v2')

In [141]:
tokenizers = [tokenizer_eng, tokenizer_g2p, prompt_tokenizer]
compare_prompts = ["Hi there!", " Whoa there! ", "{S OW0 }", "It's {S OW0 M AH0 CH OW0}"]
for p in compare_prompts:
    for t in tokenizers:
        tokenized = t(p.strip())
        print(tokenized)
        print(t.batch_decode([tokenized['input_ids']], skip_special_tokens=True))

{'input_ids': [2018, 132, 55, 1], 'attention_mask': [1, 1, 1, 1]}
['Hi there!']
{'input_ids': [40, 21, 40, 5], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}
[' H!']
{'input_ids': [2018, 132, 55, 1], 'attention_mask': [1, 1, 1, 1]}
['Hi there!']
{'input_ids': [2645, 9, 132, 55, 1], 'attention_mask': [1, 1, 1, 1, 1]}
['Whoa there!']
{'input_ids': [130, 40, 5], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}
[' W!']
{'input_ids': [2645, 9, 132, 55, 1], 'attention_mask': [1, 1, 1, 1, 1]}
['Whoa there!']
{'input_ids': [3, 2, 134, 3, 15251, 632, 3, 2, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
['S OW0 ']
{'input_ids': [40, 61, 254], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}
[' S OW0 ']
[40, 61, 254]
{'input_ids': [32140, 32161, 32354], 'attention_mask': [1, 1, 1]}
['{ S OW0 }']
{'input_ids': [94, 31, 7, 3, 2, 134, 3, 15251, 632, 283, 3, 14084, 632, 9302, 3, 15251, 632, 2, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1