# OpenAudio S1 Mini Finetuning

Finetune S1 Mini on custom voice data using LoRA.

**Requirements:**
- Linux with CUDA (Triton required for fast training)
- 16GB+ VRAM recommended
- Dataset in LJSpeech format (metadata.csv + wavs/)

**Steps:**
1. Environment setup
2. Dataset preparation
3. VQ token extraction
4. Protobuf dataset build
5. LoRA finetuning
6. Merge and export


In [None]:
# Cell 1: Environment Setup
import os
import sys
import platform
from pathlib import Path
import subprocess
import shutil
import time

# Navigate to fish-speech root
notebook_dir = Path.cwd()
if notebook_dir.name == 'notebooks':
    project_root = notebook_dir.parent.parent
else:
    project_root = notebook_dir.parent

os.chdir(project_root)
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")

import torch

print("\n" + "="*60)
print("ENVIRONMENT CHECK")
print("="*60)
print(f"Python: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"VRAM: {vram_gb:.1f} GB")
    if vram_gb >= 16:
        print("[OK] Sufficient VRAM")
    else:
        print("[WARN] Low VRAM - reduce batch size")
else:
    print("[ERROR] No GPU detected!")

try:
    import triton
    print(f"Triton: {triton.__version__}")
except ImportError:
    print("[INFO] Triton not available")


Project root: c:\Users\PC\Desktop\fish-speech

ENVIRONMENT CHECK
Python: 3.10.11
PyTorch: 2.9.1+cu130
CUDA Available: True
GPU: NVIDIA GeForce RTX 5070 Ti
VRAM: 15.9 GB
[WARN] Low VRAM - reduce batch size
[INFO] Triton not available


In [None]:
# Cell 2: Fix Protobuf Compatibility
# The protobuf files were generated with protobuf 6.x which includes runtime_version
# If you have an older protobuf version, we need to patch the files or upgrade

import google.protobuf

print("Checking protobuf version...")
protobuf_version = google.protobuf.__version__
print(f"Installed protobuf: {protobuf_version}")

# Check if runtime_version exists
try:
    from google.protobuf import runtime_version
    print("[OK] Protobuf version is compatible (has runtime_version)")
    PROTOBUF_FIXED = True
except ImportError:
    print(f"[WARN] Protobuf {protobuf_version} doesn't have runtime_version")
    print("[INFO] Patching protobuf file for compatibility...")
    
    # Patch the protobuf file to work with older versions
    pb2_file = project_root / "fish_speech" / "datasets" / "protos" / "text_data_pb2.py"
    
    if pb2_file.exists():
        content = pb2_file.read_text(encoding='utf-8')
        
        # Check if already patched
        if "# Compatibility shim" in content or "try:" in content.split('\n')[7:12]:
            print("[INFO] File already patched")
            PROTOBUF_FIXED = True
        elif "from google.protobuf import runtime_version as _runtime_version" in content:
            # Replace the import and validation with try/except
            old_import = """from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
    _runtime_version.Domain.PUBLIC,
    6,
    31,
    1,
    '',
    'text-data.proto'
)"""
            
            new_import = """from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
# Compatibility shim for older protobuf versions
try:
    from google.protobuf import runtime_version as _runtime_version
    _runtime_version.ValidateProtobufRuntimeVersion(
        _runtime_version.Domain.PUBLIC,
        6,
        31,
        1,
        '',
        'text-data.proto'
    )
except ImportError:
    # Older protobuf versions don't have runtime_version - skip validation
    pass
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder"""
            
            if old_import in content:
                content = content.replace(old_import, new_import)
                pb2_file.write_text(content, encoding='utf-8')
                print(f"[OK] Patched {pb2_file.name} for compatibility")
                PROTOBUF_FIXED = True
            else:
                print("[WARN] Could not find exact pattern to patch")
                PROTOBUF_FIXED = False
        else:
            print("[INFO] File doesn't need patching or already patched")
            PROTOBUF_FIXED = True
    else:
        print(f"[ERROR] Protobuf file not found: {pb2_file}")
        PROTOBUF_FIXED = False
    
    # Try importing again
    if PROTOBUF_FIXED:
        try:
            from fish_speech.datasets.protos.text_data_pb2 import SampledData
            print("[OK] Protobuf import successful after patching")
            PROTOBUF_FIXED = True
        except Exception as e:
            print(f"[ERROR] Still can't import protobuf: {e}")
            print("[SUGGESTION] Upgrade protobuf: pip install --upgrade protobuf>=6.0.0")
            PROTOBUF_FIXED = False

if not PROTOBUF_FIXED:
    print("\n[ACTION REQUIRED] Please run:")
    print("  pip install --upgrade protobuf>=6.0.0")
    print("Then restart this notebook.")
else:
    print("\n[OK] Protobuf compatibility check passed - ready to proceed!")


Checking protobuf version...
Installed protobuf: 3.20.3
[WARN] Protobuf 3.20.3 doesn't have runtime_version
[INFO] Patching protobuf file for compatibility...
[INFO] File already patched
[OK] Protobuf import successful after patching

[OK] Protobuf compatibility check passed - ready to proceed!


In [None]:
# Cell 2: Configuration - EDIT THESE VALUES

# Dataset paths
DATASET_NAME = "neymar"
DATASET_DIR = Path("neymar_Dataset_enhanced")
METADATA_CSV = DATASET_DIR / "metadata.csv"
WAVS_DIR = DATASET_DIR / "wavs"

# Output paths
DATA_DIR = Path("data")
SPEAKER_DIR = DATA_DIR / DATASET_NAME
PROTOS_DIR = DATA_DIR / "protos"

# Model paths
CHECKPOINT_PATH = Path("checkpoints/openaudio-s1-mini")
CODEC_PATH = CHECKPOINT_PATH / "codec.pth"

# Training config
PROJECT_NAME = f"{DATASET_NAME}_finetune"
MAX_STEPS = 1000
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
VAL_CHECK_INTERVAL = 100
LORA_R = 8
LORA_ALPHA = 16

# Output model
OUTPUT_MODEL = Path(f"checkpoints/openaudio-s1-mini-{DATASET_NAME}")

# Verify paths
print("="*60)
print("CONFIGURATION")
print("="*60)
print(f"Dataset: {DATASET_DIR}")
print(f"Output: {OUTPUT_MODEL}")
print(f"Training: {MAX_STEPS} steps, batch={BATCH_SIZE}, lr={LEARNING_RATE}")
print(f"LoRA: r={LORA_R}, alpha={LORA_ALPHA}")

assert DATASET_DIR.exists(), f"Dataset not found: {DATASET_DIR}"
assert CHECKPOINT_PATH.exists(), f"Model not found: {CHECKPOINT_PATH}"

wav_count = len(list(WAVS_DIR.glob("*.wav"))) if WAVS_DIR.exists() else 0
print(f"\n[OK] Found {wav_count} audio files")


CONFIGURATION
Dataset: neymar_Dataset_enhanced
Output: checkpoints\openaudio-s1-mini-neymar
Training: 1000 steps, batch=4, lr=0.0001
LoRA: r=8, alpha=16

[OK] Found 743 audio files


In [None]:
# Cell 3: Analyze Dataset
import pandas as pd
import soundfile as sf
from IPython.display import Audio, display

# Parse LJSpeech metadata
with open(METADATA_CSV, 'r', encoding='utf-8') as f:
    lines = f.readlines()

data = []
for line in lines:
    parts = line.strip().split('|')
    if len(parts) >= 2:
        data.append({'id': parts[0], 'text': parts[1]})

df = pd.DataFrame(data)
print(f"Entries: {len(df)}")

# Text stats
df['text_length'] = df['text'].str.len()
print(f"Text length: {df['text_length'].min()}-{df['text_length'].max()} chars")

# Audio stats
total_duration = 0
for idx, row in df.iterrows():
    wav_path = WAVS_DIR / f"{row['id']}.wav"
    if wav_path.exists():
        try:
            info = sf.info(str(wav_path))
            total_duration += info.duration
        except:
            pass

print(f"Total duration: {total_duration/60:.1f} minutes")

# Sample
display(df.head())
sample_wav = WAVS_DIR / f"{df.iloc[0]['id']}.wav"
if sample_wav.exists():
    display(Audio(filename=str(sample_wav)))


Entries: 742
Text length: 25-295 chars
Total duration: 77.4 minutes


Unnamed: 0,id,text,text_length
0,NEY0001,"Porque dÃ³i muito, nÃ©? Ter o sonho e ir embor...",81
1,NEY0002,"Eu preferia muito bem nÃ£o ter feito o gol, es...",159
2,NEY0003,A importÃ¢ncia do Instituto pra mim Ã© muito g...,52
3,NEY0004,o Instituto Ã© o gol da minha carreira mais im...,72
4,NEY0005,"EntÃ£o, isso pra mim Ã© um orgulho muito grand...",137


In [None]:
# Cell 4: Convert to Fish Speech format
print("Converting dataset...")

SPEAKER_DIR.mkdir(parents=True, exist_ok=True)

converted = 0
for idx, row in df.iterrows():
    file_id = row['id']
    text = row['text']
    
    src_wav = WAVS_DIR / f"{file_id}.wav"
    dst_wav = SPEAKER_DIR / f"{file_id}.wav"
    dst_lab = SPEAKER_DIR / f"{file_id}.lab"
    
    if not src_wav.exists():
        continue
    
    if not dst_wav.exists():
        shutil.copy2(src_wav, dst_wav)
    
    dst_lab.write_text(text, encoding='utf-8')
    converted += 1

print(f"Converted {converted} files to {SPEAKER_DIR}")

# Verify
wav_count = len(list(SPEAKER_DIR.glob("*.wav")))
lab_count = len(list(SPEAKER_DIR.glob("*.lab")))
print(f"Verification: {wav_count} WAV, {lab_count} LAB")


Converting dataset...
Converted 742 files to data\neymar
Verification: 742 WAV, 742 LAB


In [None]:
# Cell 5: Extract VQ tokens
print("Extracting VQ tokens (this takes several minutes)...")

cmd = [
    sys.executable,
    "tools/vqgan/extract_vq.py",
    str(DATA_DIR),
    "--num-workers", "1",
    "--batch-size", "16",
    "--config-name", "modded_dac_vq",
    "--checkpoint-path", str(CODEC_PATH),
]

start = time.time()
result = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8', errors='replace')
elapsed = time.time() - start

if result.returncode != 0:
    print(f"[ERROR] VQ extraction failed")
    print(result.stderr[-1500:])
else:
    npy_files = list(SPEAKER_DIR.glob("*.npy"))
    print(f"[OK] Generated {len(npy_files)} .npy files in {elapsed:.0f}s")


Extracting VQ tokens (this takes several minutes)...
[OK] Generated 742 .npy files in 8s


In [None]:
# Cell 6: Build protobuf dataset
print("Building protobuf dataset...")

PROTOS_DIR.mkdir(parents=True, exist_ok=True)

cmd = [
    sys.executable,
    "tools/llama/build_dataset.py",
    "--input", str(DATA_DIR),
    "--output", str(PROTOS_DIR),
    "--text-extension", ".lab",
    "--num-workers", "4",
]

result = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8', errors='replace')

if result.returncode != 0:
    print(f"[ERROR] Dataset build failed")
    print(result.stderr[-1500:])
else:
    proto_files = list(PROTOS_DIR.iterdir())
    total_size = sum(f.stat().st_size for f in proto_files if f.is_file())
    print(f"[OK] Built {len(proto_files)} files ({total_size/1024/1024:.1f} MB)")


Building protobuf dataset...
[OK] Built 1 files (3.8 MB)


In [15]:
# Cell 7: LoRA Finetuning
print(f"Starting LoRA finetuning: {MAX_STEPS} steps...")

cmd = [
    sys.executable,
    "fish_speech/train.py",
    "--config-name", "text2semantic_finetune",
    f"project={PROJECT_NAME}",
    f"trainer.max_steps={MAX_STEPS}",
    f"trainer.val_check_interval={VAL_CHECK_INTERVAL}",
    f"data.batch_size={BATCH_SIZE}",
    f"model.optimizer.lr={LEARNING_RATE}",
    f"+lora@model.model.lora_config=r_{LORA_R}_alpha_{LORA_ALPHA}",
]

# Fix: Delete checkpoint directory to prevent auto-resume
# The training code calls get_latest_checkpoint() unconditionally
ckpt_dir = project_root / f"results/{PROJECT_NAME}/checkpoints"
if ckpt_dir.exists():
    print(f"[INFO] Removing existing checkpoints to start fresh training...")
    shutil.rmtree(ckpt_dir)


# Windows: use single device strategy (no DDP)
if platform.system() == 'Windows':
    cmd.append("trainer.strategy=auto")
    cmd.append("trainer.devices=1")
start = time.time()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding='utf-8', errors='replace', bufsize=1)

# Safely iterate over stdout
if process.stdout is not None:
    for line in process.stdout:
        print(line, end='')
else:
    # Fallback: wait and get output
    stdout, stderr = process.communicate()
    if stdout:
        print(stdout)
    if stderr:
        print(stderr, file=sys.stderr)

process.wait()
elapsed = time.time() - start

if process.returncode == 0:
    print(f"\n[OK] Training completed in {elapsed/60:.1f} minutes")
else:
    print(f"\n[ERROR] Training failed")


Starting LoRA finetuning: 1000 steps...
[INFO] Removing existing checkpoints to start fresh training...
[2025-12-19 05:57:59,290][__main__][INFO] - [rank: 0] Instantiating datamodule <fish_speech.datasets.semantic.SemanticDataModule>
[2025-12-19 05:57:59,462][datasets][INFO] - PyTorch version 2.9.1+cu130 available.
[2025-12-19 05:58:00,019][__main__][INFO] - [rank: 0] Instantiating model <fish_speech.models.text2semantic.lit_module.TextToSemantic>
[32m2025-12-19 05:58:00.029[0m | [1mINFO    [0m | [36mfish_speech.models.text2semantic.llama[0m:[36mfrom_pretrained[0m:[36m416[0m - [1mOverride max_seq_len to 4096[0m
[32m2025-12-19 05:58:00.170[0m | [1mINFO    [0m | [36mfish_speech.models.text2semantic.llama[0m:[36mfrom_pretrained[0m:[36m432[0m - [1mLoading model from checkpoints/openaudio-s1-mini, config: DualARModelArgs(model_type='dual_ar', vocab_size=155776, n_layer=28, n_head=16, dim=1024, intermediate_size=3072, n_local_heads=8, head_dim=128, rope_base=1000000, n

In [None]:
# Cell 8: Merge LoRA weights
results_dir = Path(f"results/{PROJECT_NAME}/checkpoints")

if not results_dir.exists():
    print(f"[ERROR] No results found. Run training first.")
else:
    checkpoints = sorted(results_dir.glob("*.ckpt"))
    
    if checkpoints:
        LORA_CHECKPOINT = checkpoints[-1]
        print(f"Using checkpoint: {LORA_CHECKPOINT.name}")
        
        cmd = [
            sys.executable,
            "tools/llama/merge_lora.py",
            "--lora-config", f"r_{LORA_R}_alpha_{LORA_ALPHA}",
            "--base-weight", str(CHECKPOINT_PATH),
            "--lora-weight", str(LORA_CHECKPOINT),
            "--output", str(OUTPUT_MODEL),
        ]
        
        result = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8', errors='replace')
        
        if result.returncode == 0:
            print(f"[OK] Model saved to {OUTPUT_MODEL}")
            for f in OUTPUT_MODEL.iterdir():
                print(f"  - {f.name}")
        else:
            print(f"[ERROR] Merge failed")
            print(result.stderr)


Using checkpoint: step_000001000.ckpt
[OK] Model saved to checkpoints\openaudio-s1-mini-neymar
  - config.json
  - model.pth
  - special_tokens.json
  - tokenizer.tiktoken


In [None]:
# Cell 9: Test finetuned model
TEST_MODEL = OUTPUT_MODEL if OUTPUT_MODEL.exists() else CHECKPOINT_PATH
TEST_TEXT = "Hello, this is a test of the finetuned voice."

# Get reference from dataset
sample_wavs = list(WAVS_DIR.glob("*.wav"))[:1]
REF_AUDIO = sample_wavs[0] if sample_wavs else None

print(f"Model: {TEST_MODEL}")
print(f"Text: {TEST_TEXT}")

# Encode reference
# Note: Use base model's codec (merged model doesn't include codec.pth)
if REF_AUDIO:
    result = subprocess.run([
        sys.executable, "fish_speech/models/dac/inference.py",
        "-i", str(REF_AUDIO),
        "--checkpoint-path", str(CHECKPOINT_PATH / "codec.pth"),
    ], capture_output=True, text=True, encoding='utf-8', errors='replace')
    if result.returncode != 0:
        print(f"[WARN] Reference encoding failed: {result.stderr[-200:] if result.stderr else 'No error output'}")

# Generate semantic tokens
# Note: Don't use prompt tokens for simple TTS - they cause the model to generate
# based on the reference audio content rather than the text input
# Prompt tokens are for voice cloning, not for controlling content
cmd = [
    sys.executable, "fish_speech/models/text2semantic/inference.py",
    "--text", TEST_TEXT,
    "--checkpoint-path", str(TEST_MODEL),
]
# #region agent log
import json
log_path = project_root / ".cursor" / "debug.log"
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, 'a', encoding='utf-8') as f:
    json.dump({"sessionId": "debug-session", "runId": "run4", "hypothesisId": "G", "location": "notebook:cell10", "message": "Generating semantic tokens", "data": {"text": TEST_TEXT, "model": str(TEST_MODEL), "using_prompt_tokens": False}, "timestamp": int(time.time() * 1000)}, f)
    f.write('\n')
# #endregion

result = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8', errors='replace')

# #region agent log
with open(log_path, 'a', encoding='utf-8') as f:
    json.dump({"sessionId": "debug-session", "runId": "run4", "hypothesisId": "G", "location": "notebook:cell10", "message": "Text2semantic result", "data": {"returncode": result.returncode, "stdout": result.stdout[-500:] if result.stdout else None, "stderr": result.stderr[-500:] if result.stderr else None}, "timestamp": int(time.time() * 1000)}, f)
    f.write('\n')
# #endregion

if result.returncode != 0:
    print(f"[ERROR] Text2semantic generation failed:")
    if result.stderr:
        print(result.stderr[-1000:])
    if result.stdout:
        print("STDOUT:", result.stdout[-500:])

# Decode to audio
# Note: Use base model's codec (merged model doesn't include codec.pth)
if Path("temp/codes_0.npy").exists():
    # #region agent log
    import json
    log_path = project_root / ".cursor" / "debug.log"
    log_path.parent.mkdir(parents=True, exist_ok=True)
    codec_path = CHECKPOINT_PATH / "codec.pth"
    with open(log_path, 'a', encoding='utf-8') as f:
        json.dump({"sessionId": "debug-session", "runId": "run3", "hypothesisId": "F", "location": "notebook:cell10", "message": "Decoding audio", "data": {"codes_file": "temp/codes_0.npy", "codec_path": str(codec_path), "codec_exists": codec_path.exists()}, "timestamp": int(time.time() * 1000)}, f)
        f.write('\n')
    # #endregion
    
    result = subprocess.run([
        sys.executable, "fish_speech/models/dac/inference.py",
        "-i", "temp/codes_0.npy",
        "--checkpoint-path", str(codec_path),
        "-o", "fake.wav",
    ], capture_output=True, text=True, encoding='utf-8', errors='replace')
    
    # #region agent log
    with open(log_path, 'a', encoding='utf-8') as f:
        json.dump({"sessionId": "debug-session", "runId": "run3", "hypothesisId": "F", "location": "notebook:cell10", "message": "DAC decode result", "data": {"returncode": result.returncode, "stdout": result.stdout[-500:] if result.stdout else None, "stderr": result.stderr[-500:] if result.stderr else None}, "timestamp": int(time.time() * 1000)}, f)
        f.write('\n')
    # #endregion
    
    if result.returncode != 0:
        print(f"[ERROR] DAC decode failed:")
        if result.stderr:
            print(result.stderr[-1000:])
        if result.stdout:
            print("STDOUT:", result.stdout[-500:])

# Play result
if Path("fake.wav").exists():
    final_path = Path(f"outputs/{DATASET_NAME}_test.wav")
    Path("outputs").mkdir(exist_ok=True)
    shutil.move("fake.wav", final_path)
    print(f"[OK] Saved to {final_path}")
    display(Audio(filename=str(final_path)))
else:
    print("[ERROR] Generation failed")


NameError: name 'OUTPUT_MODEL' is not defined

In [None]:
# Standalone Test Cell - Test Finetuned Model
# This cell can be run independently without running previous cells
# It tests the finetuned model with a short text to avoid long generation times

import os
import sys
import subprocess
import shutil
from pathlib import Path
import time

# Setup paths - use absolute paths to avoid resolution issues
# Try multiple methods to find the project root
import os

# Method 1: Try to find fish_speech starting from current directory
current = Path(os.getcwd()).resolve()
project_root = None

# Search upward from current directory
search_path = current
max_depth = 10
for _ in range(max_depth):
    if (search_path / "fish_speech").exists():
        project_root = search_path
        break
    if search_path == search_path.parent:  # Reached root
        break
    search_path = search_path.parent

# Method 2: If not found, try common locations
if project_root is None:
    # Try Desktop/fish-speech
    desktop = Path.home() / "Desktop" / "fish-speech"
    if (desktop / "fish_speech").exists():
        project_root = desktop

# Method 3: Try relative to this notebook file location (if available)
if project_root is None:
    try:
        # In Jupyter, we can try to infer from sys.path
        for path in sys.path:
            p = Path(path).resolve()
            if (p / "fish_speech").exists():
                project_root = p
                break
            # Also check parent
            if (p.parent / "fish_speech").exists():
                project_root = p.parent
                break
    except:
        pass

if project_root is None:
    raise RuntimeError(
        f"Could not find fish_speech directory.\n"
        f"Current working directory: {os.getcwd()}\n"
        f"Searched from: {current}\n"
        f"Please ensure you're running from the fish-speech repository root or set PROJECT_ROOT environment variable."
    )

project_root = project_root.resolve()
print(f"Current working directory: {os.getcwd()}")
print(f"Project root found: {project_root}")

os.chdir(project_root)
sys.path.insert(0, str(project_root))

# Configuration - adjust these as needed
FINETUNED_MODEL = project_root / "checkpoints/openaudio-s1-mini-neymar"
BASE_MODEL = project_root / "checkpoints/openaudio-s1-mini"
TEST_TEXT = "Hello, this is a test."
MAX_NEW_TOKENS = 200  # Limit generation length for faster testing
OUTPUT_DIR = (project_root / "temp").resolve()
print(f"Output directory: {OUTPUT_DIR}")

# Create output directory with proper error handling
try:
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    print(f"[OK] Output directory ready: {OUTPUT_DIR}")
except (PermissionError, OSError) as e:
    # Fallback: use a temp directory in user's temp folder
    import tempfile
    OUTPUT_DIR = Path(tempfile.gettempdir()) / "fish_speech_test"
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    print(f"[WARN] Could not create temp dir in project root ({e}), using: {OUTPUT_DIR}")

# Set to True to test base model first (for comparison)
TEST_BASE_MODEL_FIRST = False  # Now testing finetuned model with proper 1000-step checkpoint

print("="*60)
print("STANDALONE FINETUNED MODEL TEST")
print("="*60)
print(f"Finetuned model: {FINETUNED_MODEL}")
print(f"Base model (for codec): {BASE_MODEL}")
print(f"Test text: {TEST_TEXT}")
print(f"Max tokens: {MAX_NEW_TOKENS}")
print()

# Check if models exist
TEST_MODEL = None
OUTPUT_SUFFIX = "_finetuned"

if not BASE_MODEL.exists():
    print(f"[ERROR] Base model not found: {BASE_MODEL}")
elif TEST_BASE_MODEL_FIRST:
    print("[INFO] Testing base model first for comparison...")
    TEST_MODEL = BASE_MODEL
    OUTPUT_SUFFIX = "_base"
else:
    if not FINETUNED_MODEL.exists():
        print(f"[ERROR] Finetuned model not found: {FINETUNED_MODEL}")
        print("Please run the training and merge cells first.")
    else:
        TEST_MODEL = FINETUNED_MODEL
        OUTPUT_SUFFIX = "_finetuned"

if TEST_MODEL:
    # Step 1: Generate semantic tokens
    print("[1/3] Generating semantic tokens...")
    total_start = time.time()
    step_start = time.time()
    
    cmd = [
        sys.executable,
        "fish_speech/models/text2semantic/inference.py",
        "--text", TEST_TEXT,
        "--checkpoint-path", str(TEST_MODEL),
        "--max-new-tokens", str(MAX_NEW_TOKENS),  # Limit generation length
        "--output-dir", str(OUTPUT_DIR),
    ]
    
    result = subprocess.run(
        cmd,
        capture_output=True,
        text=True,
        encoding='utf-8',
        errors='replace'
    )
    
    elapsed = time.time() - step_start
    
    if result.returncode != 0:
        print(f"[ERROR] Text2semantic generation failed after {elapsed:.1f}s")
        if result.stderr:
            print("STDERR:", result.stderr[-500:])
        if result.stdout:
            print("STDOUT:", result.stdout[-500:])
    else:
        print(f"[OK] Semantic tokens generated in {elapsed:.1f}s")
        
        # Step 2: Decode to audio
        codes_file = OUTPUT_DIR / "codes_0.npy"
        if codes_file.exists():
            # Validate codes before decoding
            import numpy as np
            codes = np.load(codes_file)
            print(f"  Codes shape: {codes.shape}")
            print(f"  Codes dtype: {codes.dtype}")
            print(f"  Codes range: {codes.min()} - {codes.max()}")
            
            # Check for potential issues
            if codes.shape[0] != 10:
                print(f"  [WARN] Expected 10 codebooks, got {codes.shape[0]}")
            if codes.max() >= 4096:
                print(f"  [WARN] Code values exceed codebook size (max={codes.max()})")
            if (codes == 0).sum() / codes.size > 0.9:
                print(f"  [WARN] Too many zeros in codes ({(codes == 0).sum() / codes.size * 100:.1f}%)")
            
            # Check first codebook (should be semantic, can have zeros)
            non_zero_first = (codes[0] != 0).sum()
            print(f"  First codebook non-zero tokens: {non_zero_first}/{codes.shape[1]}")
            
            # Check other codebooks (should have values)
            for i in range(1, min(4, codes.shape[0])):
                non_zero = (codes[i] != 0).sum()
                print(f"  Codebook {i} non-zero tokens: {non_zero}/{codes.shape[1]}")
            
            if codes.shape[0] == 10 and codes.max() < 4096:
                print(f"[2/3] Decoding audio from {codes_file}...")
                step_start = time.time()
                
                cmd = [
                    sys.executable,
                    "fish_speech/models/dac/inference.py",
                    "-i", str(codes_file),
                    "--checkpoint-path", str(BASE_MODEL / "codec.pth"),
                    "-o", str(OUTPUT_DIR / f"test_output{OUTPUT_SUFFIX}.wav"),
                ]
                
                result = subprocess.run(
                    cmd,
                    capture_output=True,
                    text=True,
                    encoding='utf-8',
                    errors='replace'
                )
                
                elapsed = time.time() - step_start
                
                if result.returncode != 0:
                    print(f"[ERROR] Audio decoding failed after {elapsed:.1f}s")
                    if result.stderr:
                        print("STDERR:", result.stderr[-500:])
                else:
                    print(f"[OK] Audio decoded in {elapsed:.1f}s")
                    
                    # Step 3: Display result
                    output_wav = OUTPUT_DIR / f"test_output{OUTPUT_SUFFIX}.wav"
                    if output_wav.exists():
                        print(f"[3/3] Success! Audio saved to: {output_wav}")
                        print(f"Total time: {time.time() - total_start:.1f}s")
                        
                        if TEST_BASE_MODEL_FIRST:
                            print("\n[INFO] Base model test complete. Set TEST_BASE_MODEL_FIRST=False to test finetuned model.")
                        
                        # Try to display audio if IPython is available
                        try:
                            from IPython.display import Audio, display
                            display(Audio(filename=str(output_wav)))
                        except ImportError:
                            print("(Install IPython to play audio in notebook)")
                    else:
                        print("[ERROR] Output audio file not found")
            else:
                print("[ERROR] Codes validation failed - check warnings above")
                print("The finetuned model may not have converged properly.")
                print("Try:")
                print("  1. Training for more steps")
                print("  2. Using the base model to verify the pipeline works")
                print("  3. Checking training logs for convergence issues")
        else:
            print(f"[ERROR] Codes file not found: {codes_file}")
            print("Check the text2semantic generation output above for errors.")


Current working directory: C:\Users\PC\Desktop\fish-speech
Project root found: C:\Users\PC\Desktop\fish-speech
Output directory: C:\Users\PC\Desktop\fish-speech\temp
[OK] Output directory ready: C:\Users\PC\Desktop\fish-speech\temp
STANDALONE FINETUNED MODEL TEST
Finetuned model: C:\Users\PC\Desktop\fish-speech\checkpoints\openaudio-s1-mini-neymar
Base model (for codec): C:\Users\PC\Desktop\fish-speech\checkpoints\openaudio-s1-mini
Test text: Hello, this is a test.
Max tokens: 200

[1/3] Generating semantic tokens...
[OK] Semantic tokens generated in 22.4s
  Codes shape: (10, 199)
  Codes dtype: int32
  Codes range: 0 - 4061
  First codebook non-zero tokens: 6/199
  Codebook 1 non-zero tokens: 199/199
  Codebook 2 non-zero tokens: 198/199
  Codebook 3 non-zero tokens: 199/199
[2/3] Decoding audio from C:\Users\PC\Desktop\fish-speech\temp\codes_0.npy...
[OK] Audio decoded in 8.8s
[3/3] Success! Audio saved to: C:\Users\PC\Desktop\fish-speech\temp\test_output_finetuned.wav
Total time: 31

## Next Steps

Your finetuned model is at: `checkpoints/openaudio-s1-mini-{name}/`

**Use in voice-service:**
```bash
VOICE_S1_CHECKPOINT_PATH=checkpoints/openaudio-s1-mini-neymar
```

**Enable LoRA hot-swap:**
```bash
VOICE_LORA_ENABLED=true
VOICE_LORA_PATH=results/neymar_finetune/checkpoints/step_xxx.ckpt
```
