In [1]:
import sys
sys.path.append('..')

import os
import math
import torch
import random
import numpy as np
from pathlib import Path
from torch.cuda import empty_cache
from IPython.display import Audio
import matplotlib.pyplot as plt

from commons import DEVICE, CACHE_DIR, CTX,  TEXT, MIMI, SPEAKER_FILE, CONVERT
from commons import Config as cfg
from omni.hfload import convert_to_hf
from omni.tokenlib import get_tokenizer
from omni.gpt2_model import GPT, GPTConfig
import json

Cache directory at:  /home/apurva/.cache/indri
Gap tokens:  1519


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from train_with_mimi import get_text_tokenizer

2024-10-03 08:11:11.%f | INFO     | train_with_mimi.py:16 | {'__module__': 'commons', 'coarse_codebooks': 2, 'per_codebook_size': 1024, 'VOCAB_SIZES': {'text': 50257, 'mimi': 8192}, 'OFFSET': {'text': 0, 'mimi': 50257}, 'TASK_TOKENS': {'convert': '[convert]', 'continue': '[continue]'}, 'MODALITY_TOKENS': {'text': '[text]', 'mimi': '[mimi]'}, 'UNKNOWN_SPEAKER_ID': '[spkr_unk]', 'STOP_TOKEN': '[stop]', 'VOCAB_SIZE': 59968, '__dict__': <attribute '__dict__' of 'Config' objects>, '__weakref__': <attribute '__weakref__' of 'Config' objects>, '__doc__': None}


In [3]:
DEVICE = 'cuda:0'

In [4]:
omni_model = convert_to_hf('/home/apurva/Downloads/gpt_136000.pt', device='cuda:0')

  custom_gpt = torch.load(path, map_location=device)['model']
  custom_gpt_config = torch.load(path, map_location=device)['config']


loaded config GPTConfig(block_size=1024, vocab_size=59968, n_layer=36, n_head=20, n_embd=1280, dropout=0.0, bias=True)


In [5]:
from datasets import load_dataset, Audio
from transformers import MimiModel, AutoFeatureExtractor
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# load model and feature extractor
model = MimiModel.from_pretrained("kyutai/mimi")
feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")

  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


In [6]:
text_tokenizer = get_text_tokenizer()

text vocab size 50257


In [7]:
convert_token = text_tokenizer.encode(cfg.TASK_TOKENS[CONVERT])
stop_token = text_tokenizer.encode(cfg.STOP_TOKEN)

text_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[TEXT])
acoustic_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[MIMI])

In [50]:
text = 'hey we are indri labs and we are glad to present you the super fast text to speech system that we have built'
txt_toks = text_tokenizer.encode(text)
speaker_id = text_tokenizer.encode(cfg.UNKNOWN_SPEAKER_ID)

In [51]:
input_tokens = np.hstack([
    text_modality_token,
    txt_toks,
    convert_token,
    acoustic_modality_token,
    speaker_id,
])
input_tokens = (torch.tensor(input_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Text tokens: {input_tokens.shape}')
text_tokenizer.decode(input_tokens[0])

Text tokens: torch.Size([1, 28])


'[text]hey we are indri labs and we are glad to present you the super fast text to speech system that we have built[convert][mimi][spkr_unk]'

In [52]:
from transformers import LogitsProcessor
class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
    def __init__(self, input_start_len: int, codebook_size: int, num_codebooks: int, offset: int, stop_token: int):
        self.input_start_len = input_start_len
        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks
        self.offset = offset
        self.stop_token = stop_token
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        curr_len = input_ids.shape[-1]
        codebook_idx = ((curr_len - self.input_start_len) % self.num_codebooks)
        
        scores_processed = scores.clone()
        scores_processed[:, : self.offset + codebook_idx * self.codebook_size] = -float("inf")
        scores_processed[:, self.offset + (codebook_idx+1) * self.codebook_size :] = -float("inf")
        scores_processed[:, stop_token] = scores[:, stop_token]
        return scores_processed

In [53]:
with CTX:
    omni_model.generation_config.eos_token_id = stop_token
    semantic_tokens = omni_model.generate(
        input_tokens,
        max_length=1024,
        temperature=0.6,
        top_k=30,
        do_sample=True,
        logits_processor=[AlternatingCodebooksLogitsProcessor(input_start_len=len(input_tokens[0]),
                                                              codebook_size=2048,
                                                              num_codebooks=4,
                                                              offset=cfg.OFFSET[MIMI],
                                                             stop_token=stop_token)]
    )
    semantic_tokens = semantic_tokens.detach().cpu().numpy()
    print(semantic_tokens.shape)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


(1, 365)


In [54]:
sem_tokens = semantic_tokens[0][len(input_tokens[0]):]
text_tokenizer.decode(sem_tokens)

'[aco_1049][aco_2848][aco_5690][aco_7706][aco_1946][aco_2291][aco_5838][aco_7492][aco_1156][aco_3567][aco_4879][aco_7856][aco_420][aco_2519][aco_4456][aco_7236][aco_225][aco_3590][aco_5977][aco_6418][aco_915][aco_3052][aco_5024][aco_6902][aco_1828][aco_2323][aco_5576][aco_6985][aco_281][aco_2323][aco_5477][aco_6748][aco_285][aco_3945][aco_5477][aco_6704][aco_1546][aco_3831][aco_4977][aco_7762][aco_220][aco_2833][aco_5636][aco_6867][aco_1989][aco_2068][aco_5024][aco_6426][aco_929][aco_2349][aco_4740][aco_6416][aco_929][aco_2349][aco_4403][aco_7378][aco_1014][aco_2861][aco_4898][aco_6545][aco_1781][aco_3317][aco_4602][aco_6321][aco_1536][aco_3431][aco_4560][aco_6418][aco_342][aco_2481][aco_4626][aco_6969][aco_1592][aco_3900][aco_5792][aco_6854][aco_371][aco_3138][aco_4200][aco_8058][aco_75][aco_3442][aco_5083][aco_7942][aco_90][aco_2099][aco_6080][aco_7888][aco_1222][aco_2099][aco_5060][aco_6330][aco_1222][aco_2099][aco_5216][aco_7964][aco_1822][aco_3090][aco_4305][aco_7727][aco_1916][ac

In [55]:
# last = None
last = np.where(sem_tokens==stop_token)[0][0]

In [56]:
audio_tokens = sem_tokens[:last] - cfg.OFFSET[MIMI]
audio_tokens

array([1049, 2848, 5690, 7706, 1946, 2291, 5838, 7492, 1156, 3567, 4879,
       7856,  420, 2519, 4456, 7236,  225, 3590, 5977, 6418,  915, 3052,
       5024, 6902, 1828, 2323, 5576, 6985,  281, 2323, 5477, 6748,  285,
       3945, 5477, 6704, 1546, 3831, 4977, 7762,  220, 2833, 5636, 6867,
       1989, 2068, 5024, 6426,  929, 2349, 4740, 6416,  929, 2349, 4403,
       7378, 1014, 2861, 4898, 6545, 1781, 3317, 4602, 6321, 1536, 3431,
       4560, 6418,  342, 2481, 4626, 6969, 1592, 3900, 5792, 6854,  371,
       3138, 4200, 8058,   75, 3442, 5083, 7942,   90, 2099, 6080, 7888,
       1222, 2099, 5060, 6330, 1222, 2099, 5216, 7964, 1822, 3090, 4305,
       7727, 1916, 3113, 4807, 6857, 1602, 3203, 5595, 6683, 1931, 2387,
       4722, 6835,  845, 3343, 4733, 7680,  257, 3431, 5366, 7479, 1368,
       3817, 5712, 8077, 1657, 2613, 5024, 6869, 1253, 2952, 4763, 7620,
       1344, 4029, 4301, 6282,   90, 3784, 5959, 7176,  237, 2666, 4543,
       7622, 2030, 2834, 4257, 6969,  284, 3360, 60

In [57]:
def deserialize_tokens(tokens):
    cb1 = tokens[::4]
    cb2 = tokens[1::4]
    cb3 = tokens[2::4]
    cb4 = tokens[3::4]
    min_shape = min(cb1.shape, cb2.shape, cb3.shape, cb4.shape)[0]
    acoustic_tokens = np.stack([cb1[:min_shape], cb2[:min_shape] - 2048, cb3[:min_shape] - 4096, cb4[:min_shape] - 6144])
    return acoustic_tokens

In [58]:
mimi_tokens = deserialize_tokens(audio_tokens)

In [59]:
mimi_tokens

array([[1049, 1946, 1156,  420,  225,  915, 1828,  281,  285, 1546,  220,
        1989,  929,  929, 1014, 1781, 1536,  342, 1592,  371,   75,   90,
        1222, 1222, 1822, 1916, 1602, 1931,  845,  257, 1368, 1657, 1253,
        1344,   90,  237, 2030,  284,  313, 1997,  199, 1109, 1109,  786,
         786,  272, 1967, 1112,  887, 1295, 1767, 1356,  101, 2027, 1615,
         912,  847, 1580,  984, 1658,  912, 1967, 1762,   96, 1333,  327,
         735, 1985, 1615, 1304, 1304, 1747, 1612, 1250,  345, 1085, 1865,
        1324,  668, 1039,  579, 1448, 1814, 1771],
       [ 800,  243, 1519,  471, 1542, 1004,  275,  275, 1897, 1783,  785,
          20,  301,  301,  813, 1269, 1383,  433, 1852, 1090, 1394,   51,
          51,   51, 1042, 1065, 1155,  339, 1295, 1383, 1769,  565,  904,
        1981, 1736,  618,  786, 1312, 1882,  974, 1155,  400, 1900,  401,
        1900, 1435, 1155, 1776, 1168, 1444,  758,  366,   51,  194, 1843,
          31,  731, 1312, 1496, 1744, 1776, 1908, 1159, 1584,

In [60]:
out = model.decode(torch.tensor(np.expand_dims(mimi_tokens, axis=0)))

In [61]:
out.audio_values.shape

torch.Size([1, 1, 161280])

In [62]:
import torchaudio

In [63]:
torchaudio.save('test.wav', out.audio_values[0],sample_rate=24000)