# Test VEGAS Dataset Embeddings Cache (v2 - Integrato)

Test del sistema di caching integrato in VEGASDataset.

## Features:
1. Salvataggio incrementale (append) per classe
2. Memory mapping (lazy loading)
3. Cache key: `{class_name}:{file_id}`
4. Accesso O(1) agli embeddings

In [None]:
# Setup
import os
import sys

home_dir = os.path.expanduser('~')
project_dir = os.path.join(home_dir, 'fedgfe')
os.chdir(project_dir)
sys.path.insert(0, os.path.join(project_dir, 'system'))

print(f"Working directory: {os.getcwd()}")

In [None]:
# Imports
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json

from system.datautils.dataset_vegas import VEGASDataset
from transformers import ASTFeatureExtractor, ASTModel

print(f"PyTorch: {torch.__version__}")
print(f"NumPy: {np.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

In [None]:
# Config
DATASET_PATH = 'dataset/Audio/VEGAS'
CACHE_DIR = 'cache/ast/vegas'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 8
AST_MODEL_NAME = 'MIT/ast-finetuned-audioset-10-10-0.4593'

# Test con 2 classi
SELECTED_CLASSES = ['chainsaw', 'dog']

print(f"Device: {DEVICE}")
print(f"Cache dir: {CACHE_DIR}")
print(f"Classes: {SELECTED_CLASSES}")

## 1. Carica Dataset e AST Model

In [None]:
# Carica dataset
dataset = VEGASDataset(
    root_dir=DATASET_PATH,
    selected_classes=SELECTED_CLASSES,
    samples_per_node=50,  # Limitiamo per test rapido
    node_split_id=0,
    ast_cache_dir=CACHE_DIR,
    enable_ast_cache=True
)

print(f"Dataset: {len(dataset)} samples")
print(f"Classes: {dataset.active_classes}")

In [None]:
# Carica AST model
ast_feature_extractor = ASTFeatureExtractor.from_pretrained(AST_MODEL_NAME)
ast_model = ASTModel.from_pretrained(AST_MODEL_NAME)
ast_model = ast_model.to(DEVICE)
ast_model.eval()

print(f"AST model loaded on {DEVICE}")

## 2. Simula workflow: Calcola embeddings AST

In [None]:
# Simula il calcolo degli embeddings durante training
# (come farebbe clientA2V)

ast_outputs_by_class = {}

print(f"Calculating AST embeddings...")
for idx in tqdm(range(len(dataset))):
    sample = dataset[idx]
    
    # Get audio
    audio = sample['audio']
    if isinstance(audio, torch.Tensor):
        audio = audio.cpu().numpy()
    
    # Compute AST embedding
    with torch.no_grad():
        audio_inputs = ast_feature_extractor(
            audio,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values.to(DEVICE)
        
        ast_output = ast_model(audio_inputs).last_hidden_state
    
    # Organize by class
    class_name = sample['class_name']
    file_id = sample['file_id']
    
    if class_name not in ast_outputs_by_class:
        ast_outputs_by_class[class_name] = {}
    
    ast_outputs_by_class[class_name][file_id] = ast_output.squeeze(0).cpu()

print(f"\nCalculated embeddings:")
for class_name, embs in ast_outputs_by_class.items():
    print(f"  {class_name}: {len(embs)} samples")

## 3. Test: Salva in cache (primo salvataggio)

In [None]:
# Salva usando il metodo di VEGASDataset
saved_counts = dataset.save_ast_embeddings_to_cache(
    ast_outputs_dict=ast_outputs_by_class,
    cache_dir=CACHE_DIR
)

print(f"\nSaved embeddings: {saved_counts}")

In [None]:
# Verifica struttura file
import os

print(f"\nCache structure:")
for root, dirs, files in os.walk(CACHE_DIR):
    level = root.replace(CACHE_DIR, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    subindent = ' ' * 2 * (level + 1)
    for file in files:
        file_path = os.path.join(root, file)
        size = os.path.getsize(file_path) / 1024  # KB
        print(f"{subindent}{file} ({size:.2f} KB)")

## 4. Test: Append incrementale (simula secondo batch)

In [None]:
# Simula salvataggio di nuovi embeddings (es. secondo nodo o round successivo)
# Creiamo alcuni fake embeddings per testare l'append

print("Simulating append of new embeddings...")

new_embeddings = {}
for class_name in SELECTED_CLASSES:
    new_embeddings[class_name] = {
        f'video_99{i:03d}': torch.randn(1214, 768)  # Fake embeddings
        for i in range(5)  # 5 nuovi samples per classe
    }

saved_counts = dataset.save_ast_embeddings_to_cache(
    ast_outputs_dict=new_embeddings,
    cache_dir=CACHE_DIR
)

print(f"\nAppended embeddings: {saved_counts}")

In [None]:
# Verifica che siano stati creati nuovi chunks
for class_name in SELECTED_CLASSES:
    manifest_file = os.path.join(CACHE_DIR, class_name, 'manifest.json')
    with open(manifest_file, 'r') as f:
        manifest = json.load(f)
    
    print(f"\nClass '{class_name}':")
    print(f"  Total samples: {manifest['total_samples']}")
    print(f"  Num chunks: {len(manifest['chunks'])}")
    for i, chunk in enumerate(manifest['chunks']):
        print(f"  Chunk {i}: {chunk['num_samples']} samples")

## 5. Test: Carica da cache con mmap

In [None]:
# Carica cache
loaded_cache = dataset.load_ast_embeddings_from_cache(
    cache_dir=CACHE_DIR,
    classes=SELECTED_CLASSES
)

print(f"\nLoaded cache for {len(loaded_cache)} classes:")
for class_name, cache_info in loaded_cache.items():
    print(f"  {class_name}: {cache_info['manifest']['total_samples']} samples, "
          f"{len(cache_info['chunks'])} chunks (mmap)")

## 6. Test: Accesso O(1) agli embeddings

In [None]:
# Test accesso con get_cached_ast_embedding()
import random

# Prendi un sample casuale
test_idx = random.randint(0, len(dataset) - 1)
sample = dataset.samples[test_idx]

class_name = sample['class_name']
file_id = sample['file_id']

print(f"Testing sample: class={class_name}, file_id={file_id}")

# Recupera da cache
cached_emb = dataset.get_cached_ast_embedding(class_name, file_id)

if cached_emb is not None:
    print(f"✓ Retrieved from cache: shape={cached_emb.shape}")
    
    # Confronta con l'originale (se disponibile)
    if class_name in ast_outputs_by_class and file_id in ast_outputs_by_class[class_name]:
        original_emb = ast_outputs_by_class[class_name][file_id].numpy()
        diff = np.abs(cached_emb - original_emb).mean()
        print(f"  Mean diff from original: {diff:.8f}")
        
        if diff < 1e-6:
            print("  ✓ Perfect match!")
else:
    print("✗ Not found in cache")

## 7. Test: Append duplicati (deve skippare)

In [None]:
# Prova a salvare gli stessi embeddings (devono essere skippati)
print("Testing duplicate detection...")

saved_counts = dataset.save_ast_embeddings_to_cache(
    ast_outputs_dict=ast_outputs_by_class,  # Stessi di prima
    cache_dir=CACHE_DIR
)

print(f"\nDuplicate test result: {saved_counts}")
print("✓ All should be 0 (already cached)")

## 8. Benchmark: Accesso mmap vs calcolo fresh

In [None]:
import time

num_tests = 20
test_indices = random.sample(range(len(dataset)), min(num_tests, len(dataset)))

# Test 1: Accesso da cache
print(f"Testing {num_tests} random samples...")

start = time.time()
for idx in test_indices:
    sample = dataset.samples[idx]
    _ = dataset.get_cached_ast_embedding(sample['class_name'], sample['file_id'])
cache_time = time.time() - start

print(f"Cache access time: {cache_time:.4f}s ({cache_time/num_tests*1000:.2f}ms per sample)")

# Test 2: Calcolo fresh
start = time.time()
for idx in test_indices:
    sample = dataset[idx]
    audio = sample['audio'].cpu().numpy() if isinstance(sample['audio'], torch.Tensor) else sample['audio']
    
    with torch.no_grad():
        audio_inputs = ast_feature_extractor(
            audio,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values.to(DEVICE)
        _ = ast_model(audio_inputs).last_hidden_state.cpu()

fresh_time = time.time() - start

print(f"Fresh computation time: {fresh_time:.4f}s ({fresh_time/num_tests*1000:.2f}ms per sample)")
print(f"\nSpeedup: {fresh_time/cache_time:.1f}x")

## Conclusioni

Sistema di caching integrato in VEGASDataset:

### ✓ Features implementate:
- Salvataggio per classe (file separati)
- Append incrementale (nuovi chunks)
- Rilevamento duplicati automatico
- Memory mapping (lazy loading)
- Accesso O(1) agli embeddings
- Metadata integrati nel manifest
- Thread-safe writes (atomic rename)

### Struttura cache:
```
cache_dir/
├── class_name/
│   ├── embeddings_000.npy
│   ├── embeddings_001.npy
│   └── manifest.json
└── ...
```

### API:
```python
# Salva (da clientA2V o server)
dataset.save_ast_embeddings_to_cache(ast_outputs_by_class, cache_dir)

# Carica (all'inizio training)
dataset.load_ast_embeddings_from_cache(cache_dir)

# Accedi (in __getitem__)
emb = dataset.get_cached_ast_embedding(class_name, file_id)
```

### Prossimi passi:
1. Integrare in clientA2V per salvare durante training
2. Script standalone per pre-generare cache
3. Modificare __getitem__ per usare cache quando disponibile