In [None]:
from transformers import AutoTokenizer
from snac import SNAC
import torch
import requests
import os
from IPython.display import display, Audio

# --- Model and Tokenizer Setup ---
tokenizer_name = "canopylabs/3b-hi-ft-research_release"

snac_model_name = "hubertsiuzdak/snac_24khz"


url = "https://api.friendli.ai/dedicated/v1/completions"
friendli_token = os.environ.get("FRIENDLI_TOKEN")
friendli_model_id = os.environ.get("FRIENDLI_EID") or "YOUR_ENDPOINT_ID" # Friendli AI Model ID e.g. "fdo6dto2hkng"

# Load SNAC model (using CPU)
snac_model = SNAC.from_pretrained(snac_model_name)
snac_model = snac_model.to("cpu")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

# --- Input Prompt Setup ---
# Single prompt string
prompt = "हेलो आज आप कैसे हैं?"
chosen_voice = "zoe" # Voice to use (see GitHub for other voices)

# Add voice tag to the prompt
prompt = f"{chosen_voice}: " + prompt

# --- Tokenization and Special Token Addition ---
# Convert prompt to token IDs
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

# Define special tokens
start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human

# Add special tokens before and after input IDs (SOH SOT Text EOT EOH)
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)

# Check CUDA availability and set device
device = "cuda" if torch.cuda.is_available() else "cpu"
input_ids_for_api = modified_input_ids.to(device) # IDs for API transmission (to GPU if available)

print(f"Using device: {device}")
print(f"Input token IDs for API: {input_ids_for_api}")


if not friendli_token:
    raise ValueError("FRIENDLI_TOKEN environment variable must be set.")

# Configure API request payload
payload = {
    "model": friendli_model_id,
    "tokens": input_ids_for_api.cpu().numpy().tolist()[0], # API expects a list of token IDs (move to CPU and convert)
    "max_tokens": 1200,
    "temperature": 0.6,
    "top_p": 0.95,
    "repetition_penalty": 1.1,
    "eos_token": [128258], # End of speech token
}
headers = {
    "Authorization": "Bearer " + friendli_token,
    "Content-Type": "application/json"
}

# Send API request
print("Calling Friendli AI API...")
response = None # Initialize response to None for error handling scope
try:
    response = requests.post(url, json=payload, headers=headers)
    response.raise_for_status() # Raise exception for HTTP errors
    response_data = response.json()
    print("API call successful.")
except requests.exceptions.RequestException as e:
    print(f"API call failed: {e}")
    if response is not None:
        print(f"Response content: {response.text}")
    raise # Stop execution on error

# Extract generated token IDs from API response
response_ids = response_data['choices'][0]['tokens']
generated_ids = torch.tensor([response_ids], dtype=torch.int64).to(device) # Convert results to tensor (to GPU if available)
print(f"Generated token IDs (Raw): {generated_ids}")

# --- Output Parsing and Audio Generation ---
# Define tokens for parsing
token_to_find = 128257  # Start of speech token
token_to_remove = 128258 # End of speech token

# Find Start of speech token (128257)
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)

# Extract the part after the Start of speech token
if len(token_indices[1]) > 0:
    last_occurrence_idx = token_indices[1][-1].item()
    cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
    print(f"Token '{token_to_find}' found. Subsequent IDs: {cropped_tensor}")
else:
    # If Start of speech token is not found, use the whole sequence (potential error)
    cropped_tensor = generated_ids
    print(f"Warning: Token '{token_to_find}' (Start of speech) not found. Using all generated IDs.")

# Remove End of speech token (128258)
mask = cropped_tensor != token_to_remove
processed_tensor = cropped_tensor[mask].unsqueeze(0) # Make it a 1D tensor and restore to 2D

print(f"IDs after removing token '{token_to_remove}': {processed_tensor}")

# Prepare for conversion to audio codes (process in units of 7)
code_list = []
if processed_tensor.numel() > 0: # Check if there are tokens to process
    row = processed_tensor[0] # Select the first (only) row
    row_length = row.size(0)
    # Adjust length to be a multiple of 7 (remove last incomplete set)
    new_length = (row_length // 7) * 7
    if new_length < row_length:
         print(f"Warning: The last {row_length - new_length} tokens are ignored as they form an incomplete set of 7.")

    if new_length > 0:
        trimmed_row = row[:new_length]
        # Adjust token ID offset (subtract 128266)
        code_list = [t.item() - 128266 for t in trimmed_row]
        print(f"Audio codes (length: {len(code_list)}): {code_list[:21]}...") # Print only the beginning part
    else:
        print("No valid tokens to convert to audio codes.")
else:
    print("No generated tokens to process.")

# Audio code redistribution function
def redistribute_codes(single_code_list):
  if not single_code_list:
      return torch.tensor([]) # Return empty tensor for empty list

  layer_1 = []
  layer_2 = []
  layer_3 = []
  num_frames = len(single_code_list) // 7

  for i in range(num_frames):
    try:
        # Apply offset for each layer to separate codes
        layer_1.append(single_code_list[7*i])
        layer_2.append(single_code_list[7*i+1] - 4096)
        layer_3.append(single_code_list[7*i+2] - (2*4096))
        layer_3.append(single_code_list[7*i+3] - (3*4096))
        layer_2.append(single_code_list[7*i+4] - (4*4096))
        layer_3.append(single_code_list[7*i+5] - (5*4096))
        layer_3.append(single_code_list[7*i+6] - (6*4096))
    except IndexError as e:
        # This part handles potential errors if the list length isn't a perfect multiple of 7
        # after trimming, though the trimming logic should prevent this.
        print(f"Warning: Error accessing index during code redistribution: {e}. The last frame might be incomplete.")
        break # Stop if frame is incomplete

  # Convert to tensor (add batch dimension)
  codes = [torch.tensor(layer_1).unsqueeze(0),
           torch.tensor(layer_2).unsqueeze(0),
           torch.tensor(layer_3).unsqueeze(0)]

  # Decode audio with SNAC model
  print(f"SNAC Decoding: Layer 1 ({len(layer_1)}), Layer 2 ({len(layer_2)}), Layer 3 ({len(layer_3)}) codes")
  if not layer_1 or not layer_2 or not layer_3:
      print("Warning: Insufficient codes for decoding.")
      return torch.tensor([])

  audio_hat = snac_model.decode(codes)
  return audio_hat

# Generate audio
samples = torch.tensor([]) # Default empty tensor
if code_list: # Execute only if code_list is not empty
    samples = redistribute_codes(code_list)
    print("Audio generation complete.")
else:
    print("Skipping audio generation as there are no audio codes.")

# --- Audio Output ---
if samples.numel() > 0: # Check if generated audio samples exist
    print("\n--- Generated Audio ---")
    print(f"Input prompt: {prompt}")
    # Display audio (works in Jupyter/Colab environments)
    display(Audio(samples.detach().squeeze().cpu().numpy(), rate=24000))
else:
    print("\nNo generated audio.")