# Vocal Separation Model ‚Äî BS-RoFormer (Pre-trained + Fine-tuning)

This notebook downloads a **state-of-the-art pre-trained BS-RoFormer** vocal separation model and optionally fine-tunes it.

**Pre-trained model:** `model_bs_roformer_ep_317_sdr_12.9755.ckpt` (viperx edition)
- **SDR (vocals): 12.97 dB** on MUSDB18 test set
- Much better than training from scratch (~8-10 dB)
- Based on ZFTurbo/Music-Source-Separation-Training framework

**Why pre-trained?**
- MUSDB18-HQ requires manual Zenodo access request
- Training from scratch takes 24-48 hours on A100
- Pre-trained model already achieves SOTA quality

**Estimated time:** ~5 minutes to download model, save to Google Drive

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install audio-separator (handles model download automatically)
!pip install -q audio-separator[gpu] torch torchaudio

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Download pre-trained BS-RoFormer vocal separation model
# This model achieves 12.97 SDR on MUSDB18 (state-of-the-art)
# audio-separator will auto-download the checkpoint

from audio_separator.separator import Separator
import os

# Initialize separator with the best BS-RoFormer vocal model
# This triggers automatic download of the model checkpoint (~350MB)
print("Downloading pre-trained BS-RoFormer vocal model...")
print("Model: model_bs_roformer_ep_317_sdr_12.9755.ckpt")
print("Expected SDR: 12.97 dB (vocals), 17.0 dB (instrumental)")

separator = Separator(
    model_file_dir='/content/models',
    output_dir='/content/test_output'
)

# Load the BS-RoFormer model
separator.load_model(model_filename='model_bs_roformer_ep_317_sdr_12.9755.ckpt')
print("\n‚úÖ Model downloaded and loaded successfully!")

# Show the downloaded model file
!ls -lah /content/models/*.ckpt 2>/dev/null || echo "Checking model location..."
!find /content/models -name "*.ckpt" -o -name "*.pth" -o -name "*.pt" 2>/dev/null | head -5

In [None]:
# Quick test: Separate a sample audio to verify model works
# Generate a short test signal (sine wave mix)
import numpy as np
import soundfile as sf

# Create a simple test: vocal-like sine wave + drum-like noise
sr = 44100
duration = 3.0
t = np.linspace(0, duration, int(sr * duration))

# Fake "vocal" - smooth sine
vocal = 0.5 * np.sin(2 * np.pi * 440 * t) * np.exp(-0.5 * t)
# Fake "drums" - noise bursts
drums = 0.3 * np.random.randn(len(t)) * (np.sin(2 * np.pi * 2 * t) > 0.8)
mix = np.column_stack([vocal + drums, vocal + drums])  # stereo

test_path = '/content/test_mix.wav'
sf.write(test_path, mix, sr)
print(f"Created test audio: {test_path}")

# Run separation
outputs = separator.separate(test_path)
print(f"\n‚úÖ Separation test passed! Output files:")
for f in outputs:
    print(f"  - {f}")

In [None]:
# Save model to Google Drive for StemScribe integration
import shutil
from pathlib import Path

SAVE_DIR = Path('/content/drive/MyDrive/vocal_model_results')
SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Find and copy the model checkpoint
model_files = list(Path('/content/models').rglob('*.ckpt'))
if not model_files:
    model_files = list(Path('/content/models').rglob('*.pth'))
if not model_files:
    model_files = list(Path('/content/models').rglob('*.pt'))

# Also look in the default audio-separator cache
import audio_separator
cache_dir = Path(audio_separator.__file__).parent
home_models = Path.home() / '.cache'

# Search common locations
for search_dir in ['/content/models', str(Path.home()), '/tmp']:
    found = list(Path(search_dir).rglob('*bs_roformer*'))
    if found:
        model_files.extend(found)

for f in model_files:
    dest = SAVE_DIR / f.name
    if not dest.exists():
        print(f"Copying {f.name} ({f.stat().st_size / 1e6:.1f} MB) ‚Üí Google Drive...")
        shutil.copy2(str(f), str(dest))
        print(f"  ‚úÖ Saved to {dest}")
    else:
        print(f"  Already exists: {dest}")

# Also find and save the config file
config_files = list(Path('/content/models').rglob('*.yaml'))
for f in config_files:
    dest = SAVE_DIR / f.name
    if not dest.exists():
        shutil.copy2(str(f), str(dest))
        print(f"  ‚úÖ Config saved: {dest}")

print(f"\nüìÅ Files in {SAVE_DIR}:")
!ls -lah {SAVE_DIR}/

## Optional: Fine-tune on Custom Data

If you have your own vocal stems (e.g., from songs you've produced), you can fine-tune
the pre-trained model to better handle your specific music style.

**Requirements:**
- Folders with `vocals.wav` + `other.wav` (or `mixture.wav`)
- Upload to Google Drive under `MyDrive/custom_vocal_data/`
- Each song in its own subfolder

**Fine-tuning benefits:**
- Better handling of your specific genre/recording style
- Improved separation on challenging passages
- Only needs 10-50 songs and a few hours of training

In [None]:
# Optional: Fine-tune the pre-trained model on custom data
# Skip this cell if you don't have custom training data

import os
from pathlib import Path

CUSTOM_DATA = Path('/content/drive/MyDrive/custom_vocal_data')
FINETUNE = CUSTOM_DATA.exists() and any(CUSTOM_DATA.iterdir())

if FINETUNE:
    print(f"Found custom data at {CUSTOM_DATA}")
    songs = [d for d in CUSTOM_DATA.iterdir() if d.is_dir()]
    print(f"Songs available for fine-tuning: {len(songs)}")
    for s in songs[:10]:
        files = list(s.glob('*.wav'))
        print(f"  {s.name}: {len(files)} wav files")
    
    # Install training dependencies
    !pip install -q ml_collections omegaconf beartype protobuf==3.20.3
    !pip install -q audiomentations torch_audiomentations auraloss
    !git clone https://github.com/ZFTurbo/Music-Source-Separation-Training.git /content/mss_training 2>/dev/null || true
    
    # Find the downloaded model checkpoint path
    model_ckpt = list(Path('/content/models').rglob('*bs_roformer*.ckpt'))
    if model_ckpt:
        ckpt_path = str(model_ckpt[0])
        print(f"\nFine-tuning from checkpoint: {ckpt_path}")
    else:
        print("‚ö†Ô∏è Could not find model checkpoint for fine-tuning")
        FINETUNE = False
else:
    print("No custom data found at", CUSTOM_DATA)
    print("To fine-tune, create that folder and add song subfolders with vocals.wav + other.wav")
    print("\nSkipping fine-tuning ‚Äî using pre-trained model as-is (12.97 SDR)")

In [None]:
# Fine-tuning (runs only if custom data was found above)
if FINETUNE:
    import yaml
    
    # Create fine-tuning config (lower learning rate, fewer epochs)
    config = {
        'audio': {
            'chunk_size': 131072,
            'sample_rate': 44100,
            'num_channels': 2,
            'min_mean_abs': 0.001
        },
        'model': {
            'type': 'bs_roformer',
            'dim': 384,
            'depth': 12,
            'stereo': True,
            'num_stems': 1,
            'time_transformer_depth': 1,
            'freq_transformer_depth': 1,
            'num_bands': 60,
            'dim_head': 64,
            'heads': 8,
            'attn_dropout': 0.1,
            'ff_dropout': 0.1,
            'flash_attn': True,
            'stft_n_fft': 2048,
            'stft_hop_length': 512,
        },
        'training': {
            'batch_size': 2,
            'gradient_accumulation_steps': 8,
            'num_epochs': 20,  # Fewer epochs for fine-tuning
            'num_steps': 500,
            'lr': 1e-5,  # Lower LR for fine-tuning
            'instruments': ['vocals', 'other'],
            'target_instrument': 'vocals',
            'use_amp': True,
            'optimizer': 'adamw',
        },
        'augmentations': {
            'enable': True,
            'loudness': True,
            'loudness_min': 0.5,
            'loudness_max': 1.5,
        }
    }
    
    config_path = '/content/mss_training/configs/config_finetune_vocals.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    
    RESULTS_DIR = '/content/drive/MyDrive/vocal_model_results/finetuned'
    os.makedirs(RESULTS_DIR, exist_ok=True)
    
    %cd /content/mss_training
    !python train.py \
        --model_type bs_roformer \
        --config_path {config_path} \
        --start_check_point {ckpt_path} \
        --data_path {CUSTOM_DATA} \
        --results_path {RESULTS_DIR} \
        --dataset_type 1 \
        --device_ids 0 \
        --num_workers 0 \
        --pin_memory
    
    print(f"\n‚úÖ Fine-tuning complete! Models saved to {RESULTS_DIR}")
    !ls -lah {RESULTS_DIR}/*.ckpt 2>/dev/null
else:
    print("Skipping fine-tuning (no custom data)")
    print("Pre-trained model (12.97 SDR) is ready to use!")

In [None]:
# Summary: List all saved models
from pathlib import Path

SAVE_DIR = Path('/content/drive/MyDrive/vocal_model_results')
print("=" * 60)
print("VOCAL SEPARATION MODEL - COMPLETE")
print("=" * 60)

# List all models
all_models = list(SAVE_DIR.rglob('*.ckpt')) + list(SAVE_DIR.rglob('*.pth')) + list(SAVE_DIR.rglob('*.pt'))
if all_models:
    print(f"\nüìÅ Models saved to Google Drive ({SAVE_DIR}):")
    for m in all_models:
        size_mb = m.stat().st_size / 1e6
        print(f"  ‚úÖ {m.name} ({size_mb:.1f} MB)")
else:
    print("\n‚ö†Ô∏è No model files found in Google Drive. Check the download step above.")

print("\nüìã Next steps:")
print("  1. Download model from Google Drive to stemscribe/backend/models/pretrained/")
print("  2. Update enhanced_separator.py to use the BS-RoFormer checkpoint")
print("  3. Test vocal separation quality on real songs")

## Done!

**Pre-trained BS-RoFormer vocal model** saved to Google Drive.

| Metric | Value |
|--------|-------|
| Model | BS-RoFormer (viperx ep317) |
| SDR (vocals) | 12.97 dB |
| SDR (instrumental) | 17.0 dB |
| Architecture | Band-Split RoPE Transformer |
| Parameters | ~350MB checkpoint |

**To fine-tune later:**
1. Upload vocal stems to `Google Drive/custom_vocal_data/<song_name>/vocals.wav + other.wav`
2. Re-run this notebook ‚Äî it will automatically detect the custom data and fine-tune