In [1]:
!pip install transformers torch torchaudio
!pip install snac bitsandbytes




In [2]:
!pip install -U bitsandbytes



In [3]:
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from snac import SNAC
import soundfile as sf

# -------------------------------------------------------------
# Configuration
# -------------------------------------------------------------
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load model and tokenizer
print("Loading Veena model...")
t0 = time.perf_counter()
model = AutoModelForCausalLM.from_pretrained(
    "maya-research/veena-tts",
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("maya-research/veena-tts", trust_remote_code=True)
print(f"Model loaded in {(time.perf_counter() - t0):.2f}s")

# Initialize SNAC decoder
print("Loading SNAC decoder...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().cuda()
print("SNAC ready.")

# -------------------------------------------------------------
# Constants
# -------------------------------------------------------------
START_OF_SPEECH_TOKEN = 128257
END_OF_SPEECH_TOKEN = 128258
START_OF_HUMAN_TOKEN = 128259
END_OF_HUMAN_TOKEN = 128260
START_OF_AI_TOKEN = 128261
END_OF_AI_TOKEN = 128262
AUDIO_CODE_BASE_OFFSET = 128266

speakers = ["kavya", "agastya", "maitri", "vinaya"]

# -------------------------------------------------------------
# Functions
# -------------------------------------------------------------
def generate_speech(text, speaker="kavya", temperature=0.4, top_p=0.9):
    """Generate speech from text using specified speaker voice, with timing."""
    total_start = time.perf_counter()

    # Tokenize
    prompt = f"<spk_{speaker}> {text}"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
    input_tokens = [
        START_OF_HUMAN_TOKEN, *prompt_tokens,
        END_OF_HUMAN_TOKEN, START_OF_AI_TOKEN, START_OF_SPEECH_TOKEN
    ]
    input_ids = torch.tensor([input_tokens], device=model.device)

    max_tokens = min(int(len(text) * 1.3) * 7 + 21, 700)

    # Generate audio tokens
    gen_start = time.perf_counter()
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN]
        )
    gen_time = time.perf_counter() - gen_start

    # Extract SNAC tokens
    generated_ids = output[0][len(input_tokens):].tolist()
    snac_tokens = [
        token_id for token_id in generated_ids
        if AUDIO_CODE_BASE_OFFSET <= token_id < (AUDIO_CODE_BASE_OFFSET + 7 * 4096)
    ]
    if not snac_tokens:
        raise ValueError("No audio tokens generated")

    # Decode
    decode_start = time.perf_counter()
    audio = decode_snac_tokens(snac_tokens, snac_model)
    decode_time = time.perf_counter() - decode_start

    total_time = time.perf_counter() - total_start
    print(f"[{speaker}] Generation: {gen_time:.2f}s | Decode: {decode_time:.2f}s | Total: {total_time:.2f}s")
    return audio


def decode_snac_tokens(snac_tokens, snac_model):
    """De-interleave and decode SNAC tokens to audio."""
    if not snac_tokens or len(snac_tokens) % 7 != 0:
        return None

    snac_device = next(snac_model.parameters()).device
    codes_lvl = [[] for _ in range(3)]
    llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]

    for i in range(0, len(snac_tokens), 7):
        codes_lvl[0].append(snac_tokens[i] - llm_codebook_offsets[0])
        codes_lvl[1].append(snac_tokens[i+1] - llm_codebook_offsets[1])
        codes_lvl[1].append(snac_tokens[i+4] - llm_codebook_offsets[4])
        codes_lvl[2].append(snac_tokens[i+2] - llm_codebook_offsets[2])
        codes_lvl[2].append(snac_tokens[i+3] - llm_codebook_offsets[3])
        codes_lvl[2].append(snac_tokens[i+5] - llm_codebook_offsets[5])
        codes_lvl[2].append(snac_tokens[i+6] - llm_codebook_offsets[6])

    hierarchical_codes = []
    for lvl_codes in codes_lvl:
        tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
        if torch.any((tensor < 0) | (tensor > 4095)):
            raise ValueError("Invalid SNAC token values")
        hierarchical_codes.append(tensor)

    with torch.no_grad():
        audio_hat = snac_model.decode(hierarchical_codes)
    return audio_hat.squeeze().clamp(-1, 1).cpu().numpy()

# -------------------------------------------------------------
# Example Usage
# -------------------------------------------------------------
print("\n--- Generating Hindi ---")
text_hindi = "आज मैंने एक नई तकनीक के बारे में सीखा जो कृत्रिम बुद्धिमत्ता का उपयोग करके मानव जैसी आवाज़ उत्पन्न कर सकती है।"
audio = generate_speech(text_hindi, speaker="kavya")
sf.write("output_hindi_kavya.wav", audio, 24000)

print("\n--- Generating English ---")
text_english = "Today I learned about a new technology that uses artificial intelligence to generate human-like voices."
audio = generate_speech(text_english, speaker="agastya")
sf.write("output_english_agastya.wav", audio, 24000)

print("\n--- Generating Code-Mixed ---")
text_mixed = "मैं तो पूरा presentation prepare कर चुका हूं! कल रात को ही मैंने पूरा code base चेक किया।"
audio = generate_speech(text_mixed, speaker="maitri")
sf.write("output_mixed_maitri.wav", audio, 24000)

print("\nAll generations done. Check WAV files in your runtime.")


Loading Veena model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.58G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/180 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/697 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

Model loaded in 154.21s
Loading SNAC decoder...


config.json:   0%|          | 0.00/300 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/79.5M [00:00<?, ?B/s]

SNAC ready.

--- Generating Hindi ---
[kavya] Generation: 52.17s | Decode: 1.71s | Total: 53.88s

--- Generating English ---
[agastya] Generation: 39.88s | Decode: 0.04s | Total: 39.92s

--- Generating Code-Mixed ---
[maitri] Generation: 35.25s | Decode: 0.04s | Total: 35.28s

All generations done. Check WAV files in your runtime.


In [4]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(156951, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNo