# Audio Frontend Module Tests

This notebook tests the `audio_frontend` module of the TorchAudio Long-Form Aligner.

Each test cell will display:
- ‚úÖ if the test passes
- ‚ùå if the test fails

## Setup

In [None]:
# Check PyTorch and TorchAudio versions
import torch
import torchaudio

print(f"PyTorch: {torch.__version__}")
print(f"TorchAudio: {torchaudio.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Install torchcodec if using torchaudio >= 2.8
import torchaudio
version_parts = torchaudio.__version__.split('.')
major, minor = int(version_parts[0]), int(version_parts[1].split('+')[0])
if (major, minor) >= (2, 8):
    print("TorchAudio >= 2.8 detected, installing torchcodec...")
    !pip install -q torchcodec
else:
    print(f"TorchAudio {torchaudio.__version__} - torchcodec not required")

<cell_type>markdown</cell_type>## Import the Audio Frontend Module

Import from the modular `audio_frontend` package:

In [None]:
# =============================================================================
# Setup: Clone Repository and Configure Imports
# =============================================================================
# This cell sets up the environment for both Colab and local execution.
#
# For Colab users:
#   - Clones the repo from GitHub (dev branch for testing)
#   - No authentication needed for public repos
#
# For local users:
#   - Automatically finds the src/ directory
# =============================================================================

import sys
import os
from pathlib import Path

# ===== CONFIGURATION =====
GITHUB_REPO = "https://github.com/huangruizhe/torchaudio_aligner.git"
BRANCH = "dev"  # Use 'dev' for testing, 'main' for stable
# =========================

def setup_imports():
    """Setup Python path for imports based on environment."""
    
    # Check if running in Colab
    IN_COLAB = 'google.colab' in sys.modules
    
    if IN_COLAB:
        repo_path = '/content/torchaudio_aligner'
        src_path = f'{repo_path}/src'
        
        # Clone or update repo
        if not os.path.exists(repo_path):
            print(f"üì• Cloning repository (branch: {BRANCH})...")
            os.system(f'git clone -b {BRANCH} {GITHUB_REPO} {repo_path}')
            print("‚úÖ Repository cloned")
        else:
            # Pull latest changes
            print(f"üì• Updating repository (branch: {BRANCH})...")
            os.system(f'cd {repo_path} && git fetch origin && git checkout {BRANCH} && git pull origin {BRANCH}')
            print("‚úÖ Repository updated")
        
        # Verify src exists
        if os.path.exists(src_path):
            print(f"‚úÖ Found src at: {src_path}")
        else:
            print(f"‚ùå src directory NOT found at: {src_path}")
            raise FileNotFoundError(f"src not found at {src_path}")
    
    else:
        # Local environment - find src directory
        possible_paths = [
            Path(".").absolute().parent / "src",  # Running from tests/
            Path(".").absolute() / "src",          # Running from project root
        ]
        
        src_path = None
        for p in possible_paths:
            if p.exists() and (p / "audio_frontend").exists():
                src_path = str(p.absolute())
                break
        
        if src_path is None:
            print("‚ùå Could not find src directory locally")
            print(f"   Current directory: {os.getcwd()}")
            raise FileNotFoundError("src directory not found")
        
        print(f"‚úÖ Running locally from: {src_path}")
    
    # Add to Python path
    if src_path not in sys.path:
        sys.path.insert(0, src_path)
    
    return src_path

# Run setup
src_path = setup_imports()

# Now import from the modular audio_frontend package
from audio_frontend import (
    # Loaders
    load_audio,
    get_available_backends,
    AudioBackend,
    # Preprocessing
    resample,
    to_mono,
    normalize_peak,
    preprocess,
    # Segmentation
    AudioSegment,
    SegmentationResult,
    segment_waveform,
    # Enhancement
    AudioEnhancement,
    EnhancementResult,
    TimeMappingManager,
    enhance_audio,
    denoise_noisereduce,
    get_available_enhancement_backends,
    # Frontend
    AudioFrontend,
    segment_audio,
)

import logging
logging.basicConfig(level=logging.INFO)

print()
print("=" * 60)
print("‚úÖ Audio Frontend imported successfully!")
print("=" * 60)
print("Modules loaded:")
print("  ‚Ä¢ loaders: load_audio, AudioBackend")
print("  ‚Ä¢ preprocessing: resample, to_mono, normalize_peak")
print("  ‚Ä¢ segmentation: AudioSegment, SegmentationResult")
print("  ‚Ä¢ enhancement: AudioEnhancement, TimeMappingManager")
print("  ‚Ä¢ frontend: AudioFrontend, segment_audio")

## Download Test Audio

We'll use Meta's Q1 2025 earnings call as test audio (same as in Tutorial.py).

In [None]:
# Download test audio (Meta Q1 2025 Earnings Call - ~1 hour)
!wget -q https://static.seekingalpha.com/cdn/s3/transcripts_audio/4780182.mp3 -O test_audio.mp3
!ls -lh test_audio.mp3

TEST_AUDIO = "test_audio.mp3"
print("‚úÖ Test audio downloaded")

## Test 1: Load Audio

In [None]:
print("=" * 60)
print("Test 1: AudioFrontend.load()")
print("=" * 60)

try:
    frontend = AudioFrontend(target_sample_rate=16000)
    waveform, sample_rate = frontend.load(TEST_AUDIO)

    print(f"\nResults:")
    print(f"  Waveform shape: {waveform.shape}")
    print(f"  Sample rate: {sample_rate} Hz")
    print(f"  Duration: {waveform.shape[1] / sample_rate:.2f} seconds")
    print(f"  Duration: {waveform.shape[1] / sample_rate / 60:.2f} minutes")

    assert waveform.dim() == 2, "Waveform should be 2D"
    assert sample_rate > 0, "Sample rate should be positive"
    print("\n‚úÖ Test 1 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 1 FAILED: {e}")

## Test 2: Resample Audio

In [None]:
print("=" * 60)
print("Test 2: AudioFrontend.resample()")
print("=" * 60)

try:
    print(f"\nOriginal sample rate: {sample_rate} Hz")
    print(f"Original samples: {waveform.shape[1]}")

    resampled = frontend.resample(waveform, sample_rate, 16000)

    expected_samples = int(waveform.shape[1] * 16000 / sample_rate)
    print(f"Resampled samples: {resampled.shape[1]}")
    print(f"Expected samples (approx): {expected_samples}")

    assert abs(resampled.shape[1] - expected_samples) < 100, "Resampled length mismatch"
    print("\n‚úÖ Test 2 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 2 FAILED: {e}")

## Test 3: Convert to Mono

In [None]:
print("=" * 60)
print("Test 3: AudioFrontend.to_mono()")
print("=" * 60)

try:
    print(f"\nOriginal channels: {waveform.shape[0]}")

    mono = frontend.to_mono(waveform)

    print(f"Mono channels: {mono.shape[0]}")
    assert mono.shape[0] == 1, "Should have 1 channel"
    print("\n‚úÖ Test 3 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 3 FAILED: {e}")

## Test 4: Segment Audio

In [None]:
print("=" * 60)
print("Test 4: AudioFrontend.segment()")
print("=" * 60)

try:
    frontend = AudioFrontend(target_sample_rate=16000, mono=True)
    waveform, orig_sr = frontend.load(TEST_AUDIO)
    waveform = frontend.resample(waveform, orig_sr)
    waveform = frontend.to_mono(waveform)

    result = frontend.segment(
        waveform,
        sample_rate=16000,
        segment_size=15.0,
        overlap=2.0,
        min_segment_size=0.2,
    )

    print(f"\nResults:")
    print(f"  Original duration: {result.original_duration_seconds:.2f} seconds ({result.original_duration_seconds/60:.2f} min)")
    print(f"  Number of segments: {result.num_segments}")
    print(f"  Segment size: {result.segment_size_samples} samples ({result.segment_size_samples/16000:.2f}s)")
    print(f"  Overlap: {result.overlap_samples} samples ({result.overlap_samples/16000:.2f}s)")

    print(f"\nFirst 3 segments:")
    for i, seg in enumerate(result.segments[:3]):
        print(f"  Segment {i}: offset={seg.offset_seconds:.2f}s, duration={seg.duration_seconds:.2f}s, shape={seg.waveform.shape}")

    print(f"\nLast segment:")
    last_seg = result.segments[-1]
    print(f"  Segment {last_seg.segment_index}: offset={last_seg.offset_seconds:.2f}s, duration={last_seg.duration_seconds:.2f}s")

    assert result.num_segments > 0, "Should have at least one segment"
    assert all(seg.sample_rate == 16000 for seg in result.segments), "All segments should have correct sample rate"
    print("\n‚úÖ Test 4 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 4 FAILED: {e}")

## Test 5: Full Processing Pipeline

In [None]:
print("=" * 60)
print("Test 5: AudioFrontend.process() - Full Pipeline")
print("=" * 60)

try:
    frontend = AudioFrontend(
        target_sample_rate=16000,
        mono=True,
        normalize=False,
    )

    result = frontend.process(
        TEST_AUDIO,
        segment_size=15.0,
        overlap=2.0,
    )

    print(f"\nResults:")
    print(f"  Original duration: {result.original_duration_seconds:.2f} seconds")
    print(f"  Number of segments: {result.num_segments}")

    assert isinstance(result, SegmentationResult)
    assert result.num_segments > 0
    print("\n‚úÖ Test 5 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 5 FAILED: {e}")

## Test 6: Batching for GPU Inference

In [None]:
print("=" * 60)
print("Test 6: SegmentationResult.get_waveforms_batched()")
print("=" * 60)

try:
    waveforms, lengths = result.get_waveforms_batched()

    print(f"\nResults:")
    print(f"  Batched waveforms shape: {waveforms.shape}")
    print(f"  Lengths shape: {lengths.shape}")
    print(f"  First 5 lengths: {lengths[:5].tolist()}")
    print(f"  Last 5 lengths: {lengths[-5:].tolist()}")

    assert waveforms.shape[0] == result.num_segments, "Batch size mismatch"
    assert lengths.shape[0] == result.num_segments, "Lengths mismatch"
    assert waveforms.dim() == 2, "Should be 2D for mono"
    print("\n‚úÖ Test 6 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 6 FAILED: {e}")

## Test 7: Frame Offset Calculation

In [None]:
print("=" * 60)
print("Test 7: SegmentationResult.get_offsets_in_frames()")
print("=" * 60)

try:
    # MMS model has 20ms frame duration
    frame_duration = 0.02
    offsets = result.get_offsets_in_frames(frame_duration)

    print(f"\nResults:")
    print(f"  Frame duration: {frame_duration}s (20ms)")
    print(f"  Frame offsets shape: {offsets.shape}")
    print(f"  First 5 offsets (frames): {offsets[:5].tolist()}")

    # Verify monotonically increasing
    is_monotonic = all(offsets[i] < offsets[i+1] for i in range(len(offsets)-1))
    print(f"  Monotonically increasing: {is_monotonic}")

    assert is_monotonic, "Offsets should be monotonically increasing"
    print("\n‚úÖ Test 7 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 7 FAILED: {e}")

## Test 8: Convenience Function

In [None]:
print("=" * 60)
print("Test 8: segment_audio() convenience function")
print("=" * 60)

try:
    result = segment_audio(
        TEST_AUDIO,
        target_sample_rate=16000,
        segment_size=15.0,
        overlap=2.0,
    )

    print(f"\nResults:")
    print(f"  Duration: {result.original_duration_seconds:.2f}s")
    print(f"  Segments: {result.num_segments}")

    assert isinstance(result, SegmentationResult)
    print("\n‚úÖ Test 8 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 8 FAILED: {e}")

## Test 9: Normalization

In [None]:
print("=" * 60)
print("Test 9: Audio Normalization")
print("=" * 60)

try:
    frontend_norm = AudioFrontend(target_sample_rate=16000, mono=True, normalize=True, normalize_db=-3.0)

    waveform, sr = frontend_norm.load(TEST_AUDIO)
    waveform = frontend_norm.resample(waveform, sr)
    waveform = frontend_norm.to_mono(waveform)

    original_peak = waveform.abs().max().item()
    print(f"\nOriginal peak: {original_peak:.4f}")

    normalized = frontend_norm.apply_normalization(waveform.clone())
    normalized_peak = normalized.abs().max().item()
    print(f"Normalized peak: {normalized_peak:.4f}")

    expected_peak = 10 ** (-3.0 / 20)  # -3 dB
    print(f"Expected peak (-3dB): {expected_peak:.4f}")

    assert abs(normalized_peak - expected_peak) < 0.01, "Normalized peak mismatch"
    print("\n‚úÖ Test 9 PASSED")
except Exception as e:
    print(f"\n‚ùå Test 9 FAILED: {e}")

## Test 10: Listen to a Segment

In [None]:
print("=" * 60)
print("Test 10: Listen to a Segment (Visual/Audio Check)")
print("=" * 60)

try:
    import IPython.display as ipd

    result = segment_audio(TEST_AUDIO, segment_size=15.0, overlap=2.0)

    # Play first segment
    seg = result.segments[0]
    print(f"\nPlaying Segment 0:")
    print(f"  Offset: {seg.offset_seconds:.2f}s")
    print(f"  Duration: {seg.duration_seconds:.2f}s")
    ipd.display(ipd.Audio(seg.waveform.numpy(), rate=seg.sample_rate))

    # Play a middle segment
    mid_idx = result.num_segments // 2
    seg = result.segments[mid_idx]
    print(f"\nPlaying Segment {mid_idx} (middle):")
    print(f"  Offset: {seg.offset_seconds:.2f}s")
    print(f"  Duration: {seg.duration_seconds:.2f}s")
    ipd.display(ipd.Audio(seg.waveform.numpy(), rate=seg.sample_rate))
    
    print("\n‚úÖ Test 10 PASSED (verify audio plays correctly)")
except Exception as e:
    print(f"\n‚ùå Test 10 FAILED: {e}")

## Test 11-14: Audio Enhancement Module (Demucs + VAD)

These tests verify the optional audio enhancement features:
- Demucs source separation (vocal extraction)
- Silence removal
- Voice Activity Detection (VAD)
- Timestamp mapping for alignment recovery

**Note**: These require optional dependencies:
```
pip install demucs pyloudnorm
```

In [None]:
# Install optional enhancement dependencies
!pip install -q demucs pyloudnorm

# Check availability
try:
    import demucs
    from demucs.pretrained import get_model_from_args
    from demucs.apply import apply_model
    DEMUCS_AVAILABLE = True
    print("‚úÖ demucs available")
except ImportError:
    DEMUCS_AVAILABLE = False
    print("‚ùå demucs not available")

try:
    import pyloudnorm
    PYLOUDNORM_AVAILABLE = True
    print("‚úÖ pyloudnorm available")
except ImportError:
    PYLOUDNORM_AVAILABLE = False
    print("‚ùå pyloudnorm not available")

In [None]:
print("=" * 60)
print("Test 11: TimeMappingManager - Timestamp Recovery")
print("=" * 60)
print("""
When silence is removed from audio, timestamps change.
TimeMappingManager tracks these changes for recovery.

Example:
  Original audio:  [speech][silence][speech][silence][speech]
                    0-2s    2-5s     5-8s    8-10s    10-15s
  
  After removal:   [speech][speech][speech]
                    0-2s    2-5s    5-10s
  
  Mapping: processed_time=3.0 -> original_time=6.0
""")

# TimeMappingManager is now imported from audio_frontend.enhancement

try:
    # Test case
    mapper = TimeMappingManager([(0, 1), (3, 5), (6, 8)])
    
    # Test mappings
    test_cases = [
        (-1, -1),
        (0, 1),      # After removing 0-1s silence, time 0 maps to 1
        (0.5, 1.5),
        (1, 2),
        (2, 5),      # After removing 3-5s silence, time 2 maps to 5
        (3, 8),      # After removing 6-8s silence, time 3 maps to 8
    ]
    
    print("Testing timestamp mappings:")
    all_passed = True
    for processed, expected_original in test_cases:
        actual = mapper.map_to_original(processed)
        passed = abs(actual - expected_original) < 1e-6
        status = "‚úÖ" if passed else "‚ùå"
        print(f"  {status} map_to_original({processed}) = {actual:.2f} (expected {expected_original})")
        all_passed = all_passed and passed
    
    assert all_passed, "Some timestamp mappings failed"
    print("\n‚úÖ Test 11 PASSED - TimeMappingManager works correctly")
except Exception as e:
    print(f"\n‚ùå Test 11 FAILED: {e}")

In [None]:
print("=" * 60)
print("Test 12: Silence Removal (Energy-based)")
print("=" * 60)

try:
    # Create test audio with silence
    sr = 16000
    duration = 10.0  # 10 seconds
    
    # Create audio: [noise 0-2s][silence 2-4s][noise 4-7s][silence 7-9s][noise 9-10s]
    samples = int(sr * duration)
    waveform = torch.zeros(samples)
    
    # Add noise to speech segments
    waveform[0:int(sr*2)] = torch.randn(int(sr*2)) * 0.5       # 0-2s: speech
    waveform[int(sr*4):int(sr*7)] = torch.randn(int(sr*3)) * 0.5  # 4-7s: speech  
    waveform[int(sr*9):int(sr*10)] = torch.randn(int(sr*1)) * 0.5 # 9-10s: speech
    
    print(f"üìÑ Created test audio: {duration}s with 3 speech segments")
    print(f"   Speech: 0-2s, 4-7s, 9-10s (total 6s)")
    print(f"   Silence: 2-4s, 7-9s (total 4s)")
    
    # Simple silence removal
    def remove_silence_simple(waveform, sr, threshold_db=-50, min_dur=0.2):
        threshold = 10 ** (threshold_db / 20)
        frame_size = int(sr * 0.02)
        hop_size = int(sr * 0.01)
        
        silence_intervals = []
        in_silence = False
        silence_start = 0
        
        for i in range(0, waveform.shape[0] - frame_size, hop_size):
            frame = waveform[i:i + frame_size]
            energy = frame.abs().max().item()
            
            if energy < threshold:
                if not in_silence:
                    in_silence = True
                    silence_start = i / sr
            else:
                if in_silence:
                    in_silence = False
                    silence_end = i / sr
                    if silence_end - silence_start >= min_dur:
                        silence_intervals.append((silence_start, silence_end))
        
        if in_silence:
            silence_end = waveform.shape[0] / sr
            if silence_end - silence_start >= min_dur:
                silence_intervals.append((silence_start, silence_end))
        
        return silence_intervals
    
    silence_intervals = remove_silence_simple(waveform, sr)
    print(f"\nüìÑ Detected silence intervals: {silence_intervals}")
    
    # Should detect approximately 2-4s and 7-9s
    assert len(silence_intervals) >= 2, "Should detect at least 2 silence periods"
    print(f"‚úÖ Detected {len(silence_intervals)} silence periods")
    
    print("\n‚úÖ Test 12 PASSED - Silence removal works")
except Exception as e:
    print(f"\n‚ùå Test 12 FAILED: {e}")

In [None]:
print("=" * 60)
print("Test 13: Silero VAD (Voice Activity Detection)")
print("=" * 60)

try:
    # Load Silero VAD
    print("Loading Silero VAD model...")
    vad_model, utils = torch.hub.load(
        repo_or_dir="snakers4/silero-vad",
        model="silero_vad",
        force_reload=False,
        onnx=False,
    )
    get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks = utils
    print("‚úÖ Silero VAD loaded")
    
    # Load a short segment of test audio
    waveform, orig_sr = torchaudio.load(TEST_AUDIO)
    waveform = torchaudio.functional.resample(waveform, orig_sr, 16000)
    waveform = waveform.mean(0)  # Mono
    
    # Take first 30 seconds for faster testing
    waveform = waveform[:16000 * 30]
    
    print(f"\nüìÑ Test audio: {waveform.shape[0]/16000:.2f}s")
    
    # Get speech timestamps
    speech_timestamps = get_speech_timestamps(
        waveform,
        vad_model,
        threshold=0.4,
        min_silence_duration_ms=500,
        sampling_rate=16000,
    )
    
    print(f"üìÑ Detected {len(speech_timestamps)} speech segments:")
    for i, ts in enumerate(speech_timestamps[:5]):
        start = ts['start'] / 16000
        end = ts['end'] / 16000
        print(f"   [{i}] {start:.2f}s - {end:.2f}s (duration: {end-start:.2f}s)")
    if len(speech_timestamps) > 5:
        print(f"   ... and {len(speech_timestamps) - 5} more")
    
    # Collect speech chunks
    speech_waveform = collect_chunks(speech_timestamps, waveform)
    
    original_dur = waveform.shape[0] / 16000
    speech_dur = speech_waveform.shape[0] / 16000
    print(f"\nüìÑ VAD result: {original_dur:.2f}s -> {speech_dur:.2f}s ({100*speech_dur/original_dur:.1f}%)")
    
    assert len(speech_timestamps) > 0, "Should detect some speech"
    assert speech_dur < original_dur, "Speech duration should be less than original"
    print("\n‚úÖ Test 13 PASSED - Silero VAD works")
except Exception as e:
    print(f"\n‚ùå Test 13 FAILED: {e}")

In [None]:
print("=" * 60)
print("Test 14: Demucs Vocal Extraction (Optional - Slow)")
print("=" * 60)

if not DEMUCS_AVAILABLE:
    print("‚ö†Ô∏è Demucs not installed, skipping test")
    print("   Install with: pip install demucs")
else:
    try:
        from demucs.pretrained import get_model_from_args
        from demucs.apply import apply_model
        
        # Load Demucs model
        print("Loading Demucs model (htdemucs)...")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        demucs_model = get_model_from_args(
            type("args", (object,), dict(name="htdemucs", repo=None))
        )
        demucs_model = demucs_model.to(device).eval()
        print(f"‚úÖ Demucs loaded on {device}")
        print(f"   Sources: {demucs_model.sources}")  # ['drums', 'bass', 'other', 'vocals']
        
        # Load a short segment (10 seconds) - Demucs is slow
        waveform, orig_sr = torchaudio.load(TEST_AUDIO)
        waveform = waveform[:, :orig_sr * 10]  # First 10 seconds
        
        print(f"\nüìÑ Test audio: {waveform.shape[1]/orig_sr:.2f}s at {orig_sr}Hz")
        
        # Convert audio using torchaudio (demucs.audio API changed in newer versions)
        # Demucs expects stereo audio at model.samplerate
        if waveform.shape[0] == 1:
            waveform_stereo = waveform.repeat(2, 1)
        else:
            waveform_stereo = waveform
        
        # Resample to model's sample rate
        if orig_sr != demucs_model.samplerate:
            wav = torchaudio.functional.resample(waveform_stereo, orig_sr, demucs_model.samplerate)
        else:
            wav = waveform_stereo
        
        # Add batch dimension [B, C, T]
        if wav.dim() == 2:
            wav = wav.unsqueeze(0)
        wav = wav.to(device)
        
        print("Applying Demucs source separation...")
        with torch.no_grad():
            result = apply_model(demucs_model, wav, device=device, split=True, overlap=0.25)
        
        # Extract vocals
        vocals_idx = demucs_model.sources.index("vocals")
        vocals = result[0, vocals_idx].mean(0).cpu()  # Average channels to mono
        vocals = torchaudio.functional.resample(vocals, demucs_model.samplerate, orig_sr)
        
        print(f"\nüìÑ Extracted vocals: {vocals.shape[0]/orig_sr:.2f}s")
        print(f"   Original peak: {waveform.abs().max().item():.4f}")
        print(f"   Vocals peak: {vocals.abs().max().item():.4f}")
        
        # Listen to comparison
        import IPython.display as ipd
        print("\nüîä Original audio (first 10s):")
        ipd.display(ipd.Audio(waveform.mean(0).cpu().numpy(), rate=orig_sr))
        
        print("üîä Extracted vocals:")
        ipd.display(ipd.Audio(vocals.cpu().numpy(), rate=orig_sr))
        
        print("\n‚úÖ Test 14 PASSED - Demucs vocal extraction works")
    except Exception as e:
        print(f"\n‚ùå Test 14 FAILED: {e}")
        import traceback
        traceback.print_exc()

## Test 15-17: Additional Denoising Libraries

These tests verify additional denoising options:
- **noisereduce**: Lightweight spectral gating (CPU-friendly) - **Recommended, works everywhere**
- **DeepFilterNet**: Deep learning noise suppression (48kHz full-band) - requires Rust compiler
- **Resemble Enhance**: AI speech denoising - requires torch==2.1.1

**For Colab, just use noisereduce:**
```
pip install noisereduce
```

The other options have complex build requirements. noisereduce is lightweight, effective, and works on CPU.

In [None]:
# Check availability of additional denoising libraries
print("Checking additional denoising libraries...")

NOISEREDUCE_AVAILABLE = False
DEEPFILTERNET_AVAILABLE = False
RESEMBLE_ENHANCE_AVAILABLE = False

try:
    import noisereduce as nr
    NOISEREDUCE_AVAILABLE = True
    print("‚úÖ noisereduce available")
except ImportError:
    print("‚ùå noisereduce not available (pip install noisereduce)")

try:
    from df.enhance import enhance, init_df, load_audio, save_audio
    DEEPFILTERNET_AVAILABLE = True
    print("‚úÖ deepfilternet available")
except ImportError:
    print("‚ùå deepfilternet not available (pip install deepfilternet)")

try:
    from resemble_enhance.enhancer.inference import denoise, enhance as resemble_enhance_fn
    RESEMBLE_ENHANCE_AVAILABLE = True
    print("‚úÖ resemble-enhance available")
except ImportError:
    print("‚ùå resemble-enhance not available (pip install resemble-enhance)")

In [None]:
print("=" * 60)
print("Test 15: noisereduce - Spectral Gating Noise Reduction")
print("=" * 60)

if not NOISEREDUCE_AVAILABLE:
    print("‚ö†Ô∏è noisereduce not installed, skipping test")
    print("   Install with: pip install noisereduce")
else:
    try:
        import noisereduce as nr
        import numpy as np
        
        # Load test audio
        waveform, orig_sr = torchaudio.load(TEST_AUDIO)
        waveform = waveform.mean(0)  # Mono
        waveform = waveform[:orig_sr * 10]  # First 10 seconds
        
        print(f"üìÑ Test audio: {waveform.shape[0]/orig_sr:.2f}s at {orig_sr}Hz")
        
        # Apply noisereduce
        audio_np = waveform.numpy()
        
        print("Applying noisereduce (stationary=False)...")
        reduced = nr.reduce_noise(
            y=audio_np,
            sr=orig_sr,
            stationary=False,
            prop_decrease=1.0,
            n_fft=512,
        )
        
        reduced_tensor = torch.from_numpy(reduced).float()
        
        print(f"\nüìÑ Result:")
        print(f"   Original peak: {waveform.abs().max().item():.4f}")
        print(f"   Denoised peak: {reduced_tensor.abs().max().item():.4f}")
        
        # Listen to comparison
        import IPython.display as ipd
        print("\nüîä Original audio (first 10s):")
        ipd.display(ipd.Audio(audio_np, rate=orig_sr))
        
        print("üîä noisereduce denoised:")
        ipd.display(ipd.Audio(reduced, rate=orig_sr))
        
        assert reduced.shape == audio_np.shape, "Output shape should match input"
        print("\n‚úÖ Test 15 PASSED - noisereduce works")
    except Exception as e:
        print(f"\n‚ùå Test 15 FAILED: {e}")
        import traceback
        traceback.print_exc()

In [None]:
print("=" * 60)
print("Test 16: DeepFilterNet - Deep Learning Noise Suppression")
print("=" * 60)

if not DEEPFILTERNET_AVAILABLE:
    print("‚ö†Ô∏è DeepFilterNet not installed, skipping test")
    print("   Install with: pip install deepfilternet")
else:
    try:
        from df.enhance import enhance, init_df, load_audio, save_audio
        
        # Initialize DeepFilterNet
        print("Initializing DeepFilterNet...")
        model, df_state, _ = init_df()
        print(f"‚úÖ DeepFilterNet initialized (sample_rate: {df_state.sr()}Hz)")
        
        # Load test audio at DeepFilterNet's native sample rate (48kHz)
        waveform, orig_sr = torchaudio.load(TEST_AUDIO)
        waveform = waveform.mean(0)  # Mono
        
        # Resample to 48kHz if needed
        if orig_sr != df_state.sr():
            waveform = torchaudio.functional.resample(waveform, orig_sr, df_state.sr())
        
        # Take first 10 seconds
        waveform = waveform[:df_state.sr() * 10]
        
        print(f"üìÑ Test audio: {waveform.shape[0]/df_state.sr():.2f}s at {df_state.sr()}Hz")
        
        # Apply DeepFilterNet
        print("Applying DeepFilterNet...")
        enhanced = enhance(model, df_state, waveform.numpy())
        enhanced_tensor = torch.from_numpy(enhanced).float()
        
        # Resample back if needed
        if orig_sr != df_state.sr():
            enhanced_tensor = torchaudio.functional.resample(
                enhanced_tensor, df_state.sr(), orig_sr
            )
            waveform = torchaudio.functional.resample(waveform, df_state.sr(), orig_sr)
        
        print(f"\nüìÑ Result:")
        print(f"   Original peak: {waveform.abs().max().item():.4f}")
        print(f"   Enhanced peak: {enhanced_tensor.abs().max().item():.4f}")
        
        # Listen to comparison
        import IPython.display as ipd
        print("\nüîä Original audio (first 10s):")
        ipd.display(ipd.Audio(waveform.numpy(), rate=orig_sr))
        
        print("üîä DeepFilterNet enhanced:")
        ipd.display(ipd.Audio(enhanced_tensor.numpy(), rate=orig_sr))
        
        print("\n‚úÖ Test 16 PASSED - DeepFilterNet works")
    except Exception as e:
        print(f"\n‚ùå Test 16 FAILED: {e}")
        import traceback
        traceback.print_exc()

In [None]:
print("=" * 60)
print("Test 17: Resemble Enhance - AI Speech Enhancement")
print("=" * 60)

if not RESEMBLE_ENHANCE_AVAILABLE:
    print("‚ö†Ô∏è Resemble Enhance not installed, skipping test")
    print("   Install with: pip install resemble-enhance")
    print("   Note: Heavy model (~1GB), best with GPU")
else:
    try:
        from resemble_enhance.enhancer.inference import denoise, enhance as resemble_enhance_fn
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        # Load test audio
        waveform, orig_sr = torchaudio.load(TEST_AUDIO)
        waveform = waveform.mean(0)  # Mono
        
        # Resemble Enhance works at 44.1kHz
        if orig_sr != 44100:
            waveform = torchaudio.functional.resample(waveform, orig_sr, 44100)
        
        # Take first 10 seconds (heavy model, be patient)
        waveform = waveform[:44100 * 10]
        
        print(f"üìÑ Test audio: {waveform.shape[0]/44100:.2f}s at 44100Hz")
        
        # Apply denoise only (faster than full enhance)
        print("Applying Resemble Enhance (denoise only)...")
        enhanced, out_sr = denoise(waveform, 44100, device)
        
        # Resample back if needed
        if out_sr != orig_sr:
            enhanced = torchaudio.functional.resample(enhanced, out_sr, orig_sr)
            waveform_out = torchaudio.functional.resample(waveform, 44100, orig_sr)
        else:
            waveform_out = waveform
        
        print(f"\nüìÑ Result:")
        print(f"   Original peak: {waveform.abs().max().item():.4f}")
        print(f"   Enhanced peak: {enhanced.abs().max().item():.4f}")
        
        # Listen to comparison
        import IPython.display as ipd
        print("\nüîä Original audio (first 10s):")
        ipd.display(ipd.Audio(waveform_out.cpu().numpy(), rate=orig_sr))
        
        print("üîä Resemble Enhance denoised:")
        ipd.display(ipd.Audio(enhanced.cpu().numpy(), rate=orig_sr))
        
        print("\n‚úÖ Test 17 PASSED - Resemble Enhance works")
    except Exception as e:
        print(f"\n‚ùå Test 17 FAILED: {e}")
        import traceback
        traceback.print_exc()

## Test 18: Real-World Noisy Audio Enhancement Comparison

This test uses the **NASA Apollo 11 moon landing audio** (1969) - a challenging real-world noisy recording.
We compare all denoising methods side-by-side so you can hear the before/after difference.

Source: [NASA Apollo 11 Archive](https://history.nasa.gov/alsj/a11/video11.html#Step)

Transcript: *"I'm at the foot of the ladder... That's one small step for man, one giant leap for mankind."*

In [None]:
# Download NASA Apollo 11 Moon Landing Audio (1969) - Very noisy historical recording
print("=" * 60)
print("Test 18: NASA Apollo 11 Audio - Enhancement Comparison")
print("=" * 60)
print("""
This is Neil Armstrong's famous "One small step for man" recording from 1969.
The audio quality is poor (recorded on the Moon!) - perfect for testing denoising.

Transcript: "I'm at the foot of the ladder. The LM footpads are only depressed 
in the surface about 1 or 2 inches... That's one small step for man, 
one giant leap for mankind."
""")

import subprocess
import os

# Download the NASA video and extract audio
NOISY_AUDIO = "apollo11_audio.wav"

if not os.path.exists(NOISY_AUDIO):
    print("Downloading NASA Apollo 11 video...")
    !wget -q https://www.nasa.gov/wp-content/uploads/static/history/alsj/a11/a11.v1092338.mov -O apollo11.mov
    print("Extracting audio...")
    !ffmpeg -loglevel warning -y -i apollo11.mov -vn -acodec pcm_s16le -ar 16000 -ac 1 {NOISY_AUDIO}
    !rm apollo11.mov
    print(f"‚úÖ Audio extracted: {NOISY_AUDIO}")
else:
    print(f"‚úÖ Using cached: {NOISY_AUDIO}")

!ls -lh {NOISY_AUDIO}

In [None]:
# Load the Apollo 11 audio and compare all enhancement methods
import IPython.display as ipd

print("Loading Apollo 11 audio...")
apollo_waveform, apollo_sr = torchaudio.load(NOISY_AUDIO)
apollo_waveform = apollo_waveform.squeeze(0)  # Mono
duration = apollo_waveform.shape[0] / apollo_sr

print(f"üìÑ Apollo 11 Audio: {duration:.2f}s at {apollo_sr}Hz")
print(f"   This is VERY noisy 1969 audio from the Moon!")

# Store all enhanced versions for comparison
enhanced_versions = {}

# Original
print("\n" + "="*60)
print("üîä ORIGINAL (Noisy Apollo 11 Recording)")
print("="*60)
ipd.display(ipd.Audio(apollo_waveform.numpy(), rate=apollo_sr))
enhanced_versions["original"] = apollo_waveform

# 1. noisereduce
if NOISEREDUCE_AVAILABLE:
    print("\n" + "="*60)
    print("üîä noisereduce (Spectral Gating)")
    print("="*60)
    try:
        import noisereduce as nr
        nr_result = nr.reduce_noise(
            y=apollo_waveform.numpy(),
            sr=apollo_sr,
            stationary=False,
            prop_decrease=1.0,
        )
        enhanced_versions["noisereduce"] = torch.from_numpy(nr_result).float()
        ipd.display(ipd.Audio(nr_result, rate=apollo_sr))
        print("‚úÖ noisereduce applied")
    except Exception as e:
        print(f"‚ùå noisereduce failed: {e}")
else:
    print("\n‚ö†Ô∏è noisereduce not available")

# 2. DeepFilterNet
if DEEPFILTERNET_AVAILABLE:
    print("\n" + "="*60)
    print("üîä DeepFilterNet (Deep Learning 48kHz)")
    print("="*60)
    try:
        from df.enhance import enhance as df_enhance, init_df
        
        # Initialize if not already done
        if 'df_model' not in dir():
            df_model, df_state, _ = init_df()
        
        # Resample to 48kHz for DeepFilterNet
        apollo_48k = torchaudio.functional.resample(apollo_waveform, apollo_sr, df_state.sr())
        df_result = df_enhance(df_model, df_state, apollo_48k.numpy())
        df_result_16k = torchaudio.functional.resample(
            torch.from_numpy(df_result).float(), df_state.sr(), apollo_sr
        )
        enhanced_versions["deepfilternet"] = df_result_16k
        ipd.display(ipd.Audio(df_result_16k.numpy(), rate=apollo_sr))
        print("‚úÖ DeepFilterNet applied")
    except Exception as e:
        print(f"‚ùå DeepFilterNet failed: {e}")
else:
    print("\n‚ö†Ô∏è DeepFilterNet not available")

# 3. Resemble Enhance (if available)
if RESEMBLE_ENHANCE_AVAILABLE:
    print("\n" + "="*60)
    print("üîä Resemble Enhance (AI Speech Enhancement)")
    print("="*60)
    try:
        from resemble_enhance.enhancer.inference import denoise as resemble_denoise
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Resample to 44.1kHz for Resemble
        apollo_44k = torchaudio.functional.resample(apollo_waveform, apollo_sr, 44100)
        resemble_result, out_sr = resemble_denoise(apollo_44k, 44100, device)
        resemble_result_16k = torchaudio.functional.resample(resemble_result, out_sr, apollo_sr)
        enhanced_versions["resemble"] = resemble_result_16k.cpu()
        ipd.display(ipd.Audio(resemble_result_16k.cpu().numpy(), rate=apollo_sr))
        print("‚úÖ Resemble Enhance applied")
    except Exception as e:
        print(f"‚ùå Resemble Enhance failed: {e}")
else:
    print("\n‚ö†Ô∏è Resemble Enhance not available")

# 4. Demucs (vocal extraction - different use case but interesting)
if DEMUCS_AVAILABLE:
    print("\n" + "="*60)
    print("üîä Demucs (Vocal Extraction)")
    print("="*60)
    try:
        from demucs.pretrained import get_model_from_args
        from demucs.apply import apply_model
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        # Use model from Test 14 if available, otherwise load it
        if 'demucs_model' not in dir():
            demucs_model = get_model_from_args(
                type("args", (object,), dict(name="htdemucs", repo=None))
            ).to(device).eval()
        
        # Prepare for Demucs using torchaudio (demucs.audio API changed)
        apollo_stereo = apollo_waveform.unsqueeze(0).repeat(2, 1)  # Make stereo [2, T]
        
        # Resample to model's sample rate
        if apollo_sr != demucs_model.samplerate:
            wav = torchaudio.functional.resample(apollo_stereo, apollo_sr, demucs_model.samplerate)
        else:
            wav = apollo_stereo
        
        # Add batch dimension [B, C, T]
        if wav.dim() == 2:
            wav = wav.unsqueeze(0)
        wav = wav.to(device)
        
        with torch.no_grad():
            sources = apply_model(demucs_model, wav, device=device, split=True, overlap=0.25)
        
        vocals_idx = demucs_model.sources.index("vocals")
        vocals = sources[0, vocals_idx].mean(0).cpu()
        vocals = torchaudio.functional.resample(vocals, demucs_model.samplerate, apollo_sr)
        enhanced_versions["demucs_vocals"] = vocals
        ipd.display(ipd.Audio(vocals.numpy(), rate=apollo_sr))
        print("‚úÖ Demucs vocal extraction applied")
    except Exception as e:
        print(f"‚ùå Demucs failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("\n‚ö†Ô∏è Demucs not available")

print("\n" + "="*60)
print("COMPARISON COMPLETE!")
print("="*60)
print(f"\nEnhanced versions available: {list(enhanced_versions.keys())}")
print("\nListen to each version above and compare the noise reduction quality!")
print("The Apollo 11 audio is a challenging test case due to extreme noise.")
print("\n‚úÖ Test 18 PASSED - Enhancement comparison complete")

In [None]:
print("=" * 60)
print("TEST SUMMARY")
print("=" * 60)
print("\nAudio Frontend module tests complete.")
print("\nCore Features (Tests 1-10):")
print("  ‚úÖ Test 1-3: Load, resample, mono conversion")
print("  ‚úÖ Test 4-5: Segmentation with overlap")
print("  ‚úÖ Test 6-7: Batching and frame offsets")
print("  ‚úÖ Test 8-9: Convenience functions, normalization")
print("  ‚úÖ Test 10: Audio playback verification")
print("\nEnhancement Features (Tests 11-14, requires: pip install demucs pyloudnorm):")
print("  ‚úÖ Test 11: TimeMappingManager (timestamp recovery)")
print("  ‚úÖ Test 12: Silence removal (energy-based)")
print("  ‚úÖ Test 13: Silero VAD (Voice Activity Detection)")
print("  ‚úÖ Test 14: Demucs vocal extraction")
print("\nAdditional Denoising (Tests 15-17, install as needed):")
print("  ‚úÖ Test 15: noisereduce (pip install noisereduce)")
print("  ‚úÖ Test 16: DeepFilterNet (pip install deepfilternet)")
print("  ‚úÖ Test 17: Resemble Enhance (pip install resemble-enhance)")
print("\nReal-World Comparison (Test 18):")
print("  ‚úÖ Test 18: NASA Apollo 11 audio - all methods compared")
print("\n" + "=" * 60)
print("""
AUDIO FRONTEND COMPLETE!

Features:
‚îú‚îÄ‚îÄ Loading: torchaudio, soundfile (fallback)
‚îú‚îÄ‚îÄ Preprocessing: resample, mono, normalize
‚îú‚îÄ‚îÄ Segmentation: overlap for divide-and-conquer alignment
‚îú‚îÄ‚îÄ Batching: GPU-ready tensor batching
‚îî‚îÄ‚îÄ Enhancement (optional)
    ‚îú‚îÄ‚îÄ Demucs: Vocal extraction from music/noise
    ‚îú‚îÄ‚îÄ Silero VAD: Voice Activity Detection
    ‚îú‚îÄ‚îÄ noisereduce: Spectral gating (CPU-friendly)
    ‚îú‚îÄ‚îÄ DeepFilterNet: Deep learning 48kHz denoising
    ‚îú‚îÄ‚îÄ Resemble Enhance: AI speech enhancement (GPU)
    ‚îî‚îÄ‚îÄ TimeMappingManager: Timestamp recovery

Usage:
    from audio_frontend import AudioEnhancement
    
    # Denoise with noisereduce (lightweight)
    enhancer = AudioEnhancement()
    result = enhancer.enhance("noisy.mp3", denoise_method="noisereduce")
    
    # Denoise with DeepFilterNet (48kHz quality)
    result = enhancer.enhance("noisy.mp3", denoise_method="deepfilternet")
    
    # Denoise with Resemble Enhance (best quality, GPU)
    result = enhancer.enhance("noisy.mp3", denoise_method="resemble")
    
    # Vocal extraction with Demucs
    result = enhancer.enhance("audio_with_music.mp3", extract_vocals=True)
""")
print("=" * 60)