In [None]:
import os
import torch
import re
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from huggingface_hub import snapshot_download

# --- 1. SETUP PATHS ---
# Your fine-tuned model path on HF
FINE_TUNED_REPO = "omarmosleh/spark-tts-merged-dataset" 
# The base model for architecture config and tokenizer
BASE_MODEL_REPO = "SparkAudio/Spark-TTS-0.5B" 
BASE_DIR = "Spark-TTS-0.5B"

os.environ["HF_TOKEN"] = "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"

# --- 2. DOWNLOAD BASE ASSETS (Config/Tokenizer) ---
# We need the official config/tokenizer because fine-tuned repos often lack them
if not os.path.exists(BASE_DIR):
    print(f"‚è≥ Downloading base assets from {BASE_MODEL_REPO}...")
    snapshot_download(repo_id=BASE_MODEL_REPO, local_dir=BASE_DIR, token=os.environ["HF_TOKEN"])

# --- 3. INITIALIZE COMPONENTS ---
print("‚è≥ Loading model with fine-tuned weights...")

# Load official config to force 0.5B architecture dimensions
config = AutoConfig.from_pretrained(os.path.join(BASE_DIR, "LLM"), trust_remote_code=True)

# Load official tokenizer
tokenizer = AutoTokenizer.from_pretrained(os.path.join(BASE_DIR, "LLM"), trust_remote_code=True)

# Load your FINE-TUNED weights from HF using the official config
model = AutoModelForCausalLM.from_pretrained(
    FINE_TUNED_REPO,
    config=config,
    trust_remote_code=True,
    dtype=torch.float32,
    device_map={"": "cpu"}, # Force CPU
    token=os.environ["HF_TOKEN"]
)

# Load Audio Tokenizer (The Decoder)
audio_tokenizer = BiCodecTokenizer(BASE_DIR, "cpu")

print("‚úÖ System Ready.")

In [None]:


# --- 4. INFERENCE FUNCTION ---
@torch.inference_mode()
def generate_speech(text):
    prompt = f"<|task_tts|><|start_content|>{text}<|end_content|><|start_global_token|>"
    inputs = tokenizer([prompt], return_tensors="pt").to("cpu")
    
    # Generate tokens
    gen_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.4)
    gen_text = tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=False)[0]
    
    # Extract IDs using Regex
    s_tokens = torch.tensor([int(t) for t in re.findall(r"<\|bicodec_semantic_(\d+)\|>", gen_text)]).long().unsqueeze(0)
    g_tokens = torch.tensor([int(t) for t in re.findall(r"<\|bicodec_global_(\d+)\|>", gen_text)]).long().unsqueeze(0).unsqueeze(0)

    if s_tokens.shape[1] == 0: return None

    # Detokenize to Audio
    return audio_tokenizer.detokenize(g_tokens.squeeze(0), s_tokens)

# --- 5. EXECUTION ---
input_text = "ŸáŸÜŸàŸàÿØÿ©ÿå ŸáÿßÿØ ÿßŸÑÿµŸàÿ™ ŸÖÿπŸÖŸàŸÑ ŸÉÿßŸÖŸÑÿß ÿ®ÿßŸÑ AI"
wav = generate_speech(input_text)

if wav is not None:
    import soundfile as sf
    sf.write("output.wav", wav, 16000)
    print("‚úÖ Success! Created output.wav")
else:
    print("‚ùå Failed to generate tokens.")

‚è≥ Re-aligning model architecture for CPU...
Missing tensor: mel_transformer.spectrogram.window
Missing tensor: mel_transformer.mel_scale.fb
‚úÖ Model loaded with 0.5B dimensions. Try running your inference script now.


In [None]:
import os
import re
import torch
import numpy as np
import soundfile as sf
from IPython.display import Audio, display

# --- CONFIGURATION ---
OUTPUT_DIR = "fulldataset"
os.makedirs(OUTPUT_DIR, exist_ok=True)

@torch.inference_mode()
def generate_speech_enhanced(text, file_prefix="01"):
    """
    Generates speech and provides detailed logs of the tokenization process.
    """
    print(f"\nüöÄ Starting Inference for: '{text}'")
    
    # 1. Prepare Prompt
    prompt = f"<|task_tts|><|start_content|>{text}<|end_content|><|start_global_token|>"
    inputs = tokenizer([prompt], return_tensors="pt").to("cpu")
    
    # 2. Generate Tokens from LLM
    print("‚è≥ Generating tokens from LLM (this may take a moment on CPU)...")
    gen_ids = model.generate(
        **inputs, 
        max_new_tokens=1024, 
        do_sample=True, 
        temperature=0.4,
        top_k=50
    )
    
    # 3. Decode and Parse
    # We skip special tokens = False because we NEED to see the <|bicodec_...|> tags
    gen_text = tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=False)[0]
    
    # Extract Semantic and Global tokens using regex
    semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", gen_text)
    global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", gen_text)
    
    s_tokens = torch.tensor([int(t) for t in semantic_matches]).long().unsqueeze(0)
    g_tokens = torch.tensor([int(t) for t in global_matches]).long().unsqueeze(0).unsqueeze(0)

    # --- METADATA LOGGING ---
    print("-" * 30)
    print(f"üìä INFERENCE STATISTICS:")
    print(f"üîπ Global Tokens Generated:   {len(global_matches)}")
    print(f"üîπ Semantic Tokens Generated: {len(semantic_matches)}")
    
    if s_tokens.shape[1] == 0:
        print("‚ùå Error: No semantic tokens were generated. Try a different prompt.")
        return None

    # 4. Detokenize to Audio (The Codec Step)
    print("üîä Converting tokens to waveform via BiCodec...")
    wav = audio_tokenizer.detokenize(g_tokens.squeeze(0), s_tokens)
    
    # 5. Save and Export
    # Clean text for filename (remove special characters)
    clean_text = re.sub(r'[^\w\s-]', '', text).strip().replace(' ', '_')
    filename = f"{file_prefix}_{clean_text}.wav"
    filepath = os.path.join(OUTPUT_DIR, filename)
    
    sample_rate = 16000 # Standard for Spark-TTS
    sf.write(filepath, wav, sample_rate)
    
    file_size_kb = os.path.getsize(filepath) / 1024
    print(f"‚úÖ Audio saved to: {filepath}")
    print(f"üìè Output Size: {file_size_kb:.2f} KB")
    print("-" * 30)
    
    # 6. Play in Notebook
    display(Audio(wav, rate=sample_rate))
    
    return filepath

# --- EXECUTION ---
# You can manually change the prefix (the asterisk replacement)
manual_prefix = "001" 
input_text = "ÿ¥Ÿà ÿßŸÑÿ£ÿÆÿ®ÿßÿ±ÿü ÿ∑ŸÖŸÜŸÜŸä ÿπŸÜŸÉÿå ÿßŸÜÿ¥ÿßŸÑŸÑŸá ÿ®ÿÆŸäÿ±ÿü"

output_path = generate_speech_enhanced(input_text, file_prefix=manual_prefix)

Generating speech for: 'ŸáŸÜŸàŸàÿØÿ©ÿå ŸáÿßÿØ ÿßŸÑÿµŸàÿ™ ŸÖÿπŸÖŸàŸÑ ŸÉÿßŸÖŸÑÿß ÿ®ÿßŸÑ AI'
Generating token sequence...
Token sequence generated.
Found 346 semantic tokens.
Found 3 global tokens.
Detokenizing audio tokens...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x384 and 4096x1024)