In [13]:
from transformers import AutoTokenizer 
from itertools import groupby
import numpy as np
import re

class HybridPhonemeTokenizer:
    def __init__(self,
        tokenizer_eng = 'parler-tts/parler-tts-mini-v1',
        tokenizer_g2p = 'therealvul/tokenizer_g2pen_v2',
        eng_prefix = 'eng:',
        g2p_prefix = 'g2p:',
        # maps eng tokens to g2p
        # TODO special tokens
        special_tokens = {
            "<pad>": "[PAD]",
            "</s>": "[SEP]",
            "<unk>": "[UNK]",
            "<extra_id_99>": "[CLS]",
            "<extra_id_98>": "[MASK]",},
         **kwargs):

        self.tokenizer_eng = AutoTokenizer.from_pretrained(
            tokenizer_eng)
        self.tokenizer_g2p = AutoTokenizer.from_pretrained(
            tokenizer_g2p)
        self.eng_prefix = eng_prefix
        self.g2p_prefix = g2p_prefix
        self.special_tokens = special_tokens
        self.g2p_offset = self.prune_tokens()

    def __call__(self, text):
        parts = re.split(r'({.*?})', text)
        result = []
        for i, part in enumerate(parts):
            if not len(part):
                continue
            part = part.strip()
            if not (part.startswith('{') and part.endswith('}')):
                ids = self.tokenizer_eng(part)['input_ids']
                result += [i for i in ids]
            else:
                ids = self.tokenizer_g2p(part[1:-1])['input_ids']
                result += [i + self.g2p_offset for i in ids]
        return {'input_ids': result, 'attention_mask': list(np.ones_like(result))}

    def prune_tokens(self):
        # Prune the tokenizer here?
        g2p_offset = len(self.tokenizer_eng.get_vocab())
        return g2p_offset
    
    def _list_decode(self, input_ids, skip_special_tokens=False):
        decode_args = {
            'clean_up_tokenization_spaces': True,
            'skip_special_tokens': skip_special_tokens
        }
        output = ''
        for key, group in groupby(input_ids,
            key=lambda x: x >= self.g2p_offset):
            if key:
                if len(output) == 0 or output[-1] != ' ':
                    output += ' '
                output += '{'
                output += self.tokenizer_g2p.decode(
                    [i - self.g2p_offset for i in list(group)],
                     **decode_args)
                output += '} '
            else:
                output += self.tokenizer_eng.decode(
                    list(group), **decode_args)
        return output.strip()
    
    def batch_decode(self, input_ids, skip_special_tokens=False):
        if not isinstance(input_ids[0], list):
            input_ids = [input_ids] # TODO Not correct

        return [self._list_decode(l, skip_special_tokens) for l in input_ids]

prompt_tokenizer = HybridPhonemeTokenizer(
    tokenizer_eng='meta-llama/Meta-Llama-3.1-8B'
)

In [14]:
text = "It's {S OW0 M AH0 CH}; It's a {B IH0 G T AY0 M}"
input_ids = prompt_tokenizer(text)['input_ids']
decoded = prompt_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
print(input_ids)
print(decoded)
print(prompt_tokenizer(decoded[0])['input_ids'])
print(prompt_tokenizer.batch_decode(prompt_tokenizer(decoded[0])['input_ids'], 
    skip_special_tokens=True))

[128000, 2181, 596, 128296, 128317, 128510, 128346, 128321, 128356, 128000, 26, 1102, 596, 264, 128404, 128296, 128304, 128645, 128296, 128311, 128270, 128292, 128315, 128282]
["It's { S OW0 M AH0 CH} ; It's a { B IH0 G T AY0 M}"]
[128000, 2181, 596, 128296, 128317, 128510, 128346, 128321, 128356, 128000, 26, 1102, 596, 264, 128404, 128296, 128304, 128645, 128296, 128311, 128270, 128292, 128315, 128282]
["It's { S OW0 M AH0 CH} ; It's a { B IH0 G T AY0 M}"]


In [15]:
from transformers import AutoTokenizer
tokenizer_eng = AutoTokenizer.from_pretrained('parler-tts/parler-tts-mini-v1')
tokenizer_g2p = AutoTokenizer.from_pretrained('tokenizer_g2p_v2')

In [16]:
tokenizers = [tokenizer_eng, tokenizer_g2p, prompt_tokenizer]
compare_prompts = ["Hi there!", " Whoa there! ", "{S OW0 }", "It's {S OW0 M AH0 CH OW0}"]
for p in compare_prompts:
    for t in tokenizers:
        tokenized = t(p.strip())
        print(tokenized)
        print(t.batch_decode(tokenized['input_ids'], skip_special_tokens=True))
        #print(t.batch_decode([tokenized['input_ids']], skip_special_tokens=True))

{'input_ids': [2018, 132, 55, 1], 'attention_mask': [1, 1, 1, 1]}
['Hi', 'there', '!', '']
{'input_ids': [40, 21, 40, 5], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}
[' ', 'H', ' ', '!']
{'input_ids': [128000, 13347, 1070, 0], 'attention_mask': [1, 1, 1, 1]}
['Hi there!']
{'input_ids': [2645, 9, 132, 55, 1], 'attention_mask': [1, 1, 1, 1, 1]}
['Who', 'a', 'there', '!', '']
{'input_ids': [130, 40, 5], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}
[' W', ' ', '!']
{'input_ids': [128000, 15546, 64, 1070, 0], 'attention_mask': [1, 1, 1, 1, 1]}
['Whoa there!']
{'input_ids': [3, 2, 134, 3, 15251, 632, 3, 2, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
['', '', 'S', '', 'OW', '0', '', '', '']
{'input_ids': [40, 61, 254], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}
[' ', 'S ', 'OW0 ']
{'input_ids': [128296, 128317, 128510], 'attention_mask': [1, 1, 1]}
['{ S OW0 }']
{'input_ids': [94, 31, 7, 3, 2, 134, 3, 15251, 632, 283, 3, 14084, 632, 9302,