# Audio Frontend Module Tests

This notebook tests the `audio_frontend` module of the TorchAudio Long-Form Aligner.

Each test cell will display:
- ✅ if the test passes
- ❌ if the test fails

## Setup

In [None]:
# Check PyTorch and TorchAudio versions
import torch
import torchaudio

print(f"PyTorch: {torch.__version__}")
print(f"TorchAudio: {torchaudio.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Install torchcodec if using torchaudio >= 2.8
import torchaudio
version_parts = torchaudio.__version__.split('.')
major, minor = int(version_parts[0]), int(version_parts[1].split('+')[0])
if (major, minor) >= (2, 8):
    print("TorchAudio >= 2.8 detected, installing torchcodec...")
    !pip install -q torchcodec
else:
    print(f"TorchAudio {torchaudio.__version__} - torchcodec not required")

## Define the Audio Frontend Module

Copy of `torchaudio_aligner/src/audio_frontend.py` for Colab testing:

In [None]:
"""
Audio Frontend Module for TorchAudio Long-Form Aligner
"""

from dataclasses import dataclass
from typing import Optional, List, Callable, Tuple, Union
from pathlib import Path
import logging

import torch
import torchaudio

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass
class AudioSegment:
    """Represents a segment of audio with metadata."""
    waveform: torch.Tensor
    sample_rate: int
    offset_samples: int
    length_samples: int
    segment_index: int

    @property
    def offset_seconds(self) -> float:
        return self.offset_samples / self.sample_rate

    @property
    def duration_seconds(self) -> float:
        return self.length_samples / self.sample_rate


@dataclass
class SegmentationResult:
    """Result of audio segmentation containing all segments and metadata."""
    segments: List[AudioSegment]
    original_duration_samples: int
    original_duration_seconds: float
    sample_rate: int
    segment_size_samples: int
    overlap_samples: int
    num_segments: int

    def get_waveforms_batched(self) -> Tuple[torch.Tensor, torch.Tensor]:
        max_len = max(seg.waveform.shape[-1] for seg in self.segments)
        batch_size = len(self.segments)

        if self.segments[0].waveform.dim() == 1:
            waveforms = torch.zeros(batch_size, max_len)
        else:
            num_channels = self.segments[0].waveform.shape[0]
            waveforms = torch.zeros(batch_size, num_channels, max_len)

        lengths = torch.zeros(batch_size, dtype=torch.long)

        for i, seg in enumerate(self.segments):
            length = seg.waveform.shape[-1]
            if seg.waveform.dim() == 1:
                waveforms[i, :length] = seg.waveform
            else:
                waveforms[i, :, :length] = seg.waveform
            lengths[i] = length

        return waveforms, lengths

    def get_offsets_in_frames(self, frame_duration_seconds: float) -> torch.Tensor:
        offsets = torch.tensor([seg.offset_samples for seg in self.segments])
        return (offsets / self.sample_rate / frame_duration_seconds).long()


class AudioFrontend:
    def __init__(
        self,
        target_sample_rate: int = 16000,
        mono: bool = True,
        normalize: bool = False,
        normalize_db: float = -3.0,
        preprocessors: Optional[List[Callable[[torch.Tensor, int], torch.Tensor]]] = None,
    ):
        self.target_sample_rate = target_sample_rate
        self.mono = mono
        self.normalize = normalize
        self.normalize_db = normalize_db
        self.preprocessors = preprocessors or []

    def load(self, audio_path: Union[str, Path]) -> Tuple[torch.Tensor, int]:
        audio_path = Path(audio_path)
        if not audio_path.exists():
            raise FileNotFoundError(f"Audio file not found: {audio_path}")

        logger.info(f"Loading audio from: {audio_path}")
        waveform, sample_rate = torchaudio.load(str(audio_path))
        logger.info(f"Loaded: shape={waveform.shape}, sr={sample_rate}, duration={waveform.shape[1]/sample_rate:.2f}s")
        return waveform, sample_rate

    def resample(self, waveform: torch.Tensor, orig_sample_rate: int, target_sample_rate: Optional[int] = None) -> torch.Tensor:
        target_sr = target_sample_rate or self.target_sample_rate
        if orig_sample_rate == target_sr:
            return waveform
        logger.info(f"Resampling from {orig_sample_rate} Hz to {target_sr} Hz")
        return torchaudio.functional.resample(waveform, orig_sample_rate, target_sr)

    def to_mono(self, waveform: torch.Tensor) -> torch.Tensor:
        if waveform.shape[0] == 1:
            return waveform
        logger.info(f"Converting {waveform.shape[0]} channels to mono")
        return waveform.mean(dim=0, keepdim=True)

    def apply_normalization(self, waveform: torch.Tensor) -> torch.Tensor:
        peak = waveform.abs().max()
        if peak > 0:
            target_peak = 10 ** (self.normalize_db / 20)
            waveform = waveform * (target_peak / peak)
            logger.info(f"Normalized audio to {self.normalize_db} dB peak")
        return waveform

    def preprocess(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
        if self.mono:
            waveform = self.to_mono(waveform)
        if self.normalize:
            waveform = self.apply_normalization(waveform)
        for preprocessor in self.preprocessors:
            waveform = preprocessor(waveform, sample_rate)
        return waveform

    def segment(
        self,
        waveform: torch.Tensor,
        sample_rate: int,
        segment_size: float = 15.0,
        overlap: float = 2.0,
        min_segment_size: float = 0.2,
        extra_samples: int = 128,
    ) -> SegmentationResult:
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)

        num_channels, total_samples = waveform.shape
        segment_size_samples = int(sample_rate * segment_size) + extra_samples
        overlap_samples = int(sample_rate * overlap) + extra_samples
        min_segment_samples = int(sample_rate * min_segment_size)
        step_size = segment_size_samples - overlap_samples

        logger.info(f"Segmenting: total={total_samples}, seg_size={segment_size_samples}, overlap={overlap_samples}, step={step_size}")

        segments = []
        segment_idx = 0
        offset = 0

        while offset < total_samples:
            end = min(offset + segment_size_samples, total_samples)
            segment_length = end - offset

            if segment_length < min_segment_samples:
                break

            segment_waveform = waveform[:, offset:end]
            if num_channels == 1:
                segment_waveform = segment_waveform.squeeze(0)

            segment = AudioSegment(
                waveform=segment_waveform,
                sample_rate=sample_rate,
                offset_samples=offset,
                length_samples=segment_length,
                segment_index=segment_idx,
            )
            segments.append(segment)
            segment_idx += 1
            offset += step_size

            if end >= total_samples:
                break

        logger.info(f"Created {len(segments)} segments")

        return SegmentationResult(
            segments=segments,
            original_duration_samples=total_samples,
            original_duration_seconds=total_samples / sample_rate,
            sample_rate=sample_rate,
            segment_size_samples=segment_size_samples,
            overlap_samples=overlap_samples,
            num_segments=len(segments),
        )

    def process(
        self,
        audio_path: Union[str, Path],
        segment_size: float = 15.0,
        overlap: float = 2.0,
        min_segment_size: float = 0.2,
        extra_samples: int = 128,
    ) -> SegmentationResult:
        waveform, orig_sample_rate = self.load(audio_path)
        waveform = self.resample(waveform, orig_sample_rate)
        waveform = self.preprocess(waveform, self.target_sample_rate)
        return self.segment(
            waveform, self.target_sample_rate,
            segment_size=segment_size, overlap=overlap,
            min_segment_size=min_segment_size, extra_samples=extra_samples,
        )

    def process_waveform(
        self,
        waveform: torch.Tensor,
        sample_rate: int,
        segment_size: float = 15.0,
        overlap: float = 2.0,
        min_segment_size: float = 0.2,
        extra_samples: int = 128,
    ) -> SegmentationResult:
        waveform = self.resample(waveform, sample_rate)
        waveform = self.preprocess(waveform, self.target_sample_rate)
        return self.segment(
            waveform, self.target_sample_rate,
            segment_size=segment_size, overlap=overlap,
            min_segment_size=min_segment_size, extra_samples=extra_samples,
        )


def segment_audio(
    audio_path: Union[str, Path],
    target_sample_rate: int = 16000,
    segment_size: float = 15.0,
    overlap: float = 2.0,
    mono: bool = True,
    normalize: bool = False,
) -> SegmentationResult:
    frontend = AudioFrontend(
        target_sample_rate=target_sample_rate,
        mono=mono,
        normalize=normalize,
    )
    return frontend.process(audio_path, segment_size=segment_size, overlap=overlap)


print("✅ Audio Frontend module loaded successfully!")

## Download Test Audio

We'll use Meta's Q1 2025 earnings call as test audio (same as in Tutorial.py).

In [None]:
# Download test audio (Meta Q1 2025 Earnings Call - ~1 hour)
!wget -q https://static.seekingalpha.com/cdn/s3/transcripts_audio/4780182.mp3 -O test_audio.mp3
!ls -lh test_audio.mp3

TEST_AUDIO = "test_audio.mp3"
print("✅ Test audio downloaded")

## Test 1: Load Audio

In [None]:
print("=" * 60)
print("Test 1: AudioFrontend.load()")
print("=" * 60)

try:
    frontend = AudioFrontend(target_sample_rate=16000)
    waveform, sample_rate = frontend.load(TEST_AUDIO)

    print(f"\nResults:")
    print(f"  Waveform shape: {waveform.shape}")
    print(f"  Sample rate: {sample_rate} Hz")
    print(f"  Duration: {waveform.shape[1] / sample_rate:.2f} seconds")
    print(f"  Duration: {waveform.shape[1] / sample_rate / 60:.2f} minutes")

    assert waveform.dim() == 2, "Waveform should be 2D"
    assert sample_rate > 0, "Sample rate should be positive"
    print("\n✅ Test 1 PASSED")
except Exception as e:
    print(f"\n❌ Test 1 FAILED: {e}")

## Test 2: Resample Audio

In [None]:
print("=" * 60)
print("Test 2: AudioFrontend.resample()")
print("=" * 60)

try:
    print(f"\nOriginal sample rate: {sample_rate} Hz")
    print(f"Original samples: {waveform.shape[1]}")

    resampled = frontend.resample(waveform, sample_rate, 16000)

    expected_samples = int(waveform.shape[1] * 16000 / sample_rate)
    print(f"Resampled samples: {resampled.shape[1]}")
    print(f"Expected samples (approx): {expected_samples}")

    assert abs(resampled.shape[1] - expected_samples) < 100, "Resampled length mismatch"
    print("\n✅ Test 2 PASSED")
except Exception as e:
    print(f"\n❌ Test 2 FAILED: {e}")

## Test 3: Convert to Mono

In [None]:
print("=" * 60)
print("Test 3: AudioFrontend.to_mono()")
print("=" * 60)

try:
    print(f"\nOriginal channels: {waveform.shape[0]}")

    mono = frontend.to_mono(waveform)

    print(f"Mono channels: {mono.shape[0]}")
    assert mono.shape[0] == 1, "Should have 1 channel"
    print("\n✅ Test 3 PASSED")
except Exception as e:
    print(f"\n❌ Test 3 FAILED: {e}")

## Test 4: Segment Audio

In [None]:
print("=" * 60)
print("Test 4: AudioFrontend.segment()")
print("=" * 60)

try:
    frontend = AudioFrontend(target_sample_rate=16000, mono=True)
    waveform, orig_sr = frontend.load(TEST_AUDIO)
    waveform = frontend.resample(waveform, orig_sr)
    waveform = frontend.to_mono(waveform)

    result = frontend.segment(
        waveform,
        sample_rate=16000,
        segment_size=15.0,
        overlap=2.0,
        min_segment_size=0.2,
    )

    print(f"\nResults:")
    print(f"  Original duration: {result.original_duration_seconds:.2f} seconds ({result.original_duration_seconds/60:.2f} min)")
    print(f"  Number of segments: {result.num_segments}")
    print(f"  Segment size: {result.segment_size_samples} samples ({result.segment_size_samples/16000:.2f}s)")
    print(f"  Overlap: {result.overlap_samples} samples ({result.overlap_samples/16000:.2f}s)")

    print(f"\nFirst 3 segments:")
    for i, seg in enumerate(result.segments[:3]):
        print(f"  Segment {i}: offset={seg.offset_seconds:.2f}s, duration={seg.duration_seconds:.2f}s, shape={seg.waveform.shape}")

    print(f"\nLast segment:")
    last_seg = result.segments[-1]
    print(f"  Segment {last_seg.segment_index}: offset={last_seg.offset_seconds:.2f}s, duration={last_seg.duration_seconds:.2f}s")

    assert result.num_segments > 0, "Should have at least one segment"
    assert all(seg.sample_rate == 16000 for seg in result.segments), "All segments should have correct sample rate"
    print("\n✅ Test 4 PASSED")
except Exception as e:
    print(f"\n❌ Test 4 FAILED: {e}")

## Test 5: Full Processing Pipeline

In [None]:
print("=" * 60)
print("Test 5: AudioFrontend.process() - Full Pipeline")
print("=" * 60)

try:
    frontend = AudioFrontend(
        target_sample_rate=16000,
        mono=True,
        normalize=False,
    )

    result = frontend.process(
        TEST_AUDIO,
        segment_size=15.0,
        overlap=2.0,
    )

    print(f"\nResults:")
    print(f"  Original duration: {result.original_duration_seconds:.2f} seconds")
    print(f"  Number of segments: {result.num_segments}")

    assert isinstance(result, SegmentationResult)
    assert result.num_segments > 0
    print("\n✅ Test 5 PASSED")
except Exception as e:
    print(f"\n❌ Test 5 FAILED: {e}")

## Test 6: Batching for GPU Inference

In [None]:
print("=" * 60)
print("Test 6: SegmentationResult.get_waveforms_batched()")
print("=" * 60)

try:
    waveforms, lengths = result.get_waveforms_batched()

    print(f"\nResults:")
    print(f"  Batched waveforms shape: {waveforms.shape}")
    print(f"  Lengths shape: {lengths.shape}")
    print(f"  First 5 lengths: {lengths[:5].tolist()}")
    print(f"  Last 5 lengths: {lengths[-5:].tolist()}")

    assert waveforms.shape[0] == result.num_segments, "Batch size mismatch"
    assert lengths.shape[0] == result.num_segments, "Lengths mismatch"
    assert waveforms.dim() == 2, "Should be 2D for mono"
    print("\n✅ Test 6 PASSED")
except Exception as e:
    print(f"\n❌ Test 6 FAILED: {e}")

## Test 7: Frame Offset Calculation

In [None]:
print("=" * 60)
print("Test 7: SegmentationResult.get_offsets_in_frames()")
print("=" * 60)

try:
    # MMS model has 20ms frame duration
    frame_duration = 0.02
    offsets = result.get_offsets_in_frames(frame_duration)

    print(f"\nResults:")
    print(f"  Frame duration: {frame_duration}s (20ms)")
    print(f"  Frame offsets shape: {offsets.shape}")
    print(f"  First 5 offsets (frames): {offsets[:5].tolist()}")

    # Verify monotonically increasing
    is_monotonic = all(offsets[i] < offsets[i+1] for i in range(len(offsets)-1))
    print(f"  Monotonically increasing: {is_monotonic}")

    assert is_monotonic, "Offsets should be monotonically increasing"
    print("\n✅ Test 7 PASSED")
except Exception as e:
    print(f"\n❌ Test 7 FAILED: {e}")

## Test 8: Convenience Function

In [None]:
print("=" * 60)
print("Test 8: segment_audio() convenience function")
print("=" * 60)

try:
    result = segment_audio(
        TEST_AUDIO,
        target_sample_rate=16000,
        segment_size=15.0,
        overlap=2.0,
    )

    print(f"\nResults:")
    print(f"  Duration: {result.original_duration_seconds:.2f}s")
    print(f"  Segments: {result.num_segments}")

    assert isinstance(result, SegmentationResult)
    print("\n✅ Test 8 PASSED")
except Exception as e:
    print(f"\n❌ Test 8 FAILED: {e}")

## Test 9: Normalization

In [None]:
print("=" * 60)
print("Test 9: Audio Normalization")
print("=" * 60)

try:
    frontend_norm = AudioFrontend(target_sample_rate=16000, mono=True, normalize=True, normalize_db=-3.0)

    waveform, sr = frontend_norm.load(TEST_AUDIO)
    waveform = frontend_norm.resample(waveform, sr)
    waveform = frontend_norm.to_mono(waveform)

    original_peak = waveform.abs().max().item()
    print(f"\nOriginal peak: {original_peak:.4f}")

    normalized = frontend_norm.apply_normalization(waveform.clone())
    normalized_peak = normalized.abs().max().item()
    print(f"Normalized peak: {normalized_peak:.4f}")

    expected_peak = 10 ** (-3.0 / 20)  # -3 dB
    print(f"Expected peak (-3dB): {expected_peak:.4f}")

    assert abs(normalized_peak - expected_peak) < 0.01, "Normalized peak mismatch"
    print("\n✅ Test 9 PASSED")
except Exception as e:
    print(f"\n❌ Test 9 FAILED: {e}")

## Test 10: Listen to a Segment

In [None]:
print("=" * 60)
print("Test 10: Listen to a Segment (Visual/Audio Check)")
print("=" * 60)

try:
    import IPython.display as ipd

    result = segment_audio(TEST_AUDIO, segment_size=15.0, overlap=2.0)

    # Play first segment
    seg = result.segments[0]
    print(f"\nPlaying Segment 0:")
    print(f"  Offset: {seg.offset_seconds:.2f}s")
    print(f"  Duration: {seg.duration_seconds:.2f}s")
    ipd.display(ipd.Audio(seg.waveform.numpy(), rate=seg.sample_rate))

    # Play a middle segment
    mid_idx = result.num_segments // 2
    seg = result.segments[mid_idx]
    print(f"\nPlaying Segment {mid_idx} (middle):")
    print(f"  Offset: {seg.offset_seconds:.2f}s")
    print(f"  Duration: {seg.duration_seconds:.2f}s")
    ipd.display(ipd.Audio(seg.waveform.numpy(), rate=seg.sample_rate))
    
    print("\n✅ Test 10 PASSED (verify audio plays correctly)")
except Exception as e:
    print(f"\n❌ Test 10 FAILED: {e}")

## Summary

In [None]:
print("=" * 60)
print("TEST SUMMARY")
print("=" * 60)
print("\nAudio Frontend module tests complete.")
print("\nKey features verified:")
print("  ✅ Load audio from various formats")
print("  ✅ Resample to target sample rate")
print("  ✅ Convert to mono")
print("  ✅ Uniform segmentation with overlap")
print("  ✅ Batch waveforms for GPU inference")
print("  ✅ Calculate frame offsets for acoustic models")
print("  ✅ Audio normalization")
print("\n" + "=" * 60)
print("Now you can clear all outputs and save the file.")
print("=" * 60)