# Audio Quality Analysis - GPU-Accelerated Batch Processing

**Optimizations:**
- Parallel Azure blob downloads (50-100 files at once)
- GPU-accelerated audio processing with torchaudio
- Batch spectrogram computation on GPU
- Automatic cleanup of downloaded files
- Multi-threading for download + GPU processing pipeline

**Target:** Process 5000 files in ~1-3 hours (vs 25-30 hours)

**Output:** Same as original - `audio_quality_analysis.parquet`

In [None]:
import sys
sys.path.append("../scripts")

import torch
import torchaudio
import numpy as np
import pandas as pd
from pathlib import Path
import pyloudnorm as pyln
from tqdm import tqdm
import io
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
import gc
import os

# Azure blob utilities
from azure_utils import list_blobs, download_blob_to_memory

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Set Azure authentication environment variables
os.environ['AZURE_STORAGE_ACCOUNT'] = 'stgamiadata26828'
os.environ['AZURE_STORAGE_CONTAINER'] = 'audio-raw'
os.environ['AZURE_AUTH'] = 'connection_string'
os.environ['AZURE_STORAGE_CONNECTION_STRING'] = 'DefaultEndpointsProtocol=https;AccountName=stgamiadata26828;AccountKey=Ol7WsOhceB+UxH5x33nfL6dZZLG4coJBgaWAqsbuzMMZLZKnjCS8BCbeinEIdN/h8437NQosRiAI+AStBeJqdw==;EndpointSuffix=core.windows.net'

print("âœ“ Azure credentials set in environment")

## Configuration

In [None]:
# Processing configuration
BATCH_SIZE = 64  # Number of files to process on GPU simultaneously
DOWNLOAD_WORKERS = 100  # Parallel downloads
DOWNLOAD_BATCH_SIZE = 200  # Files to download before processing
TARGET_SR = 16000  # Target sample rate

# Temp directory for downloaded files (will be cleaned up)
TEMP_DIR = Path(tempfile.mkdtemp())
print(f"Temp directory: {TEMP_DIR}")
print(f"Batch size: {BATCH_SIZE} files")
print(f"Download workers: {DOWNLOAD_WORKERS}")
print(f"Download batch size: {DOWNLOAD_BATCH_SIZE}")

## GPU-Accelerated Audio Metric Functions

In [None]:
def load_audio_to_tensor(audio_bytes, target_sr=16000):
    """
    Load audio bytes to torch tensor and resample.
    Handles MP3, MP4, WAV, etc.
    """
    # Save to temp file (torchaudio needs file path)
    with tempfile.NamedTemporaryFile(suffix='.tmp', delete=False) as tmp:
        tmp.write(audio_bytes)
        tmp_path = tmp.name
    
    try:
        # Load audio
        waveform, sr = torchaudio.load(tmp_path)
        
        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample if needed
        if sr != target_sr:
            resampler = torchaudio.transforms.Resample(sr, target_sr)
            waveform = resampler(waveform)
        
        return waveform.squeeze(), target_sr
    
    finally:
        # Clean up temp file
        os.unlink(tmp_path)


def batch_compute_spectrogram(waveforms, sr, n_fft=2048, hop_length=512):
    """
    Compute spectrograms for batch of waveforms on GPU.
    
    Args:
        waveforms: List of 1D tensors (different lengths OK)
        sr: Sample rate
    
    Returns:
        List of spectrograms (mag + phase) on GPU
    """
    spectrograms = []
    
    for wv in waveforms:
        # Move to GPU
        wv_gpu = wv.to(device)
        
        # Compute spectrogram
        spec_transform = torchaudio.transforms.Spectrogram(
            n_fft=n_fft,
            hop_length=hop_length,
            power=None  # Returns complex
        ).to(device)
        
        spec = spec_transform(wv_gpu)
        mag = torch.abs(spec)
        spectrograms.append(mag)
    
    return spectrograms


def snr_cal_batch(waveform, sr):
    """Calculate SNR for single waveform (tensor)"""
    # RMS calculation
    frame_length = 2048
    hop_length = 512
    
    # Compute RMS per frame
    rms = torch.sqrt(torch.nn.functional.unfold(
        waveform.unsqueeze(0).unsqueeze(0).unsqueeze(0),
        kernel_size=(1, frame_length),
        stride=(1, hop_length)
    ).pow(2).mean(dim=1)).squeeze()
    
    # Assume first 0.5 seconds is noise
    noise_frames = int(0.5 * sr / hop_length)
    noise_rms = torch.mean(rms[:noise_frames])
    signal_rms = torch.mean(rms)
    
    snr_db = 20 * torch.log10(signal_rms / (noise_rms + 1e-8))
    return snr_db.item()


def spectral_rolloff_batch(spec, sr, roll_percent=0.85):
    """
    Calculate spectral rolloff from spectrogram (GPU tensor).
    """
    # Cumulative sum along frequency axis
    cumsum = torch.cumsum(spec, dim=0)
    total_energy = cumsum[-1, :]
    
    # Find frequency where cumsum reaches roll_percent of total
    threshold = roll_percent * total_energy
    rolloff_bins = torch.argmax((cumsum >= threshold).float(), dim=0)
    
    # Convert bins to Hz
    freqs = torch.linspace(0, sr/2, spec.shape[0], device=device)
    rolloff_hz = freqs[rolloff_bins]
    
    return torch.median(rolloff_hz).item()


def spectral_centroid_batch(spec, sr):
    """
    Calculate spectral centroid from spectrogram (GPU tensor).
    """
    freqs = torch.linspace(0, sr/2, spec.shape[0], device=device).unsqueeze(1)
    
    # Weighted average frequency
    centroid = torch.sum(freqs * spec, dim=0) / (torch.sum(spec, dim=0) + 1e-8)
    
    return torch.median(centroid).item()


def spectral_flatness_batch(spec):
    """
    Calculate spectral flatness from spectrogram (GPU tensor).
    """
    # Geometric mean / arithmetic mean
    geometric_mean = torch.exp(torch.mean(torch.log(spec + 1e-8), dim=0))
    arithmetic_mean = torch.mean(spec, dim=0)
    
    flatness = geometric_mean / (arithmetic_mean + 1e-8)
    
    return torch.mean(flatness).item()


def zcr_batch(waveform):
    """
    Calculate zero crossing rate from waveform (GPU tensor).
    """
    # Sign changes
    signs = torch.sign(waveform)
    sign_changes = torch.abs(torch.diff(signs))
    
    # ZCR per frame
    frame_length = 2048
    hop_length = 512
    
    # Unfold into frames
    frames = sign_changes.unfold(0, frame_length, hop_length)
    zcr = torch.sum(frames > 0, dim=1).float() / frame_length
    
    return torch.mean(zcr).item(), torch.var(zcr).item()


def loudness_cal(waveform, sr):
    """
    Calculate loudness (LUFS) - requires CPU.
    """
    wv_cpu = waveform.cpu().numpy()
    meter = pyln.Meter(sr)
    loudness = meter.integrated_loudness(wv_cpu)
    return float(loudness)


def low_frequency_energy_batch(spec, sr, cutoff_hz=80):
    """
    Calculate low frequency energy ratio from spectrogram (GPU tensor).
    """
    freqs = torch.linspace(0, sr/2, spec.shape[0], device=device)
    low_freq_mask = freqs < cutoff_hz
    
    low_energy = torch.sum(spec[low_freq_mask, :])
    total_energy = torch.sum(spec)
    
    ratio = low_energy / (total_energy + 1e-8)
    return ratio.item()

## Batch Processing Function

In [None]:
def analyze_audio_batch_gpu(audio_data_list):
    """
    Analyze a batch of audio files on GPU.
    
    Args:
        audio_data_list: List of (audio_id, audio_bytes) tuples
    
    Returns:
        List of result dictionaries
    """
    results = []
    waveforms = []
    audio_ids = []
    
    # Load all audio files
    for audio_id, audio_bytes in audio_data_list:
        try:
            wv, sr = load_audio_to_tensor(audio_bytes, target_sr=TARGET_SR)
            waveforms.append(wv)
            audio_ids.append(audio_id)
        except Exception as e:
            results.append({
                'audio_id': audio_id,
                'status': 'error',
                'error_message': f"Load error: {str(e)}"
            })
    
    if not waveforms:
        return results
    
    # Compute spectrograms on GPU (batch)
    spectrograms = batch_compute_spectrogram(waveforms, TARGET_SR)
    
    # Process each file
    for i, (audio_id, wv, spec) in enumerate(zip(audio_ids, waveforms, spectrograms)):
        try:
            # Move waveform to GPU for processing
            wv_gpu = wv.to(device)
            
            # Calculate metrics
            duration_sec = len(wv) / TARGET_SR
            
            # GPU-based metrics
            snr = snr_cal_batch(wv_gpu, TARGET_SR)
            rolloff = spectral_rolloff_batch(spec, TARGET_SR)
            centroid = spectral_centroid_batch(spec, TARGET_SR)
            flatness = spectral_flatness_batch(spec)
            zcr_mean, zcr_var = zcr_batch(wv_gpu)
            low_freq_ratio = low_frequency_energy_batch(spec, TARGET_SR)
            
            # CPU-based metric (loudness)
            loudness = loudness_cal(wv, TARGET_SR)
            
            results.append({
                'audio_id': audio_id,
                'sample_rate': TARGET_SR,
                'duration_sec': float(duration_sec),
                'snr_db': float(snr),
                'spectral_rolloff_hz': float(rolloff),
                'spectral_flatness': float(flatness),
                'spectral_centroid_hz': float(centroid),
                'zcr_mean': float(zcr_mean),
                'zcr_var': float(zcr_var),
                'loudness_lufs': float(loudness),
                'low_freq_energy_ratio': float(low_freq_ratio),
                'status': 'success'
            })
            
        except Exception as e:
            results.append({
                'audio_id': audio_id,
                'status': 'error',
                'error_message': f"Processing error: {str(e)}"
            })
    
    # Clear GPU memory
    del waveforms, spectrograms
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    
    return results

## Parallel Download Functions

In [None]:
def download_single_blob(blob_path):
    """
    Download a single blob and return (audio_id, audio_bytes).
    """
    audio_id = Path(blob_path).parent.name  # Extract ID from path
    try:
        audio_bytes = download_blob_to_memory(blob_path)
        return (audio_id, audio_bytes, None)
    except Exception as e:
        return (audio_id, None, str(e))


def download_batch_parallel(blob_paths, max_workers=100):
    """
    Download multiple blobs in parallel.
    
    Args:
        blob_paths: List of blob paths to download
        max_workers: Number of parallel downloads
    
    Returns:
        List of (audio_id, audio_bytes) tuples (only successful downloads)
        List of error dictionaries
    """
    audio_data = []
    errors = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(download_single_blob, path): path for path in blob_paths}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading"):
            audio_id, audio_bytes, error = future.result()
            
            if error:
                errors.append({
                    'audio_id': audio_id,
                    'status': 'error',
                    'error_message': f"Download error: {error}"
                })
            else:
                audio_data.append((audio_id, audio_bytes))
    
    return audio_data, errors

## List All Audio Files

In [None]:
# List all audio/video blobs from Azure
blob_prefix = "loc_vhp/"

print(f"Listing blobs with prefix: {blob_prefix}")
audio_blobs = list_blobs(blob_prefix)

# Filter for audio AND video files
media_extensions = ('.mp3', '.mp4', '.wav', '.m4a', '.flac', '.ogg')
audio_blobs = [b for b in audio_blobs if b.lower().endswith(media_extensions)]

print(f"Found {len(audio_blobs)} media files")
print(f"\nFirst 5 files:")
for blob in audio_blobs[:5]:
    print(f"  - {blob}")

In [None]:
# TESTING: Uncomment to limit processing for testing
# SAMPLE_SIZE = 100
# audio_blobs = audio_blobs[:SAMPLE_SIZE]
# print(f"\nðŸ§ª TEST MODE: Processing only first {SAMPLE_SIZE} files")

## Process All Files (Download + GPU Processing Pipeline)

In [None]:
import time

all_results = []
total_files = len(audio_blobs)

print(f"\nProcessing {total_files} files...")
print(f"Download batch size: {DOWNLOAD_BATCH_SIZE}")
print(f"GPU batch size: {BATCH_SIZE}")
print(f"Download workers: {DOWNLOAD_WORKERS}")
print("="*60)

start_time = time.time()

# Process in large download batches
for i in range(0, total_files, DOWNLOAD_BATCH_SIZE):
    batch_blobs = audio_blobs[i:i + DOWNLOAD_BATCH_SIZE]
    
    print(f"\n[Batch {i//DOWNLOAD_BATCH_SIZE + 1}] Downloading {len(batch_blobs)} files...")
    
    # Download batch in parallel
    audio_data, download_errors = download_batch_parallel(batch_blobs, max_workers=DOWNLOAD_WORKERS)
    all_results.extend(download_errors)
    
    print(f"Downloaded: {len(audio_data)} files, Errors: {len(download_errors)}")
    
    # Process downloaded files in GPU batches
    print(f"Processing on GPU...")
    for j in range(0, len(audio_data), BATCH_SIZE):
        gpu_batch = audio_data[j:j + BATCH_SIZE]
        batch_results = analyze_audio_batch_gpu(gpu_batch)
        all_results.extend(batch_results)
    
    # Clear memory after batch
    del audio_data
    gc.collect()
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    
    # Progress update
    elapsed = time.time() - start_time
    processed = min(i + DOWNLOAD_BATCH_SIZE, total_files)
    rate = processed / elapsed
    remaining = (total_files - processed) / rate if rate > 0 else 0
    
    print(f"Progress: {processed}/{total_files} ({processed/total_files*100:.1f}%)")
    print(f"Elapsed: {elapsed/60:.1f} min, Rate: {rate:.1f} files/sec, ETA: {remaining/60:.1f} min")

total_time = time.time() - start_time

print("\n" + "="*60)
print(f"âœ“ COMPLETED")
print(f"Total files: {len(all_results)}")
print(f"Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"Average rate: {len(all_results)/total_time:.2f} files/second")
print(f"Success: {sum(1 for r in all_results if r.get('status') == 'success')}")
print(f"Errors: {sum(1 for r in all_results if r.get('status') == 'error')}")

## Issue Detection and Preprocessing Recommendations

In [None]:
def detect_audio_issues(row):
    """Detect audio issues based on quality metrics."""
    issues = []
    
    if row['status'] != 'success':
        return issues
    
    # 1. Bandwidth-limited
    if row['spectral_rolloff_hz'] < 1000:
        issues.append('bandwidth_limited_severe')
    elif row['spectral_rolloff_hz'] < 4000:
        issues.append('bandwidth_limited_moderate')
    
    # 2. High noise
    if row['zcr_mean'] > 0.05:
        issues.append('high_noise_zcr')
    if row['snr_db'] < 15:
        issues.append('high_noise_snr')
    
    # 3. Low-frequency rumble
    if row['low_freq_energy_ratio'] > 0.15:
        issues.append('low_frequency_rumble')
    
    # 4. Low loudness
    if row['loudness_lufs'] < -30:
        issues.append('low_loudness')
    
    # 5. Very flat spectrum
    if row['spectral_flatness'] > 0.8:
        issues.append('very_flat_spectrum')
    
    return issues


def recommend_preprocessing(issues):
    """Recommend preprocessing methods based on detected issues."""
    recommendations = []
    
    # Always normalize loudness
    recommendations.append('loudness_normalization')
    
    # Bandwidth-limited â†’ EQ boost
    if 'bandwidth_limited_severe' in issues or 'bandwidth_limited_moderate' in issues:
        recommendations.append('eq_high_freq_boost')
    
    # High noise â†’ Noise reduction
    if 'high_noise_zcr' in issues or 'high_noise_snr' in issues:
        recommendations.append('noise_reduction')
    
    # Low-frequency rumble â†’ High-pass filter
    if 'low_frequency_rumble' in issues:
        recommendations.append('highpass_filter')
    
    return recommendations


# Convert to DataFrame
df = pd.DataFrame(all_results)

# Detect issues
df['issues'] = df.apply(detect_audio_issues, axis=1)
df['recommended_preprocessing'] = df['issues'].apply(recommend_preprocessing)

print(f"\nFiles with issues: {(df['issues'].str.len() > 0).sum()} / {len(df)}")

from collections import Counter
all_issues = [issue for issues in df['issues'] for issue in issues]
issue_counts = Counter(all_issues)

print("\nIssue breakdown:")
for issue, count in issue_counts.most_common():
    print(f"  {issue}: {count} files ({count/len(df)*100:.1f}%)")

## Summary Statistics

In [None]:
df_success = df[df['status'] == 'success']

print("Overall Statistics:")
print(df_success[['snr_db', 'spectral_rolloff_hz', 'spectral_flatness', 
                   'zcr_mean', 'loudness_lufs']].describe())

print("\nPreprocessing Recommendations:")
all_recs = [rec for recs in df['recommended_preprocessing'] for rec in recs]
rec_counts = Counter(all_recs)

for rec, count in rec_counts.most_common():
    print(f"{rec:30s}: {count:4d} files ({count/len(df)*100:5.1f}%)")

## Save Results

In [None]:
output_path = Path("../data/audio_quality_analysis.parquet")
df.to_parquet(output_path, index=False)

print(f"âœ“ Saved: {output_path}")
print(f"Rows: {len(df)}")
print(f"Columns: {list(df.columns)}")
print(f"\nFile size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")

## Cleanup

In [None]:
# Clean up temp directory
import shutil

if TEMP_DIR.exists():
    shutil.rmtree(TEMP_DIR)
    print(f"âœ“ Cleaned up temp directory: {TEMP_DIR}")

# Clear GPU memory
if device.type == 'cuda':
    torch.cuda.empty_cache()
    print("âœ“ Cleared GPU memory")

print("\nâœ“ All done!")