In [1]:
#Llama Tokenizer - Step 1
from transformers import MimiModel, AutoFeatureExtractor, AutoTokenizer
import torch
import numpy as np

class TextTokenizer:
    def __init__(self, name='Llama_tokenizer'):
        self.tokenizer = AutoTokenizer.from_pretrained(name, legacy=False)
        print("text vocab size", self.tokenizer.vocab_size)

    def encode(self, text: str):
        tokens = self.tokenizer.encode(text)
        return tokens

    def decode(self, tokens):
        return self.tokenizer.decode(tokens)
    
class MimiTokenizer:
    def __init__(self, device):    
        self.device = device
        self.model = MimiModel.from_pretrained("kyutai/mimi")
        self.model.to(device)
        self.model.eval()
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi", device=device)
        self.sampling_rate = self.feature_extractor.sampling_rate
        self.n_codebooks = 8
        self.vocab_size = 2048

    @torch.inference_mode()
    def encode(self, waveform):
        inputs = self.feature_extractor(raw_audio=waveform, 
                                        sampling_rate=self.sampling_rate, 
                                        return_tensors="pt").to(self.device)
            
        output = self.model.encode(inputs["input_values"], inputs["padding_mask"], num_quantizers=self.n_codebooks)
        tokens = output.audio_codes[0].cpu().numpy()
        return tokens

    def decode(self, tokens):
        assert len(tokens.shape) == 2
        tokens = torch.tensor(np.expand_dims(tokens, axis=0)).to(self.device)
        output = self.model.decode(tokens)
        waveform = output.audio_values.cpu()
        return waveform

2024-12-18 10:56:00.292601: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-18 10:56:00.300439: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734499560.309602 2387918 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734499560.312496 2387918 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-18 10:56:00.323024: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
import torch
import numpy as np
from transformers import AutoTokenizer

class TTSTokenizer:
    def __init__(self, text_tokenizer_name='tts_tokenizer', audio_tokenizer_name='tts_tokenizer'):
        self.text_tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_name, legacy=False)
        self.audio_tokenizer = AutoTokenizer.from_pretrained(audio_tokenizer_name, legacy=False)
        print("text vocab size", self.audio_tokenizer.vocab_size)

    def encode(self, input_data, add_special_tokens=True):
        if isinstance(input_data, str):
            encoded_tokens = self.text_tokenizer.encode(
                input_data, 
                return_tensors='pt', 
                add_special_tokens=add_special_tokens
            )
            return encoded_tokens
        elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
            encoded_tokens = self.audio_tokenizer.encode(
                input_data, 
                return_tensors='pt', 
                add_special_tokens=add_special_tokens
            )
            return encoded_tokens
        else:
            raise TypeError("Input must be a string or a list of strings")

    def decode(self, tokens):
        if not isinstance(tokens, torch.Tensor):
            raise TypeError("Input must be a torch tensor of tokens")
        
        try:
            decoded_text = self.text_tokenizer.decode(tokens)
            return decoded_text
        except:
            try:
                decoded_tokens = self.audio_tokenizer.decode(tokens)
                return torch.tensor(decoded_tokens)
            except:
                raise ValueError("Unable to decode the provided tokens")

In [3]:
tts_tokenizer = TTSTokenizer()
mimi_tokenizer = MimiTokenizer(device='cuda')

text vocab size 128000


  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


In [4]:
import torch
from llama_model import Llama, LlamaConfig

config = LlamaConfig(
    dim=1024,  
    n_layers=12,
    n_heads=16,
    vocab_size=144646,  
    max_seq_len=2048,  
    multiple_of=2048,
    use_scaled_rope=True
)
model = Llama(config).cuda()
state_dict = torch.load('models/llama_model_epoch_100.pth')
model.load_state_dict(state_dict)
model.eval()

  state_dict = torch.load('models/llama_model_epoch_100.pth')


Llama(
  (tok_embeddings): Embedding(144646, 1024)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=1024, out_features=1024, bias=False)
        (wk): Linear(in_features=1024, out_features=1024, bias=False)
        (wv): Linear(in_features=1024, out_features=1024, bias=False)
        (wo): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=1024, out_features=4096, bias=False)
        (w2): Linear(in_features=4096, out_features=1024, bias=False)
        (w3): Linear(in_features=1024, out_features=4096, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=1024, out_features=144646, bias=False)
)

In [5]:
import torch.nn.functional as F
import torch

@torch.inference_mode()
def generate_output_tokens(model, input_tokens, max_new_tokens, temperature=0.8, top_k=50, stop_token=144644):
    idx = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0).cuda()
    start_pos = 0
    
    for layer in model.layers:
        layer.attention.cache_k = torch.zeros_like(layer.attention.cache_k)
        layer.attention.cache_v = torch.zeros_like(layer.attention.cache_v)
    
    for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= model.config.max_seq_len else idx[:, -model.config.max_seq_len:]        
        h = model.tok_embeddings(idx_cond)        
        freqs_cis = model._prepare_rotary_embeddings(h, idx_cond.size(1))        
        seqlen = idx_cond.size(1)
        mask = torch.triu(torch.full((seqlen, seqlen), float('-inf'), device=h.device), diagonal=1)
        
        for layer in model.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        
        logits = model.norm(h)
        logits = model.output(logits)[:, -1, :] / temperature
        
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')
        
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        
        if stop_token is not None and idx_next.item() == stop_token:
            break
        
        idx = torch.cat((idx, idx_next), dim=1)
        start_pos += 1

    generated_tokens = idx[:, len(input_tokens):]
    return generated_tokens.squeeze().cpu()

In [6]:
input_token = [128000,  39628,     11,    304,    279,   1193,   5647,    449,    902,
           584,    527,    520,   3118,  11920,     11,  44642,    505,   1455,
           422,    539,    505,    682,    279,  19071,    323,  44948,  15609,
           304,    279,  68033, 144642, 144645, 144641]
output = generate_output_tokens(model, input_token, max_new_tokens=1024)
print(output)

tensor([129106, 141544, 142464,  ..., 130207, 129953, 129670])


In [7]:
print(tts_tokenizer.decode(output))

[aco_850][aco_13288][aco_14208][aco_11413][aco_8537][aco_4973][aco_5623][aco_4847][aco_3327][aco_3494][aco_3035][aco_2369][aco_798][aco_3865][aco_1697][aco_107][aco_107][aco_1414][aco_798][aco_994][aco_798][aco_471][aco_677][aco_60][aco_1414][aco_798][aco_2017][aco_1414][aco_2017][aco_1414][aco_677][aco_677][aco_2017][aco_1691][aco_612][aco_677][aco_1414][aco_1697][aco_1181][aco_1691][aco_994][aco_1697][aco_1048][aco_798][aco_1951][aco_1951][aco_107][aco_1048][aco_107][aco_612][aco_471][aco_1691][aco_107][aco_1048][aco_1951][aco_1181][aco_1181][aco_142][aco_1697][aco_994][aco_1697][aco_994][aco_471][aco_612][aco_1181][aco_107][aco_1048][aco_1691][aco_1048][aco_60][aco_612][aco_798][aco_1026][aco_994][aco_1048][aco_994][aco_1691][aco_677][aco_107][aco_1697][aco_1181][aco_1181][aco_1697][aco_1691][aco_1414][aco_2017][aco_994][aco_60][aco_1137][aco_1181][aco_1697][aco_1048][aco_974][aco_1691][aco_994][aco_107][aco_994][aco_1697][aco_994][aco_1181][aco_471][aco_994][aco_612][aco_994][aco_9

In [8]:
result = torch.tensor([128000,  39628,     11,    304,    279,   1193,   5647,    449,    902,
           584,    527,    520,   3118,  11920,     11,  44642,    505,   1455,
           422,    539,    505,    682,    279,  19071,    323,  44948,  15609,
           304,    279,  68033, 144642, 144645, 144641, 128995, 130325, 132834,
        135350, 138081, 138819, 140988, 143449, 128572, 131272, 133038, 135276,
        136552, 139416, 141865, 142522, 129117, 131597, 133406, 134146, 136322,
        139289, 141923, 143985, 129069, 131186, 133662, 135220, 136512, 139853,
        141760, 142410, 129620, 131843, 133121, 134814, 137436, 138387, 141084,
        142444, 128080, 131002, 133038, 135147, 137077, 139665, 141897, 142752,
        129134, 131203, 132924, 134668, 136548, 140002, 140409, 143451, 128715,
        130220, 133849, 135225, 136357, 139775, 140460, 143596, 129650, 131937,
        132126, 134545, 136933, 139077, 141937, 143864, 129532, 131639, 132479,
        134835, 137883, 138246, 141974, 143201, 128937, 131824, 133411, 134384,
        136248, 138246, 141565, 143949, 129077, 130323, 132375, 134147, 138069,
        139010, 141512, 142954, 129352, 130386, 133860, 134608, 138148, 140177,
        141688, 143239, 128749, 130955, 133781, 135938, 137985, 140133, 141270,
        143942, 128577, 130291, 133245, 134258, 137928, 138587, 141264, 142904,
        128670, 131159, 132503, 135148, 136914, 139760, 140828, 142475, 129701,
        130488, 133298, 134413, 136709, 138498, 140735, 142696, 129762, 131203,
        132992, 136027, 136402, 138246, 141713, 142997, 128570, 131397, 133382,
        134994, 136195, 139013, 142075, 142487, 129468, 130744, 133054, 135043,
        136663, 140173, 141663, 142607, 128157, 132044, 132786, 135857, 136657,
        138778, 140733, 142574, 128839, 131041, 133885, 134696, 136357, 139132,
        140508, 144299, 129457, 131094, 132982, 135497, 137717, 140221, 141218,
        142767, 128181, 131341, 133840, 135517, 136861, 140137, 141400, 142794,
        129022, 131672, 132922, 135561, 137960, 138869, 140476, 143155, 129037,
        130437, 132452, 134366, 136541, 138498, 140552, 143492, 129067, 131891,
        132160, 135811, 137641, 138246, 140940, 143208, 129132, 130387, 133891,
        136062, 137641, 139273, 141345, 143673, 129320, 131104, 133274, 134286,
        137706, 138808, 141526, 143887, 129607, 130291, 133245, 134258, 137207,
        138795, 142159, 143942, 129607, 131259, 133245, 134434, 136773, 139812,
        141797, 143942, 128663, 130291, 133655, 135492, 137107, 139683, 141526,
        144344, 129779, 130291, 133655, 135492, 136773, 139683, 142159, 144344,
        129301, 130291, 133655, 135492, 136459, 139683, 141526, 144080, 129618,
        131765, 133686, 135920, 136515, 138568, 140382, 142654, 129782, 130680,
        134088, 135388, 138099, 139164, 141007, 142836, 128304, 130689, 133573,
        135245, 138128, 140000, 141169, 144180, 129219, 131390, 134066, 134549,
        136840, 139839, 142026, 143399, 129640, 130582, 133627, 134822, 136510,
        138754, 141771, 143770, 129196, 130787, 134070, 136099, 136409, 138550,
        142227, 142524, 128505, 130783, 133420, 135114, 137860, 140104, 140948,
        144171, 129146, 130063, 132779, 135688, 137530, 139481, 141406, 142520,
        128784, 131671, 132450, 136045, 137666, 138890, 141700, 143080, 128917,
        131754, 132133, 134842, 137151, 139775, 142253, 143311, 129708, 130779,
        133640, 135447, 136465, 139175, 141048, 142408, 128922, 130386, 132886,
        134893, 136343, 139580, 141355, 142594, 129137, 130821, 132533, 135072,
        136226, 139759, 140997, 143102, 129358, 130868, 133712, 134398, 137393,
        139796, 140534, 143363, 128126, 130538, 132762, 134984, 136773, 139979,
        140954, 143887, 128832, 131259, 133793, 135492, 137928, 139809, 141526,
        144080, 129743, 130291, 133655, 135492, 137928, 139809, 141113, 144344,
        128663, 130291, 133245, 134308, 136773, 139270, 142159, 143942, 129544,
        131104, 132879, 135492, 137928, 139812, 141113, 144344, 128835, 130941,
        132602, 135049, 136666, 139613, 140860, 144269, 128337, 131187, 133089,
        134987, 136678, 139449, 140465, 143899, 129377, 131817, 133781, 134413,
        136337, 140221, 141484, 143470, 130024, 130741, 133627, 134670, 137965,
        138514, 141771, 142828, 129252, 130240, 133686, 134497, 138092, 140015,
        142160, 143434, 129260, 130871, 133216, 135834, 137580, 139135, 141863,
        142524, 128304, 130417, 132762, 135214, 138040, 139256, 141175, 142728,
        129039, 131825, 133669, 135608, 138197, 138794, 140372, 143522, 129640,
        130696, 133360, 136125, 138013, 138913, 141370, 142998, 128797, 131259,
        133598, 134286, 136459, 138587, 141526, 143749, 128262, 131259, 133655,
        135492, 136532, 139809, 141526, 144080, 128663, 130291, 133655, 135492,
        137928, 139809, 142266, 144344, 128384, 131397, 132762, 135841, 138197,
        138794, 141281, 144204, 130025, 130837, 133686, 134666, 137959, 139039,
        141433, 143179, 129774, 130979, 133778, 134741, 137010, 138806, 141633,
        143066, 128178, 130680, 133776, 134877, 136409, 139577, 140681, 143858,
        128129, 131251, 133595, 135833, 136396, 138246, 140669, 142598, 129669,
        130333, 132352, 134878, 136717, 138246, 141089, 143782, 128421, 130505,
        132348, 134838, 137561, 138741, 142305, 143603, 129519, 131958, 132468,
        135590, 136966, 139421, 141476, 142660, 129022, 130951, 133706, 135882,
        136521, 139121, 141580, 143632, 129562, 130376, 132313, 134835, 137641,
        138246, 140703, 143165, 129252, 131627, 133411, 134736, 136208, 139572,
        141000, 142947, 129782, 130468, 133482, 134796, 136859, 140087, 141757,
        143059, 128304, 131566, 133203, 135172, 136604, 139309, 141337, 143786,
        129882, 130079, 133471, 135332, 136465, 139874, 142303, 142508, 129721,
        130225, 133411, 135207, 137798, 138246, 141517, 143754, 129356, 131988,
        133442, 135049, 137706, 139295, 141301, 143426, 128096, 131763, 133583,
        135882, 138042, 139177, 140579, 143748, 128777, 132076, 133283, 135913,
        136326, 139039, 140993, 143849, 128505, 131396, 133917, 135805, 136482,
        139801, 140598, 143465, 128526, 131624, 134083, 135933, 137119, 138807,
        142219, 143542, 129146, 131526, 133208, 134889, 136540, 138609, 140881,
        143960, 129235, 130842, 134120, 135056, 136239, 139371, 142065, 143465,
        128981, 131914, 132919, 136140, 137531, 139670, 140380, 143349, 129708,
        130302, 132534, 136098, 137274, 138754, 142050, 143112, 129473, 130871,
        133060, 136084, 137224, 139800, 141985, 142724, 129358, 130417, 132419,
        134201, 137032, 139209, 142248, 144188, 128502, 131104, 133722, 134308,
        136773, 139270, 141800, 144344, 129268, 131552, 132208, 134758, 137754,
        138645, 140963, 143949, 128885, 130731, 133042, 134676, 138067, 138689,
        141153, 142518, 129558, 130562, 132762, 135147, 136479, 138505, 141154,
        142954, 129252, 130302, 133623, 135723, 136773, 138820, 141661, 142916,
        129260, 132080, 132886, 135635, 137689, 140076, 141626, 143059, 128304,
        132035, 132754, 135963, 137388, 139734, 140330, 143591, 129631, 130914,
        133663, 134275, 137724, 139944, 141107, 142904, 129565, 130486, 132722,
        136154, 136729, 138547, 141040, 142487, 144644],dtype=torch.int32)

print(tts_tokenizer.decode(result))

<|begin_of_text|>Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition[convert][spkr_unk][mimi][aco_739][aco_2069][aco_4578][aco_7094][aco_9825][aco_10563][aco_12732][aco_15193][aco_316][aco_3016][aco_4782][aco_7020][aco_8296][aco_11160][aco_13609][aco_14266][aco_861][aco_3341][aco_5150][aco_5890][aco_8066][aco_11033][aco_13667][aco_15729][aco_813][aco_2930][aco_5406][aco_6964][aco_8256][aco_11597][aco_13504][aco_14154][aco_1364][aco_3587][aco_4865][aco_6558][aco_9180][aco_10131][aco_12828][aco_14188]<|reserved_special_token_72|>[aco_2746][aco_4782][aco_6891][aco_8821][aco_11409][aco_13641][aco_14496][aco_878][aco_2947][aco_4668][aco_6412][aco_8292][aco_11746][aco_12153][aco_15195][aco_459][aco_1964][aco_5593][aco_6969][aco_8101][aco_11519][aco_12204][aco_15340][aco_1394][aco_3681][aco_3870][aco_6289][aco_8677][aco_10821][aco_13681][aco_15608][aco_1276][aco_3383][aco_4223][aco_6579][aco_96

In [None]:
def deflatten_tokens(tokens, n_codebooks, per_codebook_size):
    arr = []
    for i in range(n_codebooks):
        arr.append(tokens[i::n_codebooks])
    acoustic_tokens = np.stack(arr)
    return acoustic_tokens