In [1]:
import torch
import pandas as pd
import soundfile as sf
from pathlib import Path
from tqdm.auto import tqdm
import pickle
import warnings
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
import gc
import time
import os

warnings.filterwarnings('ignore')

# Import your encoder
from latent_audio_encoder import LatentAudioEncoder

2025-11-24 11:09:21.402708: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-24 11:09:30.791776: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-24 11:09:57.898420: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
CONFIG = {
    'csv_path': '../SoccerNet_audio_labels.csv',
    'output_dir': '../audio_embeddings_cache',
    'output_filename': 'audio_embeddings.pkl',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': 32,  # Process multiple audio files at once
    'num_audio_workers': 4,  # Parallel audio loading
    'segment_duration': 20.0,  # seconds
    'target_sample_rate': 16000,
    'checkpoint_frequency': 50,  # Save checkpoint every N batches
}

print("="*80)
print("üéµ AUDIO EMBEDDING EXTRACTION PIPELINE")
print("="*80)
print(f"üìç Device: {CONFIG['device']}")
print(f"üì¶ Batch size: {CONFIG['batch_size']}")
print(f"üîß Audio workers: {CONFIG['num_audio_workers']}")
print(f"‚è±Ô∏è  Segment duration: {CONFIG['segment_duration']}s")
print("="*80 + "\n")

üéµ AUDIO EMBEDDING EXTRACTION PIPELINE
üìç Device: cuda
üì¶ Batch size: 32
üîß Audio workers: 4
‚è±Ô∏è  Segment duration: 20.0s



In [3]:
def create_unique_key(audio_path, timestamp):
    """Create unique key for caching"""
    path_stem = Path(audio_path).stem
    return f"{path_stem}_{timestamp:.2f}"

def load_audio_segment(audio_path, center_time, segment_duration=20.0, target_sr=16000):
    """
    Load audio segment centered at timestamp
    Returns: waveform as numpy array, actual sample rate
    """
    try:
        # Load full audio
        data, sr = sf.read(audio_path, always_2d=True)
        
        # Calculate start/end frames
        start_time = max(0, center_time - segment_duration / 2)
        end_time = center_time + segment_duration / 2
        
        start_frame = int(start_time * sr)
        end_frame = int(end_time * sr)
        end_frame = min(end_frame, len(data))
        
        # Extract segment (mono)
        segment = data[start_frame:end_frame, 0]
        
        # Resample if needed
        if sr != target_sr:
            from scipy import signal
            num_samples = int(len(segment) * target_sr / sr)
            segment = signal.resample(segment, num_samples)
        
        # Pad or trim to exact length
        target_length = int(segment_duration * target_sr)
        if len(segment) < target_length:
            segment = np.pad(segment, (0, target_length - len(segment)))
        else:
            segment = segment[:target_length]
        
        return {
            'success': True,
            'waveform': segment,
            'sample_rate': target_sr
        }
    
    except Exception as e:
        return {
            'success': False,
            'error': str(e)
        }

def load_audio_batch_parallel(batch_df, segment_duration, target_sr, num_workers=8):
    """
    Load multiple audio files in parallel
    Returns: list of loaded audio data
    """
    loaded_data = []
    
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {
            executor.submit(
                load_audio_segment,
                row['audio_path'],
                row['time_seconds'],
                segment_duration,
                target_sr
            ): idx for idx, row in batch_df.iterrows()
        }
        
        for future in as_completed(futures):
            idx = futures[future]
            result = future.result()
            
            if result['success']:
                row = batch_df.loc[idx]
                loaded_data.append({
                    'key': create_unique_key(row['audio_path'], row['time_seconds']),
                    'waveform': result['waveform'],
                    'sample_rate': result['sample_rate'],
                    'metadata': {
                        'audio_path': row['audio_path'],
                        'timestamp': row['time_seconds'],
                        'label': row['label']
                    }
                })
            else:
                print(f"‚ö†Ô∏è  Failed to load audio for index {idx}: {result['error']}")
    
    return loaded_data

def process_batch_gpu(loaded_data, encoder, device):
    """
    Safe batch processing with automatic waveform normalization,
    stereo ‚Üí mono handling, and GPU-safe fallback.
    """
    embeddings = {}
    failed = []

    try:
        # --- Normalize all waveforms before batching ---
        normalized = []
        for item in loaded_data:
            wav = np.array(item["waveform"], dtype=np.float32)

            # Shape fixes
            if wav.ndim == 2:
                # stereo or multi-channel ‚Üí average to mono
                wav = wav.mean(axis=0).astype(np.float32)

            if wav.ndim != 1:
                failed.append({"key": item["key"],
                               "error": f"Invalid waveform shape: {wav.shape}"})
                continue

            # Now wav is guaranteed shape: (time,)
            normalized.append((item, wav))

        if not normalized:
            return {}, failed

        items, waveforms_np = zip(*normalized)

        # Convert to tensor batch [batch, time]
        waveforms = [torch.tensor(w, dtype=torch.float32) for w in waveforms_np]
        waveforms_batch = torch.stack(waveforms).to(device)

        # --- Try batch encoding ---
        with torch.no_grad():
            try:
                audio_embs = encoder(waveforms_batch, sampling_rate=16000)

                # Ensure shape [batch, dim]
                if audio_embs.dim() == 3:
                    # e.g. [batch, 1, dim]
                    audio_embs = audio_embs.squeeze(1)
                elif audio_embs.dim() == 1:
                    # single item case: [dim]
                    audio_embs = audio_embs.unsqueeze(0)

                for i, (item, _) in enumerate(normalized):
                    embeddings[item['key']] = {
                        "embedding": audio_embs[i].cpu(),
                        "metadata": item["metadata"]
                    }

            except Exception as batch_error:
                print(f"‚ö†Ô∏è Batch processing failed: {batch_error}")
                print("   Falling back to individual processing...")

                # --- GPU-safe fallback: one item at a time ---
                for item, wav in normalized:
                    try:
                        waveform = torch.tensor(wav, dtype=torch.float32).to(device)

                        with torch.no_grad():
                            audio_emb = encoder(waveform, sampling_rate=16000)

                        if audio_emb.dim() == 2:
                            audio_emb = audio_emb.squeeze(0)

                        embeddings[item["key"]] = {
                            "embedding": audio_emb.cpu(),
                            "metadata": item["metadata"]
                        }

                        # Prevent memory accumulation
                        del waveform
                        torch.cuda.empty_cache()

                    except Exception as e:
                        failed.append({"key": item["key"], "error": str(e)})

        # Clean up batch tensors
        del waveforms_batch
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"‚ùå process_batch_gpu failed: {e}")
        import traceback
        traceback.print_exc()

        for item in loaded_data:
            failed.append({"key": item["key"], "error": str(e)})

    return embeddings, failed


In [4]:
def extract_audio_embeddings(
    csv_path,
    output_dir,
    output_filename,
    encoder,
    device,
    batch_size=32,
    num_audio_workers=4,
    segment_duration=20.0,
    target_sr=16000,
    checkpoint_freq=50
):
    """
    Extract audio embeddings for all samples in CSV
    """
    
    # Load dataset
    print(f"üìÇ Loading dataset from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"üìä Total samples: {len(df)}\n")
    
    # Create output directory
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    checkpoint_path = output_path / 'checkpoints'
    checkpoint_path.mkdir(exist_ok=True)
    checkpoint_file = checkpoint_path / 'embedding_checkpoint.pkl'
    
    # Initialize storage
    embeddings_cache = {}
    failed_samples = []
    metadata_cache = {}
    
    # Load checkpoint if exists
    start_batch = 0
    if checkpoint_file.exists():
        try:
            print(f"üìÅ Loading checkpoint...")
            with open(checkpoint_file, 'rb') as f:
                checkpoint = pickle.load(f)
            embeddings_cache = checkpoint['embeddings']
            failed_samples = checkpoint.get('failed', [])
            metadata_cache = checkpoint.get('metadata', {})
            start_batch = checkpoint.get('last_batch', 0) + 1
            print(f"‚úÖ Resuming from batch {start_batch}")
            print(f"   Already processed: {len(embeddings_cache)} embeddings\n")
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not load checkpoint: {e}")
            print("   Starting fresh...\n")
    
    # Calculate batches
    num_batches = (len(df) + batch_size - 1) // batch_size
    
    # Track performance
    batch_times = []
    total_start = time.time()
    
    print("üöÄ Starting extraction...\n")
    
    with tqdm(total=len(df), initial=start_batch * batch_size, 
              desc="Extracting embeddings") as pbar:
        
        for batch_idx in range(start_batch, num_batches):
            batch_start = time.time()
            
            # Get batch
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(df))
            batch_df = df.iloc[start_idx:end_idx]
            
            # Step 1: Load audio in parallel
            loaded_data = load_audio_batch_parallel(
                batch_df,
                segment_duration,
                target_sr,
                num_workers=num_audio_workers
            )
            
            if not loaded_data:
                pbar.update(len(batch_df))
                continue
            
            # Step 2: Process on GPU
            batch_embeddings, batch_failed = process_batch_gpu(
                loaded_data,
                encoder,
                device
            )
            
            # Store results
            embeddings_cache.update(batch_embeddings)
            failed_samples.extend(batch_failed)
            
            # Update metadata
            for key, data in batch_embeddings.items():
                metadata_cache[key] = data['metadata']
            
            # Track timing
            batch_time = time.time() - batch_start
            batch_times.append(batch_time)
            
            # Calculate stats
            if len(batch_times) > 10:
                avg_time = sum(batch_times[-10:]) / 10
                samples_per_sec = batch_size / avg_time
                remaining = len(df) - end_idx
                eta_seconds = remaining / samples_per_sec
                
                pbar.set_postfix({
                    'embeddings': len(embeddings_cache),
                    'failed': len(failed_samples),
                    'speed': f'{samples_per_sec:.1f}/s',
                    'ETA': f'{eta_seconds/60:.0f}m'
                })
            else:
                pbar.set_postfix({
                    'embeddings': len(embeddings_cache),
                    'failed': len(failed_samples)
                })
            
            pbar.update(len(batch_df))
            
            # Checkpoint
            if (batch_idx + 1) % checkpoint_freq == 0:
                with open(checkpoint_file, 'wb') as f:
                    pickle.dump({
                        'embeddings': embeddings_cache,
                        'metadata': metadata_cache,
                        'failed': failed_samples,
                        'last_batch': batch_idx
                    }, f)
                
                pbar.write(f"üíæ Checkpoint: {len(embeddings_cache):,} embeddings saved")
                
                # Memory cleanup
                gc.collect()
                if device == 'cuda':
                    torch.cuda.empty_cache()
    
    # Final save
    total_time = time.time() - total_start
    
    print(f"\n{'='*80}")
    print("‚úÖ EXTRACTION COMPLETE!")
    print(f"{'='*80}")
    print(f"Total embeddings: {len(embeddings_cache):,}")
    print(f"Failed: {len(failed_samples)}")
    print(f"Total time: {total_time/60:.1f} minutes")
    print(f"Average speed: {len(df)/total_time:.1f} samples/second")
    
    # Save final results
    output_file = output_path / output_filename
    
    # Prepare final data structure
    final_data = {
        'embeddings': {k: v['embedding'] for k, v in embeddings_cache.items()},
        'metadata': metadata_cache,
        'config': {
            'segment_duration': segment_duration,
            'sample_rate': target_sr,
            'encoder': 'LatentAudioEncoder',
            'total_samples': len(embeddings_cache)
        }
    }
    
    print(f"\nüíæ Saving to {output_file}...")
    with open(output_file, 'wb') as f:
        pickle.dump(final_data, f)
    
    file_size = output_file.stat().st_size / (1024 * 1024)
    print(f"‚úÖ Saved! File size: {file_size:.1f} MB")
    
    # Save failed samples log
    if failed_samples:
        failed_file = output_path / 'failed_samples.pkl'
        with open(failed_file, 'wb') as f:
            pickle.dump(failed_samples, f)
        print(f"‚ö†Ô∏è  Failed samples log: {failed_file}")
    
    # Clean up checkpoint
    if checkpoint_file.exists():
        checkpoint_file.unlink()
        print(f"üóëÔ∏è  Removed checkpoint file")
    
    return embeddings_cache, metadata_cache, failed_samples


In [None]:
print("üîß Loading LatentAudioEncoder...")
encoder = LatentAudioEncoder().to(CONFIG['device'])
encoder.eval()

# Freeze parameters
for p in encoder.parameters():
    p.requires_grad = False

print(f"‚úÖ Encoder loaded on {CONFIG['device']}")
print(f"üì¶ Parameters: {sum(p.numel() for p in encoder.parameters()):,}\n")

# Extract embeddings
embeddings_cache, metadata_cache, failed_samples = extract_audio_embeddings(
    csv_path=CONFIG['csv_path'],
    output_dir=CONFIG['output_dir'],
    output_filename=CONFIG['output_filename'],
    encoder=encoder,
    device=CONFIG['device'],
    batch_size=CONFIG['batch_size'],
    num_audio_workers=CONFIG['num_audio_workers'],
    segment_duration=CONFIG['segment_duration'],
    target_sr=CONFIG['target_sample_rate'],
    checkpoint_freq=CONFIG['checkpoint_frequency']
)

print("\n" + "="*80)
print("üéâ PIPELINE COMPLETE!")
print("="*80)
print(f"\nüìä Statistics:")
print(f"   Total embeddings: {len(embeddings_cache):,}")
print(f"   Failed samples: {len(failed_samples)}")
print(f"   Success rate: {100 * len(embeddings_cache) / (len(embeddings_cache) + len(failed_samples)):.1f}%")

# Show embedding shape
if embeddings_cache:
    first_key = list(embeddings_cache.keys())[0]
    first_emb = embeddings_cache[first_key]
    

print("\n‚úÖ Ready for training!")
print(f"   Use: 'audio_embeddings_cache/audio_embeddings.pkl'")

üîß Loading LatentAudioEncoder...


`torch_dtype` is deprecated! Use `dtype` instead!
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ Encoder loaded on cuda
üì¶ Parameters: 315,438,720

üìÇ Loading dataset from: ../SoccerNet_audio_labels.csv
üìä Total samples: 66460

üöÄ Starting extraction...



Extracting embeddings:   0%|          | 0/66460 [00:00<?, ?it/s]

‚ö†Ô∏è Batch processing failed: cannot select an axis to squeeze out which has size not equal to one
   Falling back to individual processing...
‚ö†Ô∏è Batch processing failed: cannot select an axis to squeeze out which has size not equal to one
   Falling back to individual processing...
‚ö†Ô∏è Batch processing failed: cannot select an axis to squeeze out which has size not equal to one
   Falling back to individual processing...
‚ö†Ô∏è Batch processing failed: cannot select an axis to squeeze out which has size not equal to one
   Falling back to individual processing...
‚ö†Ô∏è Batch processing failed: cannot select an axis to squeeze out which has size not equal to one
   Falling back to individual processing...
‚ö†Ô∏è Batch processing failed: cannot select an axis to squeeze out which has size not equal to one
   Falling back to individual processing...
‚ö†Ô∏è Batch processing failed: cannot select an axis to squeeze out which has size not equal to one
   Falling back to individual

AttributeError: 'dict' object has no attribute 'shape'