# Test VEGAS Dataset Embeddings Cache

Questo notebook testa la creazione e il caricamento della cache degli embeddings per il dataset VEGAS.

## Obiettivi:
1. Caricare il dataset VEGAS
2. Calcolare gli embeddings audio tramite AST model
3. Salvare gli embeddings in una cache
4. Testare il ricaricamento dalla cache

In [37]:
# Setup working directory
import os
import sys

# Set working directory to project root
home_dir = os.path.expanduser('~')
project_dir = os.path.join(home_dir, 'fedgfe')
os.chdir(project_dir)
print(f"Working directory: {os.getcwd()}")

# Add system directory to path
sys.path.insert(0, os.path.join(project_dir, 'system'))

Working directory: /home/lpala/fedgfe


In [None]:
# Imports
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json
import pickle
from datetime import datetime

# Dataset imports
from system.datautils.dataset_vegas import VEGASDataset

# Model imports
from transformers import ASTFeatureExtractor, ASTModel

print("Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Configurazione e parametri

In [24]:
# Parametri
DATASET_PATH = 'dataset/Audio/VEGAS'
CACHE_DIR = 'cache/embeddings/vegas'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 8

# AST Model
AST_MODEL_NAME = 'MIT/ast-finetuned-audioset-10-10-0.4593'

# Selected classes (tutte le 10 classi VEGAS)
# SELECTED_CLASSES = [
#     'baby_cry', 'chainsaw', 'dog', 'drum', 'fireworks',
#     'helicopter', 'printer', 'rail_transport', 'snoring', 'water_flowing'
# ]

SELECTED_CLASSES = [
    'chainsaw'
]

print(f"Device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Cache directory: {CACHE_DIR}")
print(f"Selected classes: {SELECTED_CLASSES}")

Device: cuda
Batch size: 8
Cache directory: cache/embeddings/vegas
Selected classes: ['chainsaw']


## 2. Caricamento dataset VEGAS

In [25]:
# Carica dataset VEGAS
print("Loading VEGAS dataset...")
dataset = VEGASDataset(
    root_dir=DATASET_PATH,
    ast_cache_dir = CACHE_DIR,
    samples_per_node=100,
    node_split_id=0,
    selected_classes=SELECTED_CLASSES,
    enable_ast_cache = True,
)

print(f"Dataset loaded: {len(dataset)} samples")
print(f"Active classes: {dataset.active_classes}")

# Mostra un sample di esempio
sample = dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Audio shape: {sample['audio'].shape}")
print(f"Class name: {sample['class_name']}")
print(f"Sample ID: {sample.get('sample_id', 'N/A')}")

Loading VEGAS dataset...
Dataset loaded: 100 samples
Active classes: {'chainsaw': 1}

Sample keys: dict_keys(['audio', 'text_emb', 'label', 'audio_filename', 'image_filename', 'video_filename', 'class_name', 'video_id', 'file_id', 'sample_idx', 'ytid', 'start_second', 'end_second', 'caption'])
Audio shape: torch.Size([80000])
Class name: chainsaw
Sample ID: N/A


## 3. Inizializzazione AST Model

In [26]:
# Carica AST model e feature extractor
print(f"Loading AST model: {AST_MODEL_NAME}")
ast_feature_extractor = ASTFeatureExtractor.from_pretrained(AST_MODEL_NAME)
ast_model = ASTModel.from_pretrained(AST_MODEL_NAME)
ast_model = ast_model.to(DEVICE)
ast_model.eval()

print(f"AST model loaded on {DEVICE}")
print(f"Model parameters: {sum(p.numel() for p in ast_model.parameters()) / 1e6:.2f}M")

Loading AST model: MIT/ast-finetuned-audioset-10-10-0.4593
AST model loaded on cuda
Model parameters: 86.19M


## 4. Funzioni per calcolare e salvare embeddings

In [None]:
def build_index_map(dataset):
    """
    Costruisce mappa cache_key → dataset_index per accesso O(1).
    
    Args:
        dataset: VEGASDataset instance
    
    Returns:
        dict: {cache_key: dataset_index}
    """
    index_map = {}
    for idx in range(len(dataset)):
        sample = dataset.samples[idx]
        cache_key = f"{sample['class_name']}:{sample['file_id']}"
        index_map[cache_key] = idx
    return index_map


def compute_audio_embeddings(dataset, ast_model, ast_feature_extractor, device, batch_size=8):
    """
    Calcola gli audio embeddings per tutti i campioni del dataset.
    
    Uses cache_key = "{class_name}:{file_id}" (e.g., "dog:video_00001")
    
    Returns:
        tuple: (embeddings_array, cache_keys, index_map)
            embeddings_array: numpy array (n_samples, seq_len, hidden_dim)
            cache_keys: list of cache keys in order
            index_map: {cache_key: array_index}
    """
    all_embeddings = []
    cache_keys = []
    index_map = {}
    
    print(f"\nComputing embeddings for {len(dataset)} samples...")
    
    # Process in batches
    for i in tqdm(range(0, len(dataset), batch_size), desc="Processing batches"):
        batch_end = min(i + batch_size, len(dataset))
        batch_indices = list(range(i, batch_end))
        batch_samples = [dataset[j] for j in batch_indices]
        
        # Extract audio and metadata
        batch_audio = []
        batch_cache_keys = []
        
        for idx, sample in zip(batch_indices, batch_samples):
            audio = sample['audio']
            if isinstance(audio, torch.Tensor):
                audio = audio.cpu().numpy()
            batch_audio.append(audio)
            
            # Create cache key: {class_name}:{file_id}
            cache_key = f"{sample['class_name']}:{sample['file_id']}"
            batch_cache_keys.append(cache_key)
        
        # Process audio through AST
        with torch.no_grad():
            # Feature extraction
            audio_inputs = ast_feature_extractor(
                batch_audio,
                sampling_rate=16000,
                return_tensors="pt",
                padding=True
            ).input_values.to(device)
            
            # Get embeddings
            ast_output = ast_model(audio_inputs).last_hidden_state  # (batch, seq_len, hidden_dim)
            
            # Convert to numpy and store
            ast_output_np = ast_output.cpu().numpy()
            
            for emb, cache_key in zip(ast_output_np, batch_cache_keys):
                all_embeddings.append(emb)
                cache_keys.append(cache_key)
                index_map[cache_key] = len(cache_keys) - 1
        
        # Cleanup
        del audio_inputs, ast_output, ast_output_np
        if device == 'cuda':
            torch.cuda.empty_cache()
    
    # Stack all embeddings into single array
    embeddings_array = np.stack(all_embeddings, axis=0)
    
    print(f"\nComputed embeddings for {len(embeddings_array)} samples")
    print(f"Embeddings array shape: {embeddings_array.shape}")
    print(f"Created index map with {len(index_map)} entries")
    
    return embeddings_array, cache_keys, index_map


def save_embeddings_cache_mmap(embeddings_array, cache_keys, index_map, cache_dir, metadata=None):
    """
    Salva gli embeddings usando numpy memmap per accesso efficiente.
    
    Args:
        embeddings_array: numpy array (n_samples, seq_len, hidden_dim)
        cache_keys: list of cache keys
        index_map: Dictionary cache_key → array_index
        cache_dir: Directory dove salvare la cache
        metadata: Optional metadata dict
    """
    cache_path = Path(cache_dir)
    cache_path.mkdir(parents=True, exist_ok=True)
    
    # Salva embeddings come numpy array (per mmap)
    embeddings_file = cache_path / 'audio_embeddings.npy'
    np.save(embeddings_file, embeddings_array)
    print(f"\nEmbeddings array saved to: {embeddings_file}")
    print(f"File size: {embeddings_file.stat().st_size / (1024**2):.2f} MB")
    print(f"Shape: {embeddings_array.shape}, dtype: {embeddings_array.dtype}")
    
    # Salva cache_keys come lista (per ricostruire cache_key da index)
    cache_keys_file = cache_path / 'cache_keys.pkl'
    with open(cache_keys_file, 'wb') as f:
        pickle.dump(cache_keys, f)
    print(f"Cache keys saved to: {cache_keys_file}")
    
    # Salva index map
    index_map_file = cache_path / 'index_map.pkl'
    with open(index_map_file, 'wb') as f:
        pickle.dump(index_map, f)
    print(f"Index map saved to: {index_map_file}")
    print(f"File size: {index_map_file.stat().st_size / 1024:.2f} KB")
    
    # Salva metadata
    metadata_dict = metadata or {}
    metadata_dict.update({
        'num_samples': len(embeddings_array),
        'embedding_shape': list(embeddings_array.shape),
        'embedding_dtype': str(embeddings_array.dtype),
        'timestamp': datetime.now().isoformat(),
        'cache_version': '3.0',  # Version 3.0 with mmap support
        'cache_key_format': '{class_name}:{file_id}',
        'mmap_enabled': True
    })
    
    metadata_file = cache_path / 'metadata.json'
    with open(metadata_file, 'w') as f:
        json.dump(metadata_dict, f, indent=2)
    print(f"Metadata saved to: {metadata_file}")
    
    return embeddings_file


def load_embeddings_cache_mmap(cache_dir, mmap_mode='r'):
    """
    Carica gli embeddings usando numpy memmap (lazy loading).
    
    Args:
        cache_dir: Directory cache
        mmap_mode: Modalità mmap ('r' = read-only, 'r+' = read-write, 'c' = copy-on-write)
    
    Returns:
        tuple: (embeddings_mmap, index_map, cache_keys, metadata_dict)
            embeddings_mmap: numpy memmap array (accesso lazy)
            index_map: {cache_key: array_index}
            cache_keys: list of cache keys
            metadata_dict: metadata
    """
    cache_path = Path(cache_dir)
    
    # Load metadata first
    metadata_file = cache_path / 'metadata.json'
    metadata_dict = {}
    if metadata_file.exists():
        with open(metadata_file, 'r') as f:
            metadata_dict = json.load(f)
        print(f"Metadata: {metadata_dict}")
    
    # Load embeddings as memmap (NO CARICAMENTO IN RAM!)
    embeddings_file = cache_path / 'audio_embeddings.npy'
    if not embeddings_file.exists():
        raise FileNotFoundError(f"Cache file not found: {embeddings_file}")
    
    print(f"\nLoading embeddings with mmap from: {embeddings_file}")
    embeddings_mmap = np.load(embeddings_file, mmap_mode=mmap_mode)
    print(f"Embeddings mmap loaded: shape={embeddings_mmap.shape}, dtype={embeddings_mmap.dtype}")
    print(f"Memory mapped (not loaded in RAM)")
    
    # Load cache_keys
    cache_keys_file = cache_path / 'cache_keys.pkl'
    with open(cache_keys_file, 'rb') as f:
        cache_keys = pickle.load(f)
    print(f"Loaded {len(cache_keys)} cache keys")
    
    # Load index map
    index_map_file = cache_path / 'index_map.pkl'
    with open(index_map_file, 'rb') as f:
        index_map = pickle.load(f)
    print(f"Loaded index map with {len(index_map)} entries")
    
    return embeddings_mmap, index_map, cache_keys, metadata_dict


def get_embedding_from_cache(cache_key, embeddings_mmap, index_map):
    """
    Recupera un embedding dalla cache usando mmap (accesso lazy).
    
    Args:
        cache_key: Cache key (format: "{class_name}:{file_id}")
        embeddings_mmap: numpy memmap array
        index_map: Index map dictionary
    
    Returns:
        numpy array: embedding (seq_len, hidden_dim)
    """
    idx = index_map.get(cache_key)
    if idx is None:
        raise KeyError(f"Cache key '{cache_key}' not found in index_map")
    
    # L'accesso all'array mmap carica SOLO questo elemento in RAM
    return embeddings_mmap[idx]


def get_sample_from_cache_key(dataset, cache_key, index_map):
    """
    Recupera un sample dal dataset usando il cache_key.
    
    Args:
        dataset: VEGASDataset instance
        cache_key: Cache key (format: "{class_name}:{file_id}")
        index_map: Index map dictionary (dataset index map, not cache index map)
    
    Returns:
        Sample dictionary or None if not found
    """
    # Parse cache_key
    class_name, file_id = cache_key.split(':')
    
    # Linear search in dataset.samples
    for idx in range(len(dataset)):
        sample = dataset.samples[idx]
        if sample['file_id'] == file_id and sample['class_name'] == class_name:
            return dataset[idx]
    
    return None


print("Functions defined (Solution 2: with mmap support)")

## 5. Calcolo e salvataggio embeddings

In [None]:
# Compute embeddings
embeddings_array, cache_keys, index_map = compute_audio_embeddings(
    dataset=dataset,
    ast_model=ast_model,
    ast_feature_extractor=ast_feature_extractor,
    device=DEVICE,
    batch_size=BATCH_SIZE
)

# Mostra statistiche
print("\n=== Embeddings Statistics ===")
print(f"Embeddings array shape: {embeddings_array.shape}")
print(f"Embeddings array dtype: {embeddings_array.dtype}")
print(f"Total samples: {len(embeddings_array)}")
print(f"Memory size: {embeddings_array.nbytes / (1024**2):.2f} MB")

# Count per class
class_counts = {}
for cache_key in cache_keys:
    class_name = cache_key.split(':')[0]
    class_counts[class_name] = class_counts.get(class_name, 0) + 1

print("\nSamples per class:")
for class_name, count in sorted(class_counts.items()):
    print(f"  {class_name}: {count}")

# Show some example cache keys
print("\nExample cache keys:")
for i, cache_key in enumerate(cache_keys[:5]):
    print(f"  {cache_key}")

In [None]:
# Salva embeddings in cache con mmap support
metadata = {
    'dataset': 'VEGAS',
    'ast_model': AST_MODEL_NAME,
    'selected_classes': SELECTED_CLASSES,
    'samples_per_class': class_counts
}

cache_file = save_embeddings_cache_mmap(
    embeddings_array=embeddings_array,
    cache_keys=cache_keys,
    index_map=index_map,
    cache_dir=CACHE_DIR,
    metadata=metadata
)

## 6. Test caricamento dalla cache

In [None]:
# Test loading from cache
print("\n=== Testing cache loading ===")
loaded_embeddings, loaded_index_map, loaded_metadata = load_embeddings_cache(CACHE_DIR)

print(f"\nLoaded {len(loaded_embeddings)} embeddings")
print(f"Loaded {len(loaded_index_map)} index map entries")
print(f"Metadata: {loaded_metadata}")

# Verify data integrity
cache_key = list(loaded_embeddings.keys())[0]
sample_data = loaded_embeddings[cache_key]
print(f"\nExample cache key: {cache_key}")
print(f"Sample data keys: {sample_data.keys()}")
print(f"Embedding shape: {sample_data['embedding'].shape}")
print(f"Class name: {sample_data['class_name']}")
print(f"File ID: {sample_data['file_id']}")
print(f"Dataset index: {loaded_index_map[cache_key]}")

## 7. Test accesso con cache_key e index_map (Soluzione 2)

In [None]:
import random

# Prendi un cache_key casuale
test_cache_key = random.choice(list(loaded_embeddings.keys()))
print(f"Testing cache key: {test_cache_key}")

# Metodo 1: Accesso diretto usando index_map (O(1))
dataset_idx = loaded_index_map[test_cache_key]
print(f"Dataset index from index_map: {dataset_idx}")
sample = dataset[dataset_idx]

# Metodo 2: Usando la funzione helper
sample_alt = get_sample_from_cache_key(dataset, test_cache_key, loaded_index_map)

print(f"\nSample class: {sample['class_name']}")
print(f"Sample file_id: {sample['file_id']}")

# Compute fresh embedding
audio = sample['audio']
if isinstance(audio, torch.Tensor):
    audio = audio.cpu().numpy()

with torch.no_grad():
    audio_inputs = ast_feature_extractor(
        audio,
        sampling_rate=16000,
        return_tensors="pt",
        padding=True
    ).input_values.to(DEVICE)
    
    fresh_embedding = ast_model(audio_inputs).last_hidden_state.cpu()

# Get from cache
cached_embedding = loaded_embeddings[test_cache_key]['embedding']

# Compare
print(f"\nFresh embedding shape: {fresh_embedding.shape}")
print(f"Cached embedding shape: {cached_embedding.shape}")

# Calculate difference
diff = torch.abs(fresh_embedding - cached_embedding).mean()
print(f"\nMean absolute difference: {diff.item():.8f}")

if diff < 1e-5:
    print("✓ Embeddings match! Cache is working correctly.")
else:
    print("⚠ Warning: Embeddings don't match perfectly (might be due to padding/batching)")

# Verify both methods return same sample
print(f"\n✓ Both access methods return same sample: {sample['file_id'] == sample_alt['file_id']}")

## 8. Benchmark: tempo di caricamento e accesso

In [None]:
import time

print("=== Benchmark 1: Cache Loading Speed ===")

# Time cache loading
start = time.time()
cached_embs, cached_index_map, _ = load_embeddings_cache(CACHE_DIR)
cache_time = time.time() - start
print(f"Cache loading time: {cache_time:.3f}s")

# Time computing a few samples
num_test_samples = min(50, len(dataset))
print(f"\n=== Benchmark 2: On-the-fly vs Cache ({num_test_samples} samples) ===")
start = time.time()
for i in range(num_test_samples):
    sample = dataset[i]
    audio = sample['audio']
    if isinstance(audio, torch.Tensor):
        audio = audio.cpu().numpy()
    
    with torch.no_grad():
        audio_inputs = ast_feature_extractor(
            audio,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values.to(DEVICE)
        _ = ast_model(audio_inputs).last_hidden_state.cpu()

compute_time = time.time() - start
print(f"On-the-fly computation time: {compute_time:.3f}s")
print(f"Average time per sample: {compute_time/num_test_samples:.4f}s")

# Extrapolate to full dataset
estimated_full_time = (compute_time / num_test_samples) * len(dataset)
print(f"\nEstimated time for full dataset ({len(dataset)} samples): {estimated_full_time:.2f}s")
print(f"Speedup with cache: {estimated_full_time/cache_time:.1f}x")

print(f"\n=== Benchmark 3: Sample Access Speed ===")

# Test index_map lookup speed
num_lookups = 1000
test_cache_keys = list(cached_index_map.keys())[:min(num_lookups, len(cached_index_map))]

# Method 1: With index_map (O(1))
start = time.time()
for cache_key in test_cache_keys:
    idx = cached_index_map[cache_key]
    _ = dataset[idx]
fast_time = time.time() - start

print(f"With index_map ({len(test_cache_keys)} lookups): {fast_time:.4f}s")
print(f"Average per lookup: {fast_time/len(test_cache_keys)*1000:.4f}ms")

# Method 2: Linear search (O(n)) - just for comparison on a few samples
num_slow_tests = min(10, len(test_cache_keys))
start = time.time()
for cache_key in test_cache_keys[:num_slow_tests]:
    class_name, file_id = cache_key.split(':')
    for idx in range(len(dataset)):
        sample = dataset.samples[idx]
        if sample['file_id'] == file_id and sample['class_name'] == class_name:
            _ = dataset[idx]
            break
slow_time = time.time() - start

print(f"\nLinear search ({num_slow_tests} lookups): {slow_time:.4f}s")
print(f"Average per lookup: {slow_time/num_slow_tests*1000:.4f}ms")
print(f"Speedup with index_map: {(slow_time/num_slow_tests)/(fast_time/len(test_cache_keys)):.1f}x")

## 9. Test integrazione con VEGASDataset

Testiamo i metodi appena implementati in VEGASDataset:
- `save_ast_embeddings_to_cache()` - salva embeddings organizzati per classe
- `load_ast_embeddings_from_cache()` - carica con mmap
- `get_cached_ast_embedding()` - accesso O(1)