# Audio Quality Analysis - GPU-Accelerated Batch Processing

**Optimizations:**
- Parallel Azure blob downloads (50-100 files at once)
- GPU-accelerated audio processing with torchaudio
- Batch spectrogram computation on GPU
- Automatic cleanup of downloaded files
- Multi-threading for download + GPU processing pipeline

**Target:** Process 5000 files in ~1-3 hours (vs 25-30 hours)

**Output:** Same as original - `audio_quality_analysis.parquet`

In [None]:
import sys
sys.path.append("../scripts")

import torch
import torchaudio
import numpy as np
import pandas as pd
from pathlib import Path
import pyloudnorm as pyln
from tqdm import tqdm
import io
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
import gc
import os

# Azure blob utilities
from scripts.cloud.azure_utils import list_blobs, download_blob_to_memory

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
import os
from dotenv import load_dotenv
load_dotenv(dotenv_path='../credentials/creds.env')

## Configuration

In [None]:
# Processing configuration
BATCH_SIZE = 64  # Number of files to process on GPU simultaneously
DOWNLOAD_WORKERS = 100  # Parallel downloads
DOWNLOAD_BATCH_SIZE = 200  # Files to download before processing
TARGET_SR = 16000  # Target sample rate

# Temp directory for downloaded files (will be cleaned up)
TEMP_DIR = Path(tempfile.mkdtemp())
print(f"Temp directory: {TEMP_DIR}")
print(f"Batch size: {BATCH_SIZE} files")
print(f"Download workers: {DOWNLOAD_WORKERS}")
print(f"Download batch size: {DOWNLOAD_BATCH_SIZE}")

## GPU-Accelerated Audio Metric Functions

In [None]:
def load_audio_to_tensor(audio_bytes, target_sr=16000):
    """
    Load audio bytes to torch tensor and resample.
    Handles MP3, MP4, WAV, etc. using pydub (more robust).
    """
    from pydub import AudioSegment
    
    # Use pydub to load audio (handles MP3, MP4, M4A, WAV, etc.)
    audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes))
    
    # Convert to mono
    audio_segment = audio_segment.set_channels(1)
    
    # Convert to target sample rate
    audio_segment = audio_segment.set_frame_rate(target_sr)
    
    # Convert to numpy array
    samples = np.array(audio_segment.get_array_of_samples())
    
    # Normalize to float32 [-1, 1]
    if audio_segment.sample_width == 2:  # 16-bit
        waveform = samples.astype(np.float32) / 32768.0
    elif audio_segment.sample_width == 4:  # 32-bit
        waveform = samples.astype(np.float32) / 2147483648.0
    else:
        waveform = samples.astype(np.float32)
    
    # Convert to torch tensor
    waveform_tensor = torch.from_numpy(waveform).float()
    
    return waveform_tensor, target_sr


def batch_compute_spectrogram(waveforms, sr, n_fft=2048, hop_length=512):
    """
    Compute spectrograms for batch of waveforms on GPU.
    Matches librosa.stft() parameters exactly.
    
    Args:
        waveforms: List of 1D tensors (different lengths OK)
        sr: Sample rate
    
    Returns:
        List of (magnitude_spectrogram, frequencies) tuples on GPU
    """
    results = []
    
    # Frequency bins (matches librosa.fft_frequencies)
    freqs = torch.linspace(0, sr/2, n_fft//2 + 1, device=device)
    
    # Create the spectrogram transform once
    spec_transform = torchaudio.transforms.Spectrogram(
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=n_fft,  # librosa default
        window_fn=torch.hann_window,
        power=None,  # Complex output
        center=True,  # librosa default
        pad_mode='reflect',  # librosa default
        normalized=False
    ).to(device)
    
    for wv in waveforms:
        # Move to GPU
        wv_gpu = wv.to(device)
        
        # Check if waveform is long enough for n_fft
        # With center=True, we need at least (n_fft // 2 + 1) samples to avoid padding issues
        min_length = n_fft
        if len(wv_gpu) < min_length:
            # Pad short audio to minimum length
            pad_needed = min_length - len(wv_gpu)
            wv_gpu = torch.nn.functional.pad(wv_gpu, (0, pad_needed), mode='constant', value=0)
        
        spec = spec_transform(wv_gpu)
        mag = torch.abs(spec)  # Magnitude
        results.append((mag, freqs))
    
    return results


def snr_cal_batch(waveform, sr):
    """
    Calculate SNR matching librosa.feature.rms() exactly.
    
    librosa.feature.rms defaults:
    - frame_length=2048
    - hop_length=512  
    - center=True
    """
    frame_length = 2048
    hop_length = 512
    
    # Handle short waveforms
    if len(waveform) < frame_length:
        # Pad to minimum length
        pad_needed = frame_length - len(waveform)
        waveform = torch.nn.functional.pad(waveform, (0, pad_needed), mode='constant', value=0)
    
    # Pad for centering (matches librosa center=True)
    pad_length = frame_length // 2
    waveform_padded = torch.nn.functional.pad(waveform, (pad_length, pad_length), mode='reflect')
    
    # Calculate number of frames
    num_frames = 1 + (len(waveform_padded) - frame_length) // hop_length
    
    if num_frames <= 0:
        return 0.0
    
    # Compute RMS per frame (matches librosa exactly)
    rms_values = []
    for i in range(num_frames):
        start = i * hop_length
        end = start + frame_length
        if end > len(waveform_padded):
            break
        frame = waveform_padded[start:end]
        rms = torch.sqrt(torch.mean(frame ** 2))
        rms_values.append(rms)
    
    if len(rms_values) == 0:
        return 0.0
    
    rms_tensor = torch.stack(rms_values)
    
    # SNR calculation: assume first 0.5 seconds is noise
    noise_frames = int(0.5 * sr / hop_length)
    
    if noise_frames > 0 and noise_frames < len(rms_tensor):
        noise_rms = torch.mean(rms_tensor[:noise_frames])
        signal_rms = torch.mean(rms_tensor)
        
        if noise_rms > 0:
            snr_db = 20 * torch.log10(signal_rms / noise_rms)
            return snr_db.item()
    
    return 0.0


def spectral_rolloff_batch(spec_mag, freqs, sr, roll_percent=0.85):
    """
    Calculate spectral rolloff matching librosa.feature.spectral_rolloff() exactly.
    
    Args:
        spec_mag: Magnitude spectrogram (freq_bins x time_frames) on GPU
        freqs: Frequency values for each bin
        sr: Sample rate
        roll_percent: Rolloff percentage (default 0.85)
    """
    # Cumulative sum along frequency axis
    cumsum = torch.cumsum(spec_mag, dim=0)
    total_energy = cumsum[-1, :]
    
    # Find frequency where cumsum reaches roll_percent of total
    threshold = roll_percent * total_energy
    
    # For each frame, find the bin where threshold is crossed
    rolloff_bins = torch.argmax((cumsum >= threshold.unsqueeze(0)).float(), dim=0)
    
    # Convert bins to Hz
    rolloff_hz = freqs[rolloff_bins]
    
    # Return median (matches librosa convention)
    return torch.median(rolloff_hz).item()


def spectral_centroid_batch(spec_mag, freqs, sr):
    """
    Calculate spectral centroid matching librosa.feature.spectral_centroid() exactly.
    
    Args:
        spec_mag: Magnitude spectrogram (freq_bins x time_frames) on GPU
        freqs: Frequency values for each bin
        sr: Sample rate
    """
    # Weighted average frequency per frame
    freqs_2d = freqs.unsqueeze(1)  # (freq_bins, 1)
    
    # Weighted sum
    centroid = torch.sum(freqs_2d * spec_mag, dim=0) / (torch.sum(spec_mag, dim=0) + 1e-8)
    
    # Return median (matches librosa convention)
    return torch.median(centroid).item()


def spectral_flatness_batch(spec_mag):
    """
    Calculate spectral flatness matching librosa.feature.spectral_flatness() exactly.
    
    Spectral flatness = geometric_mean / arithmetic_mean per frame
    
    Args:
        spec_mag: Magnitude spectrogram (freq_bins x time_frames) on GPU
    """
    # Add small epsilon to avoid log(0)
    spec_safe = spec_mag + 1e-10
    
    # Geometric mean per frame: exp(mean(log(x)))
    log_spec = torch.log(spec_safe)
    geometric_mean = torch.exp(torch.mean(log_spec, dim=0))
    
    # Arithmetic mean per frame
    arithmetic_mean = torch.mean(spec_mag, dim=0)
    
    # Flatness per frame
    flatness = geometric_mean / (arithmetic_mean + 1e-10)
    
    # Return mean across frames (matches librosa)
    return torch.mean(flatness).item()


def zcr_batch(waveform):
    """
    Calculate zero crossing rate matching librosa.feature.zero_crossing_rate() exactly.
    
    librosa defaults:
    - frame_length=2048
    - hop_length=512
    - center=True
    """
    frame_length = 2048
    hop_length = 512
    
    # Handle short waveforms
    if len(waveform) < frame_length:
        # Pad to minimum length
        pad_needed = frame_length - len(waveform)
        waveform = torch.nn.functional.pad(waveform, (0, pad_needed), mode='constant', value=0)
    
    # Pad for centering
    pad_length = frame_length // 2
    waveform_padded = torch.nn.functional.pad(waveform, (pad_length, pad_length), mode='constant', value=0)
    
    # Zero crossings: sign changes
    signs = torch.sign(waveform_padded)
    signs[signs == 0] = 0  # Treat zeros as positive (librosa convention)
    
    # Indicator where sign changes
    sign_changes = torch.abs(torch.diff(signs)) > 0
    
    # Calculate number of frames
    num_frames = 1 + (len(waveform_padded) - frame_length) // hop_length
    
    if num_frames <= 0:
        return 0.0, 0.0
    
    # ZCR per frame
    zcr_values = []
    for i in range(num_frames):
        start = i * hop_length
        end = start + frame_length - 1  # diff reduces length by 1
        if end > len(sign_changes):
            break
        frame = sign_changes[start:end]
        zcr = torch.sum(frame.float()) / (frame_length - 1)
        zcr_values.append(zcr)
    
    if len(zcr_values) == 0:
        return 0.0, 0.0
    
    zcr_tensor = torch.stack(zcr_values)
    
    # Return mean and variance
    return torch.mean(zcr_tensor).item(), torch.var(zcr_tensor).item()


def loudness_cal(waveform, sr):
    """
    Calculate loudness (LUFS) - requires CPU.
    Uses pyloudnorm (matches original).
    """
    wv_cpu = waveform.cpu().numpy()
    
    # Handle very short audio (pyloudnorm needs at least 0.4s)
    min_length = int(0.4 * sr)
    if len(wv_cpu) < min_length:
        wv_cpu = np.pad(wv_cpu, (0, min_length - len(wv_cpu)), mode='constant', constant_values=0)
    
    meter = pyln.Meter(sr)
    loudness = meter.integrated_loudness(wv_cpu)
    return float(loudness)


def low_frequency_energy_batch(spec_mag, freqs, sr, cutoff_hz=80):
    """
    Calculate low frequency energy ratio matching original librosa implementation exactly.
    
    Args:
        spec_mag: Magnitude spectrogram from librosa.stft() (freq_bins x time_frames) on GPU
        freqs: Frequency values for each bin (matches librosa.fft_frequencies)
        sr: Sample rate
        cutoff_hz: Cutoff frequency for low frequency energy
    """
    # Energy below cutoff
    low_freq_mask = freqs < cutoff_hz
    low_energy = torch.sum(spec_mag[low_freq_mask, :])
    total_energy = torch.sum(spec_mag)
    
    # Return ratio
    if total_energy > 0:
        ratio = low_energy / total_energy
        return ratio.item()
    else:
        return 0.0

## Batch Processing Function

In [None]:
def analyze_audio_batch_gpu(audio_data_list):
    """
    Analyze a batch of audio files on GPU.
    
    Args:
        audio_data_list: List of (audio_id, audio_bytes) tuples
    
    Returns:
        List of result dictionaries
    """
    results = []
    waveforms = []
    audio_ids = []
    
    # Load all audio files
    for audio_id, audio_bytes in audio_data_list:
        try:
            wv, sr = load_audio_to_tensor(audio_bytes, target_sr=TARGET_SR)
            waveforms.append(wv)
            audio_ids.append(audio_id)
        except Exception as e:
            results.append({
                'audio_id': audio_id,
                'status': 'error',
                'error_message': f"Load error: {str(e)}"
            })
    
    if not waveforms:
        return results
    
    # Compute spectrograms on GPU (batch) - now returns (mag, freqs) tuples
    spec_results = batch_compute_spectrogram(waveforms, TARGET_SR)
    
    # Process each file
    for i, (audio_id, wv) in enumerate(zip(audio_ids, waveforms)):
        try:
            spec_mag, freqs = spec_results[i]
            
            # Move waveform to GPU for processing
            wv_gpu = wv.to(device)
            
            # Calculate metrics (matching librosa implementations exactly)
            duration_sec = len(wv) / TARGET_SR
            
            # GPU-based metrics (now matching librosa exactly)
            snr = snr_cal_batch(wv_gpu, TARGET_SR)
            rolloff = spectral_rolloff_batch(spec_mag, freqs, TARGET_SR)
            centroid = spectral_centroid_batch(spec_mag, freqs, TARGET_SR)
            flatness = spectral_flatness_batch(spec_mag)
            zcr_mean, zcr_var = zcr_batch(wv_gpu)
            low_freq_ratio = low_frequency_energy_batch(spec_mag, freqs, TARGET_SR)
            
            # CPU-based metric (loudness)
            loudness = loudness_cal(wv, TARGET_SR)
            
            results.append({
                'audio_id': audio_id,
                'sample_rate': TARGET_SR,
                'duration_sec': float(duration_sec),
                'snr_db': float(snr),
                'spectral_rolloff_hz': float(rolloff),
                'spectral_flatness': float(flatness),
                'spectral_centroid_hz': float(centroid),
                'zcr_mean': float(zcr_mean),
                'zcr_var': float(zcr_var),
                'loudness_lufs': float(loudness),
                'low_freq_energy_ratio': float(low_freq_ratio),
                'status': 'success'
            })
            
        except Exception as e:
            results.append({
                'audio_id': audio_id,
                'status': 'error',
                'error_message': f"Processing error: {str(e)}"
            })
    
    # Clear GPU memory
    del waveforms, spec_results
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    
    return results

## Parallel Download Functions

In [None]:
def download_single_blob(blob_path):
    """
    Download a single blob and return (audio_id, audio_bytes).
    """
    audio_id = Path(blob_path).parent.name  # Extract ID from path
    try:
        audio_bytes = download_blob_to_memory(blob_path)
        return (audio_id, audio_bytes, None)
    except Exception as e:
        return (audio_id, None, str(e))


def download_batch_parallel(blob_paths, max_workers=100):
    """
    Download multiple blobs in parallel.
    
    Args:
        blob_paths: List of blob paths to download
        max_workers: Number of parallel downloads
    
    Returns:
        List of (audio_id, audio_bytes) tuples (only successful downloads)
        List of error dictionaries
    """
    audio_data = []
    errors = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(download_single_blob, path): path for path in blob_paths}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading"):
            audio_id, audio_bytes, error = future.result()
            
            if error:
                errors.append({
                    'audio_id': audio_id,
                    'status': 'error',
                    'error_message': f"Download error: {error}"
                })
            else:
                audio_data.append((audio_id, audio_bytes))
    
    return audio_data, errors

## List All Audio Files

In [None]:
# List all audio/video blobs from Azure
blob_prefix = "loc_vhp/"

print(f"Listing blobs with prefix: {blob_prefix}")
audio_blobs = list_blobs(blob_prefix)

# Filter for ORIGINAL VHP source files only (not NFA-segmented WAV files)
# VHP project naming convention: audio.mp3 or video.mp4
audio_blobs = [b for b in audio_blobs if b.endswith('/audio.mp3') or b.endswith('/video.mp4')]

print(f"Found {len(audio_blobs)} original VHP media files")
print(f"\nFirst 5 files:")
for blob in audio_blobs[:5]:
    print(f"  - {blob}")

In [None]:
# TESTING: Uncomment to limit processing for testing
# SAMPLE_SIZE = 100
# audio_blobs = audio_blobs[:SAMPLE_SIZE]
# print(f"\nðŸ§ª TEST MODE: Processing only first {SAMPLE_SIZE} files")

## Process All Files (Download + GPU Processing Pipeline)

In [None]:
import time

all_results = []
total_files = len(audio_blobs)

print(f"\nProcessing {total_files} files...")
print(f"Download batch size: {DOWNLOAD_BATCH_SIZE}")
print(f"GPU batch size: {BATCH_SIZE}")
print(f"Download workers: {DOWNLOAD_WORKERS}")
print("="*60)

start_time = time.time()

# Process in large download batches
for i in range(0, total_files, DOWNLOAD_BATCH_SIZE):
    batch_blobs = audio_blobs[i:i + DOWNLOAD_BATCH_SIZE]
    
    print(f"\n[Batch {i//DOWNLOAD_BATCH_SIZE + 1}] Downloading {len(batch_blobs)} files...")
    
    # Download batch in parallel
    audio_data, download_errors = download_batch_parallel(batch_blobs, max_workers=DOWNLOAD_WORKERS)
    all_results.extend(download_errors)
    
    print(f"Downloaded: {len(audio_data)} files, Errors: {len(download_errors)}")
    
    # Process downloaded files in GPU batches
    print(f"Processing on GPU...")
    for j in range(0, len(audio_data), BATCH_SIZE):
        gpu_batch = audio_data[j:j + BATCH_SIZE]
        batch_results = analyze_audio_batch_gpu(gpu_batch)
        all_results.extend(batch_results)
    
    # Clear memory after batch
    del audio_data
    gc.collect()
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    
    # Progress update
    elapsed = time.time() - start_time
    processed = min(i + DOWNLOAD_BATCH_SIZE, total_files)
    rate = processed / elapsed
    remaining = (total_files - processed) / rate if rate > 0 else 0
    
    print(f"Progress: {processed}/{total_files} ({processed/total_files*100:.1f}%)")
    print(f"Elapsed: {elapsed/60:.1f} min, Rate: {rate:.1f} files/sec, ETA: {remaining/60:.1f} min")

total_time = time.time() - start_time

print("\n" + "="*60)
print(f"âœ“ COMPLETED")
print(f"Total files: {len(all_results)}")
print(f"Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"Average rate: {len(all_results)/total_time:.2f} files/second")
print(f"Success: {sum(1 for r in all_results if r.get('status') == 'success')}")
print(f"Errors: {sum(1 for r in all_results if r.get('status') == 'error')}")

## Issue Detection and Preprocessing Recommendations

In [None]:
def detect_audio_issues(row):
    """Detect audio issues based on quality metrics."""
    issues = []
    
    if row['status'] != 'success':
        return issues
    
    # 1. Bandwidth-limited
    if row['spectral_rolloff_hz'] < 1000:
        issues.append('bandwidth_limited_severe')
    elif row['spectral_rolloff_hz'] < 4000:
        issues.append('bandwidth_limited_moderate')
    
    # 2. High noise
    if row['zcr_mean'] > 0.05:
        issues.append('high_noise_zcr')
    if row['snr_db'] < 15:
        issues.append('high_noise_snr')
    
    # 3. Low-frequency rumble
    if row['low_freq_energy_ratio'] > 0.15:
        issues.append('low_frequency_rumble')
    
    # 4. Low loudness
    if row['loudness_lufs'] < -30:
        issues.append('low_loudness')
    
    # 5. Very flat spectrum
    if row['spectral_flatness'] > 0.8:
        issues.append('very_flat_spectrum')
    
    return issues


def recommend_preprocessing(issues):
    """Recommend preprocessing methods based on detected issues."""
    recommendations = []
    
    # Always normalize loudness
    recommendations.append('loudness_normalization')
    
    # Bandwidth-limited â†’ EQ boost
    if 'bandwidth_limited_severe' in issues or 'bandwidth_limited_moderate' in issues:
        recommendations.append('eq_high_freq_boost')
    
    # High noise â†’ Noise reduction
    if 'high_noise_zcr' in issues or 'high_noise_snr' in issues:
        recommendations.append('noise_reduction')
    
    # Low-frequency rumble â†’ High-pass filter
    if 'low_frequency_rumble' in issues:
        recommendations.append('highpass_filter')
    
    return recommendations


# Convert to DataFrame
df = pd.DataFrame(all_results)

# Detect issues
df['issues'] = df.apply(detect_audio_issues, axis=1)
df['recommended_preprocessing'] = df['issues'].apply(recommend_preprocessing)

print(f"\nFiles with issues: {(df['issues'].str.len() > 0).sum()} / {len(df)}")

from collections import Counter
all_issues = [issue for issues in df['issues'] for issue in issues]
issue_counts = Counter(all_issues)

print("\nIssue breakdown:")
for issue, count in issue_counts.most_common():
    print(f"  {issue}: {count} files ({count/len(df)*100:.1f}%)")

## Summary Statistics

In [None]:
df_success = df[df['status'] == 'success']

print("Overall Statistics:")
print(df_success[['snr_db', 'spectral_rolloff_hz', 'spectral_flatness', 
                   'zcr_mean', 'loudness_lufs']].describe())

print("\nPreprocessing Recommendations:")
all_recs = [rec for recs in df['recommended_preprocessing'] for rec in recs]
rec_counts = Counter(all_recs)

for rec, count in rec_counts.most_common():
    print(f"{rec:30s}: {count:4d} files ({count/len(df)*100:5.1f}%)")

## Save Results

In [None]:
output_path = Path("../data/audio_quality_analysis.parquet")
df.to_parquet(output_path, index=False)

print(f"âœ“ Saved: {output_path}")
print(f"Rows: {len(df)}")
print(f"Columns: {list(df.columns)}")
print(f"\nFile size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")

## Cleanup

In [None]:
# Clean up temp directory
import shutil

if TEMP_DIR.exists():
    shutil.rmtree(TEMP_DIR)
    print(f"âœ“ Cleaned up temp directory: {TEMP_DIR}")

# Clear GPU memory
if device.type == 'cuda':
    torch.cuda.empty_cache()
    print("âœ“ Cleared GPU memory")

print("\nâœ“ All done!")