### Installation

In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.55.4
!pip install --no-deps trl==0.22.2
!pip install snac
!pip install json_repair
!pip install -U bitsandbytes
!pip install librosa ipywebrtc

###test

In [None]:
!pip install snac

Collecting snac
  Downloading snac-1.2.1-py3-none-any.whl.metadata (3.5 kB)
Downloading snac-1.2.1-py3-none-any.whl (8.4 kB)
Installing collected packages: snac
Successfully installed snac-1.2.1


In [None]:
# -*- coding: utf-8 -*-
# فقط دیکد SNAC از روی speech_tokens — بدون مدل زبانی

import torch
from IPython.display import display, Audio

# ---------------------------
# پیکربندی
# ---------------------------
SNAC_DEVICE   = "cpu"      # یا "cuda" اگر خواستی روی GPU
CODEBOOK_SIZE = 4096
SR            = 24000

# نشانه‌ها/آی‌دی‌ها (مثل قبل)
START_OF_SPEECH = 128257
END_OF_SPEECH   = 128258
AUDIO_BASE      = 128266   # آفست پایهٔ کُدها

# ---------------------------
# SNAC
# ---------------------------
from snac import SNAC
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(SNAC_DEVICE)
snac_model.eval()

# ---------------------------
# توابع کمکی
# ---------------------------
def _split_7(code_list):
    for i in range(0, len(code_list), 7):
        yield code_list[i:i+7]

def _to_layers(group7):
    """
    group7: [g0..g6]  (همان ترتیبی که در بخش encode ساخته بودی)
    بازسازی لایه‌ها:
      L1: [g0]
      L2: [g1, g4]              (دوبرابر L1)
      L3: [g2, g3, g5, g6]      (چهاربرابر L1)
    سپس هر کدام را از شیفت لایه‌ای‌شان کم می‌کنیم:
      g1-1*4096, g2-2*4096, g3-3*4096, g4-4*4096, g5-5*4096, g6-6*4096
    """
    layer_1 = group7[0]                    # 0*4096
    layer_2_a = group7[1] - (1*4096)
    layer_3_a = group7[2] - (2*4096)
    layer_3_b = group7[3] - (3*4096)
    layer_2_b = group7[4] - (4*4096)
    layer_3_c = group7[5] - (5*4096)
    layer_3_d = group7[6] - (6*4096)

    l1 = [layer_1]
    l2 = [layer_2_a, layer_2_b]
    l3 = [layer_3_a, layer_3_b, layer_3_c, layer_3_d]
    return l1, l2, l3

def _valid_idx(x):
    return 0 <= x < CODEBOOK_SIZE

def _clamp(x):
    # اگر نمی‌خواهی گروه را دور بیندازی، می‌توانی clamp کنی
    return 0 if x < 0 else (CODEBOOK_SIZE-1 if x >= CODEBOOK_SIZE else x)

def tokens_to_audio(single_tokens,
                    clamp_out_of_range=True,
                    drop_invalid_groups=False):
    """
    ورودی: single_tokens = لیست آیدی‌های توکن (int)
    خروجی: audio_hat (Tensor [1, T_audio]), invalid_groups (int)
    """
    # 1) بریدن به بعد از آخرین START_OF_SPEECH و حذف END_OF_SPEECH
    t = torch.tensor(single_tokens, dtype=torch.long)
    idxs = (t == START_OF_SPEECH).nonzero(as_tuple=True)[0]
    start = (idxs[-1].item()+1) if len(idxs) > 0 else 0
    t = t[start:]
    t = t[t != END_OF_SPEECH]

    # 2) فقط توکن‌های صوتی که >= AUDIO_BASE هستند
    t = t[t >= AUDIO_BASE]
    if t.numel() == 0:
        return None, 0

    # 3) حذف آفست پایه و trim به مضرب 7
    t = (t - AUDIO_BASE)
    new_len = (t.numel() // 7) * 7
    t = t[:new_len]
    if t.numel() == 0:
        return None, 0

    # 4) بازتوزیع به لایه‌ها
    layer_1, layer_2, layer_3 = [], [], []
    invalid_groups = 0

    for g in _split_7(t.tolist()):
        if len(g) < 7:
            continue
        l1, l2, l3 = _to_layers(g)

        if all(_valid_idx(v) for v in l1+l2+l3):
            layer_1.extend(l1); layer_2.extend(l2); layer_3.extend(l3)
        else:
            invalid_groups += 1
            if drop_invalid_groups:
                # این فریم ۷تایی را رها می‌کنیم
                continue
            elif clamp_out_of_range:
                l1 = [_clamp(x) for x in l1]
                l2 = [_clamp(x) for x in l2]
                l3 = [_clamp(x) for x in l3]
                layer_1.extend(l1); layer_2.extend(l2); layer_3.extend(l3)
            else:
                # اگر نه drop و نه clamp: این فریم را حذف کنیم بهتر است
                continue

    if len(layer_1) == 0:
        return None, invalid_groups

    # 5) شکل ورودی SNAC.decode باید (B=1, T), (1, 2T), (1, 4T) باشد
    T_len = len(layer_1)
    assert len(layer_2) == 2*T_len and len(layer_3) == 4*T_len, "Bad layer lengths"

    codes = [
        torch.tensor(layer_1, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
        torch.tensor(layer_2, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
        torch.tensor(layer_3, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
    ]

    with torch.no_grad():
        audio_hat = snac_model.decode(codes)  # [1, T_audio]
    return audio_hat, invalid_groups

def batch_decode(list_of_speech_tokens, **kwargs):
    """
    ورودی: لیستی از نمونه‌ها — هر کدام آرایه‌ای از آیدی‌ها
    خروجی: لیست Tensorهای صوت و مجموع گروه‌های نامعتبر
    """
    outs, total_invalid = [], 0
    for tokens in list_of_speech_tokens:
        audio, invalid = tokens_to_audio(tokens, **kwargs)
        total_invalid += invalid
        if audio is not None:
            outs.append(audio)
    return outs, total_invalid

# ---------------------------
# نمونه‌ی استفاده
# ---------------------------
# فرض: شما خودتان speech_tokens را آماده می‌دهید.
# مثال ساختگی زیر فقط فرم را نشان می‌دهد:
tokens=[128257,
  131218,
  136235,
  136515,
  144558,
  148523,
  150464,
  154642,
  129487,
  134651,
  136558,
  142685,
  144819,
  151524,
  152963,
  130658,
  133795,
  137739,
  143299,
  146973,
  152813,
  154924,
  130966,
  132618,
  137231,
  143524,
  146778,
  149618,
  155158,
  128462,
  134248,
  137387,
  142620,
  145688,
  149434,
  156624,
  132358,
  134957,
  136586,
  143180,
  147290,
  150438,
  155761,
  129601,
  136154,
  137875,
  144297,
  144752,
  151858,
  156231,
  129177,
  133996,
  139072,
  144002,
  145169,
  149983,
  156534,
  129225,
  134137,
  136677,
  143674,
  146723,
  151441,
  154929,
  129328,
  136339,
  138183,
  144634,
  148214,
  151967,
  153539,
  128691,
  135299,
  140525,
  144544,
  144817,
  150946,
  156283,
  130315,
  134672,
  136956,
  144344,
  145802,
  149193,
  153168,
  129485,
  134579,
  138534,
  144603,
  147722,
  149462,
  156400,
  129447,
  136401,
  137552,
  143547,
  147453,
  151401,
  155878,
  130560,
  135710,
  137760,
  142509,
  146441,
  148880,
  153841,
  130611,
  134413,
  137037,
  140936,
  147585,
  152232,
  156047,
  128971,
  135996,
  137454,
  140578,
  148194,
  152287,
  154521,
  130370,
  134132,
  138277,
  141699,
  147641,
  151298,
  153987,
  130996,
  134075,
  139185,
  140600,
  144711,
  149252,
  154747,
  130684,
  135417,
  139528,
  140924,
  146626,
  150880,
  155806,
  129454,
  135479,
  139712,
  141165,
  148011,
  150040,
  156299,
  130857,
  134645,
  136637,
  143972,
  147780,
  149795,
  155750,
  130513,
  133301,
  137539,
  144094,
  145830,
  149325,
  154228,
  129518,
  134468,
  139962,
  144134,
  148701,
  151618,
  153498,
  129900,
  135991,
  137635,
  143056,
  145145,
  151968,
  156659,
  130071,
  133137,
  138461,
  144024,
  148236,
  149054,
  154221,
  129754,
  135953,
  137940,
  142675,
  146536,
  151219,
  153570,
  130532,
  132748,
  138288,
  140988,
  145872,
  148753,
  153574,
  128875,
  134338,
  138095,
  144066,
  145282,
  151729,
  154595,
  131334,
  133761,
  136498,
  140985,
  145036,
  149295,
  156217,
  130850,
  132887,
  139875,
  142555,
  146848,
  148785,
  153816,
  130302,
  135807,
  138046,
  143319,
  148667,
  151583,
  156549,
  129484,
  135136,
  139289,
  141953,
  145191,
  151423,
  152917,
  129485,
  134323,
  138339,
  141564,
  147121,
  150539,
  154376,
  130532,
  134465,
  138823,
  142304,
  147424,
  152612,
  154240,
  128275,
  135136,
  137682,
  141703,
  148641,
  151036,
  153004,
  130382,
  134588,
  136605,
  141885,
  147126,
  151755,
  153247,
  129411,
  133105,
  137102,
  141265,
  146425,
  150582,
  155729,
  129601,
  134390,
  137686,
  143782,
  146723,
  148791,
  155376,
  130833,
  132427,
  139403,
  143352,
  146850,
  150882,
  153870,
  132012,
  134266,
  139278,
  142070,
  145795,
  152431,
  156050,
  128353,
  133311,
  139089,
  141381,
  145337,
  150936,
  153658,
  129165,
  132799,
  138210,
  144024,
  148135,
  149418,
  155356,
  129487,
  133996,
  138536,
  143124,
  147704,
  152810,
  153321,
  130384,
  135927,
  138443,
  142237,
  145303,
  151054,
  154024,
  130029,
  133548,
  137258,
  143373,
  144783,
  149282,
  153016,
  129772,
  133963,
  138841,
  142417,
  148703,
  149282,
  154217,
  131717,
  135245,
  140171,
  144027,
  145950,
  149844,
  153407,
  131596,
  135385,
  139372,
  142814,
  144961,
  150238,
  153159,
  132348,
  132836,
  138600,
  144373,
  147360,
  149548,
  152912,
  130315,
  134672,
  139898,
  142696,
  146626,
  151844,
  153891,
  128770,
  136413,
  137107,
  143482,
  147744,
  151192,
  153276,
  131006,
  136348,
  139105,
  143476,
  145282,
  149964,
  153775,
  131735,
  133210,
  136862,
  142169,
  146531,
  149635,
  154444,
  131006,
  134490,
  137864,
  141447,
  145138,
  152020,
  154807,
  129712,
  134455,
  137448,
  142695,
  148701,
  149888,
  154707,
  130961,
  135889,
  136618,
  142757,
  145036,
  150982,
  153548,
  132336,
  133566,
  136532,
  144141,
  144674,
  151524,
  154218,
  131334,
  135593,
  138772,
  143492,
  145707,
  149599,
  155261,
  132358,
  132402,
  138222,
  143275,
  147615,
  149794,
  155571,
  130855,
  135338,
  138003,
  141850,
  146780,
  152782,
  155511,
  129725,
  136115,
  139445,
  144575,
  146849,
  151246,
  154118,
  131253,
  133179,
  138705,
  143106,
  147386,
  151942,
  156832,
  130045,
  135962,
  137421,
  143449,
  146626,
  151188,
  156483,
  132002,
  132887,
  139849,
  140611,
  146848,
  148878,
  153289,
  129487,
  136235,
  139295,
  140720,
  144715,
  150801,
  153430,
  130532,
  135656,
  138352,
  143432,
  146188,
  152060,
  156417,
  129810,
  132577,
  138638,
  144318,
  146839,
  148926,
  153504,
  131456,
  135637,
  138052,
  140711,
  148569,
  149434,
  154011]

speech_batch = [tokens]

audios, invalid_total = batch_decode(speech_batch, clamp_out_of_range=True, drop_invalid_groups=False)

if len(audios) == 0:
    raise RuntimeError(f"No valid 7-code groups for SNAC decoding (discarded groups: {invalid_total}).")

print(f"Decoded {len(audios)} sample(s). Discarded groups: {invalid_total}")
for i, a in enumerate(audios, 1):
    print(f"Sample {i}")
    display(Audio(a.detach().squeeze().cpu().numpy(), rate=SR))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/300 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/79.5M [00:00<?, ?B/s]

Decoded 1 sample(s). Discarded groups: 0
Sample 1


<a name="Inference"></a>
### Inference
Let's run the model! You can change the prompts



In [5]:
from unsloth import FastLanguageModel
from huggingface_hub import login
import torch
import os
from dotenv import load_dotenv

load_dotenv(dotenv_path="/content/drive/MyDrive/projects/TTS/.env")

HF_TOKEN = os.getenv('HF_TOKEN')
login(HF_TOKEN)

In [None]:
from IPython.display import display, Audio
from peft import PeftModel
DEVICE_LLAMA = "cuda"
SNAC_DEVICE  = "cpu"
CODEBOOK_SIZE = 4096

base_model, tokenizer = FastLanguageModel.from_pretrained(
    "unsloth/orpheus-3b-0.1-ft",
    max_seq_length=3000,
    dtype=None,
    load_in_4bit=True,
)

model = PeftModel.from_pretrained(base_model, "David-ger/Orpheus-llama3b-fa-finetuned-gpt5-mini-4865")

model.eval()

from snac import SNAC
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(SNAC_DEVICE)
snac_model.eval()



### example


In [None]:
prompts = [
    "faateme lotfan chaai beriz"

]

chosen_voice = None

In [None]:

prompts_ = [(f"female1: " + prompts[0]) ]

all_input_ids = []
for prompt in prompts_:
    input_ids_ = tokenizer(prompt, return_tensors="pt").input_ids
    all_input_ids.append(input_ids_)

start_token = torch.tensor([[128259]], dtype=torch.long)
end_tokens  = torch.tensor([[128009, 128260]], dtype=torch.long)

all_modified_input_ids = []
for input_ids_ in all_input_ids:
    modified_input_ids = torch.cat([start_token, input_ids_, end_tokens], dim=1)
    all_modified_input_ids.append(modified_input_ids)

all_padded_tensors = []
all_attention_masks = []
max_length = max(m.shape[1] for m in all_modified_input_ids)
for m in all_modified_input_ids:
    padding = max_length - m.shape[1]
    pad_ids  = torch.full((1, padding), 128263, dtype=torch.long)
    attn_pad = torch.zeros((1, padding), dtype=torch.long)
    attn_tok = torch.ones((1, m.shape[1]), dtype=torch.long)
    padded_tensor   = torch.cat([pad_ids, m], dim=1)
    attention_mask  = torch.cat([attn_pad, attn_tok], dim=1)
    all_padded_tensors.append(padded_tensor)
    all_attention_masks.append(attention_mask)

input_ids      = torch.cat(all_padded_tensors, dim=0)
attention_mask = torch.cat(all_attention_masks, dim=0)

input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)

generated_ids = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    max_new_tokens=1200,
    do_sample=True,
    temperature=0.6,
    top_p=0.95,
    repetition_penalty=1.1,
    num_return_sequences=1,
    eos_token_id=128258,
    use_cache=True
)

token_to_find   = 128257
token_to_remove = 128258

token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
    last_occurrence_idx = token_indices[1][-1].item()
    cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
    cropped_tensor = generated_ids

processed_rows = []
for row in cropped_tensor:
    row_cpu = row.detach().to("cpu")
    masked_row = row_cpu[row_cpu != token_to_remove]
    processed_rows.append(masked_row)

code_lists = []
for row in processed_rows:
    row_length = row.size(0)
    new_length = (row_length // 7) * 7
    trimmed_row = row[:new_length]
    trimmed_row = (trimmed_row - 128266).cpu().tolist()
    code_lists.append(trimmed_row)

def _split_7(code_list):
    for i in range(0, len(code_list), 7):
        yield code_list[i:i+7]

def _to_layers(group7):
    layer_1 = group7[0]
    layer_2_a = group7[1] - 4096
    layer_3_a = group7[2] - (2*4096)
    layer_3_b = group7[3] - (3*4096)
    layer_2_b = group7[4] - (4*4096)
    layer_3_c = group7[5] - (5*4096)
    layer_3_d = group7[6] - (6*4096)

    l1 = [layer_1]
    l2 = [layer_2_a, layer_2_b]
    l3 = [layer_3_a, layer_3_b, layer_3_c, layer_3_d]
    return l1, l2, l3

def _valid_idx(x):
    return 0 <= x < CODEBOOK_SIZE

def redistribute_codes(code_list):
    """
    منطق همان قبلی است؛ فقط قبل از decode اطمینان می‌گیریم
    تمام اندیس‌ها داخل بازهٔ [0, CODEBOOK_SIZE) باشند.
    """
    layer_1, layer_2, layer_3 = [], [], []
    invalid_groups = 0

    for group7 in _split_7(code_list):
        if len(group7) < 7:
            continue
        l1, l2, l3 = _to_layers(group7)

        if (all(_valid_idx(v) for v in l1) and
            all(_valid_idx(v) for v in l2) and
            all(_valid_idx(v) for v in l3)):
            layer_1.extend(l1)
            layer_2.extend(l2)
            layer_3.extend(l3)
        else:
            invalid_groups += 1

    if len(layer_1) == 0:
        return None, invalid_groups

    codes = [
        torch.tensor(layer_1, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
        torch.tensor(layer_2, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
        torch.tensor(layer_3, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
    ]

    with torch.no_grad():
        audio_hat = snac_model.decode(codes)
    return audio_hat, invalid_groups

my_samples = []
total_invalid = 0
for code_list in code_lists:
    audio, invalid = redistribute_codes(code_list)
    total_invalid += invalid
    if audio is not None:
        my_samples.append(audio)

if len(my_samples) == 0:
    raise RuntimeError(
        f"No valid 7-code groups for SNAC decoding (discarded groups: {total_invalid}). "
    )

if len(prompts) != len(my_samples):

    print(f"Warning: {total_invalid} invalid groups discarded; prompts={len(prompts)}, samples={len(my_samples)}")

for i, samples in enumerate(my_samples):
    print(prompts[i] if i < len(prompts) else f"Sample {i+1}")
    display(Audio(samples.detach().squeeze().cpu().numpy(), rate=24000))

del my_samples


chetor mitoonam komaket konam?


### prepare TTS

In [9]:
def generate_audio(text):

  prompts_ = [(f"female1: {text}") ]

  all_input_ids = []
  for prompt in prompts_:
      input_ids_ = tokenizer(prompt, return_tensors="pt").input_ids
      all_input_ids.append(input_ids_)

  start_token = torch.tensor([[128259]], dtype=torch.long)
  end_tokens  = torch.tensor([[128009, 128260]], dtype=torch.long)

  all_modified_input_ids = []
  for input_ids_ in all_input_ids:
      modified_input_ids = torch.cat([start_token, input_ids_, end_tokens], dim=1)
      all_modified_input_ids.append(modified_input_ids)

  all_padded_tensors = []
  all_attention_masks = []
  max_length = max(m.shape[1] for m in all_modified_input_ids)
  for m in all_modified_input_ids:
      padding = max_length - m.shape[1]
      pad_ids  = torch.full((1, padding), 128263, dtype=torch.long)
      attn_pad = torch.zeros((1, padding), dtype=torch.long)
      attn_tok = torch.ones((1, m.shape[1]), dtype=torch.long)
      padded_tensor   = torch.cat([pad_ids, m], dim=1)
      attention_mask  = torch.cat([attn_pad, attn_tok], dim=1)
      all_padded_tensors.append(padded_tensor)
      all_attention_masks.append(attention_mask)

  input_ids      = torch.cat(all_padded_tensors, dim=0)
  attention_mask = torch.cat(all_attention_masks, dim=0)

  input_ids = input_ids.to(model.device)
  attention_mask = attention_mask.to(model.device)

  generated_ids = model.generate(
      input_ids=input_ids,
      attention_mask=attention_mask,
      max_new_tokens=1200,
      do_sample=True,
      temperature=0.6,
      top_p=0.95,
      repetition_penalty=1.1,
      num_return_sequences=1,
      eos_token_id=128258,
      use_cache=True
  )

  token_to_find   = 128257
  token_to_remove = 128258

  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
  if len(token_indices[1]) > 0:
      last_occurrence_idx = token_indices[1][-1].item()
      cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
  else:
      cropped_tensor = generated_ids

  processed_rows = []
  for row in cropped_tensor:
      row_cpu = row.detach().to("cpu")
      masked_row = row_cpu[row_cpu != token_to_remove]
      processed_rows.append(masked_row)

  code_lists = []
  for row in processed_rows:
      row_length = row.size(0)
      new_length = (row_length // 7) * 7
      trimmed_row = row[:new_length]
      trimmed_row = (trimmed_row - 128266).cpu().tolist()
      code_lists.append(trimmed_row)

  def _split_7(code_list):
      for i in range(0, len(code_list), 7):
          yield code_list[i:i+7]

  def _to_layers(group7):
      layer_1 = group7[0]
      layer_2_a = group7[1] - 4096
      layer_3_a = group7[2] - (2*4096)
      layer_3_b = group7[3] - (3*4096)
      layer_2_b = group7[4] - (4*4096)
      layer_3_c = group7[5] - (5*4096)
      layer_3_d = group7[6] - (6*4096)

      l1 = [layer_1]
      l2 = [layer_2_a, layer_2_b]
      l3 = [layer_3_a, layer_3_b, layer_3_c, layer_3_d]
      return l1, l2, l3

  def _valid_idx(x):
      return 0 <= x < CODEBOOK_SIZE

  def redistribute_codes(code_list):

      layer_1, layer_2, layer_3 = [], [], []
      invalid_groups = 0

      for group7 in _split_7(code_list):
          if len(group7) < 7:
              continue
          l1, l2, l3 = _to_layers(group7)

          if (all(_valid_idx(v) for v in l1) and
              all(_valid_idx(v) for v in l2) and
              all(_valid_idx(v) for v in l3)):
              layer_1.extend(l1)
              layer_2.extend(l2)
              layer_3.extend(l3)
          else:
              invalid_groups += 1

      if len(layer_1) == 0:
          return None, invalid_groups

      codes = [
          torch.tensor(layer_1, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
          torch.tensor(layer_2, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
          torch.tensor(layer_3, dtype=torch.long, device=SNAC_DEVICE).unsqueeze(0),
      ]

      with torch.no_grad():
          audio_hat = snac_model.decode(codes)
      return audio_hat, invalid_groups

  my_samples = []
  total_invalid = 0
  for code_list in code_lists:
      audio, invalid = redistribute_codes(code_list)
      total_invalid += invalid
      if audio is not None:
          my_samples.append(audio)

  if len(my_samples) == 0:
      raise RuntimeError(
          f"No valid 7-code groups for SNAC decoding (discarded groups: {total_invalid}). "
      )

  for i, samples in enumerate(my_samples):
      print(text)
      display(Audio(samples.detach().squeeze().cpu().numpy(), rate=24000))

  del my_samples


### ASR


In [None]:
from transformers import pipeline
pipe = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")

Device set to use cuda:0


###chatbot

In [7]:
from openai import OpenAI

OPENAI_TOKEN = os.getenv('OPENAI_TOKEN')
client = OpenAI(api_key=OPENAI_TOKEN)
def chatbot(sentence: str):
        prompt = f"user: {sentence}\noutput answer: "

        messages =[{"role":"system","content":"""
        you are a smart assistant.
        user write in persian and you must write your formally answer in finglish.

        Example:
        user: سلام

        output answer:
        salaam, chetor mitoonam komaketoon konam?

        """},{"role":"user","content":prompt}]

        response = client.chat.completions.create(
            model="gpt-5-mini",
            messages=messages,
            reasoning_effort="minimal",
            max_completion_tokens=10000,
        )

        return response.choices[0].message.content


### listen and speak

In [12]:
generate_audio("salaam khoobi?")

salaam khoobi?


In [11]:
generate_audio("to oomadi")

to oomadi


In [13]:
# result = pipe("/content/drive/MyDrive/projects/TTS/hh.ogg",generate_kwargs={"language": "persian"})
# print(result['text'])
chatbot_answer = chatbot("اسمت چیه؟")
generate_audio(chatbot_answer)

man yek model zabaani hastam va esm-e rasmi nadaram, amma mitoonid man ro "assistant" ya "chatbot" seda konid. chetor mitoonam komaketoon konam?


### voice cloning

In [None]:

my_wav_file_is = "/content/drive/MyDrive/projects/TTS/to_clone.ogg"
and_the_transcript_is = """barkhalaf bachehaayi ke labkhandhaye bi-dalil daran, in koodak hesse jeddiyat dare. ehtemaal ziadi hast ke bishtar az senesh mifahme. in mitoone khoob baashe, vali az tarafi ham mitoone baes-e ezterabe darooni ya feshar-e ravani beshe ke kasi motevajehesh nist.be nazar nemirese az oon bachehaayi baashe ke sari’ miparan vasat-e jam’ ya hame ja mikhandan. bishtar ehtemaal dare aval sabr kone, faza ro tahlil kone, bad vared-e baazi ya goftogoo beshe. in mitoone neshone-ye daroon-garaayi ya hatta ezterabe ejtemaa’i-ye molayem baashe."""

the_model_should_say = [
    "man fatemeh hastam"
 ]

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from ipywebrtc import AudioRecorder, Audio
from IPython.display import display
import ipywidgets as widgets
from huggingface_hub import snapshot_download
import torchaudio.transforms as T
import librosa

device = "cuda" if torch.cuda.is_available() else "mps" #or cpu if you aren't on an M type mac

'''
The template is:

start_of_human, start_of_text, text, end_of_text, start_of_ai, start_of_speech, speech, end_of_speech, end_of_ai, start_of_human, text, end_of_human and then generate from here

'''


filename = my_wav_file_is

audio_array, sample_rate = librosa.load(filename, sr=24000)

def tokenise_audio(waveform):
  waveform = torch.from_numpy(waveform).unsqueeze(0)
  waveform = waveform.to(dtype=torch.float32)


  waveform = waveform.unsqueeze(0)

  with torch.inference_mode():
    codes = snac_model.encode(waveform)

  all_codes = []
  for i in range(codes[0].shape[1]):
    all_codes.append(codes[0][0][i].item()+128266)
    all_codes.append(codes[1][0][2*i].item()+128266+4096)
    all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
    all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
    all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
    all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
    all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))


  return all_codes

myts = tokenise_audio(audio_array)
start_tokens = torch.tensor([[ 128259]], dtype=torch.int64)
end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
final_tokens = torch.tensor([[128258, 128262]], dtype=torch.int64)
voice_prompt = and_the_transcript_is
prompt_tokked = tokenizer(voice_prompt, return_tensors="pt")

input_ids = prompt_tokked["input_ids"]

zeroprompt_input_ids = torch.cat([start_tokens, input_ids, end_tokens, torch.tensor([myts]), final_tokens], dim=1) # SOH SOT Text EOT EOH

prompts = the_model_should_say

all_modified_input_ids = []
for prompt in prompts:
  input_ids = tokenizer(f"female1: {prompt}", return_tensors="pt").input_ids
  second_input_ids = torch.cat([zeroprompt_input_ids, start_tokens, input_ids, end_tokens], dim=1)
  all_modified_input_ids.append(second_input_ids)


all_padded_tensors = []
all_attention_masks = []

max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])

for modified_input_ids in all_modified_input_ids:
  padding = max_length - modified_input_ids.shape[1]
  padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
  attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
  all_padded_tensors.append(padded_tensor)
  all_attention_masks.append(attention_mask)

all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
all_attention_masks = torch.cat(all_attention_masks, dim=0)

input_ids = all_padded_tensors.to(device)
attention_mask = all_attention_masks.to(device)

#@title Run Inference

with torch.no_grad():
  generated_ids = model.generate(
      input_ids=input_ids,
      # attention_mask=attention_mask,
      max_new_tokens=990,
      do_sample=True,
      temperature=0.5,
      # top_k=40,
      top_p=0.95,
      repetition_penalty=1.1,
      num_return_sequences=1,
      eos_token_id=128258,
      # end_token_id=128009
  )

token_to_find = 128257
token_to_remove = 128258

token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)

if len(token_indices[1]) > 0:
    last_occurrence_idx = token_indices[1][-1].item()
    cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
    cropped_tensor = generated_ids

mask = cropped_tensor != token_to_remove
processed_rows = []
for row in cropped_tensor:
    masked_row = row[row != token_to_remove]
    processed_rows.append(masked_row)

code_lists = []
for row in processed_rows:
    row_length = row.size(0)
    new_length = (row_length // 7) * 7
    trimmed_row = row[:new_length]
    trimmed_row = [t - 128266 for t in trimmed_row]
    code_lists.append(trimmed_row)

def redistribute_codes(code_list):
  layer_1 = []
  layer_2 = []
  layer_3 = []
  for i in range((len(code_list)+1)//7):
    layer_1.append(code_list[7*i])
    layer_2.append(code_list[7*i+1]-4096)
    layer_3.append(code_list[7*i+2]-(2*4096))
    layer_3.append(code_list[7*i+3]-(3*4096))
    layer_2.append(code_list[7*i+4]-(4*4096))
    layer_3.append(code_list[7*i+5]-(5*4096))
    layer_3.append(code_list[7*i+6]-(6*4096))
  codes = [torch.tensor(layer_1).unsqueeze(0),
         torch.tensor(layer_2).unsqueeze(0),
         torch.tensor(layer_3).unsqueeze(0)]
  audio_hat = snac_model.decode(codes)
  return audio_hat

my_samples = []
for code_list in code_lists:
  samples = redistribute_codes(code_list)
  my_samples.append(samples)

from IPython.display import Audio, display
for samples in my_samples:
  display(Audio(samples.detach().squeeze().to("cpu").numpy(), rate=24000))