# VEGAS Dataset - Node Splits Testing

Questo notebook testa le funzionalit√† di splitting del dataset VEGAS:
- Split per nodo con singola classe
- Split per nodo con pi√π classi
- Visualizzazione immagini
- Riproduzione audio
- Verifica bilanciamento dataset

## Setup e Import

In [None]:
import os
from pathlib import Path

# Set working directory to project root
project_root = Path.home() / 'fedgfe'
os.chdir(project_root)
print(f"Working directory set to: {os.getcwd()}\n")

import sys
sys.path.append('system')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import IPython.display as ipd
import json

from datautils.dataset_vegas import VEGASDataset

print("‚úì Imports completed")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Configurazione

In [6]:
# Path al dataset VEGAS (relativo alla working directory)
VEGAS_ROOT = Path("dataset/Audio/VEGAS")

# Verifica esistenza
if not VEGAS_ROOT.exists():
    print(f"‚ö†Ô∏è  VEGAS dataset not found at: {VEGAS_ROOT}")
    print("Please update the path or download the dataset")
else:
    print(f"‚úì VEGAS dataset found at: {VEGAS_ROOT}")
    
# Configurazione
config = {
    'dataset_path': str(VEGAS_ROOT),
    'audio_sample_rate': 16000,
    'audio_duration': 10.0,
    'image_size': (224, 224),
}

‚úì VEGAS dataset found at: dataset/Audio/VEGAS


## 1. Esplora Dataset VEGAS Completo

In [None]:
# Mostra le classi disponibili in VEGAS usando i metadati del dataset
print(f"Available classes in VEGAS dataset:")
print(f"\nTotal classes: {len(VEGASDataset.CLASS_LABELS)}")
print("\nClass labels:")
for class_name, label in sorted(VEGASDataset.CLASS_LABELS.items(), key=lambda x: x[1]):
    print(f"  {label}: {class_name}")

# Carica solo un sample per verificare che il dataset funzioni
print(f"\n{'='*80}")
print("Loading dataset (this may take a moment)...")
print(f"{'='*80}\n")

full_dataset = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    split='train',
    enable_ast_cache=False,
    load_audio=True,
    load_image=True,
    load_video=False
)

print(f"‚úì Dataset loaded successfully!")
print(f"Total samples in VEGAS train split: {len(full_dataset)}")

# Conta samples per classe (solo i primi 100 per velocit√†)
print(f"\nSampling first 100 samples to check class distribution...")
class_counts = {}
sample_limit = min(100, len(full_dataset))

for idx in range(sample_limit):
    sample = full_dataset[idx]
    class_name = sample.get('class_name', 'unknown')
    class_counts[class_name] = class_counts.get(class_name, 0) + 1

sorted_classes = sorted(class_counts.items(), key=lambda x: x[1], reverse=True)
print(f"\nClass distribution (first {sample_limit} samples):")
for class_name, count in sorted_classes:
    print(f"  {class_name:30s}: {count:4d} samples")

Available classes in VEGAS dataset:

Total classes: 10

Class labels:
  0: baby_cry
  1: chainsaw
  2: dog
  3: drum
  4: fireworks
  5: helicopter
  6: printer
  7: rail_transport
  8: snoring
  9: water_flowing

Loading dataset (this may take a moment)...

‚úì Dataset loaded successfully!
Total samples in VEGAS train split: 19674

Sampling first 100 samples to check class distribution...

Class distribution (first 100 samples):
  printer                       :   15 samples
  water_flowing                 :   15 samples
  rail_transport                :   15 samples
  dog                           :   14 samples
  fireworks                     :   13 samples
  helicopter                    :    8 samples
  drum                          :    7 samples
  chainsaw                      :    6 samples
  baby_cry                      :    4 samples
  snoring                       :    3 samples


: 

## 2. Test Split Singola Classe per Nodo

In [None]:
print("=" * 80)
print("TEST 1: Single Class per Node")
print("=" * 80)

# Test con singola classe per nodo (ridotto a campioni pi√π piccoli)
single_class_datasets = []

# Node 0: solo 'dog' - RIDOTTO A 20 SAMPLES
print(f"\nNode 0: Loading 'dog' class (20 samples, split_id=0)...")
node_0 = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    selected_classes=['dog'],
    samples_per_node=20,  # RIDOTTO
    node_split_id=0,
    train_ratio=1.0,  # Tutto in train per semplificare
    val_ratio=0.0,
    test_ratio=0.0,
    split='train',
    enable_ast_cache=False,
    load_audio=True,
    load_image=True,
    load_video=False
)
single_class_datasets.append(('dog', node_0))
print(f"  ‚úì Loaded {len(node_0)} samples")

# Node 1: solo 'chainsaw' - RIDOTTO A 20 SAMPLES
print(f"\nNode 1: Loading 'chainsaw' class (20 samples, split_id=0)...")
node_1 = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    selected_classes=['chainsaw'],
    samples_per_node=20,  # RIDOTTO
    node_split_id=0,
    train_ratio=1.0,
    val_ratio=0.0,
    test_ratio=0.0,
    split='train',
    enable_ast_cache=False,
    load_audio=True,
    load_image=True,
    load_video=False
)
single_class_datasets.append(('chainsaw', node_1))
print(f"  ‚úì Loaded {len(node_1)} samples")

print("\n" + "=" * 80)
print(f"‚úì Created {len(single_class_datasets)} single-class datasets")
print("=" * 80)

## 3. Test Split Multi-Classe per Nodo

In [None]:
print("=" * 80)
print("TEST 2: Multiple Classes per Node")
print("=" * 80)

# Test con pi√π classi per nodo (RIDOTTO)
multi_class_datasets = []

# Node 0: dog, baby_cry, drum - RIDOTTO A 30 SAMPLES TOTALI
print(f"\nNode 0: Loading ['dog', 'baby_cry', 'drum'] (30 samples total)...")
multi_node_0 = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    selected_classes=['dog', 'baby_cry', 'drum'],
    samples_per_node=30,  # RIDOTTO
    node_split_id=0,
    train_ratio=1.0,
    val_ratio=0.0,
    test_ratio=0.0,
    split='train',
    enable_ast_cache=False,
    load_audio=True,
    load_image=True,
    load_video=False
)
multi_class_datasets.append((['dog', 'baby_cry', 'drum'], multi_node_0))
print(f"  ‚úì Loaded {len(multi_node_0)} samples")

print("\n" + "=" * 80)
print(f"‚úì Created {len(multi_class_datasets)} multi-class dataset")
print("=" * 80)

# Conta samples per classe
for i, (classes, dataset) in enumerate(multi_class_datasets):
    print(f"\nMulti-Class Node {i} - Class distribution:")
    class_counts = {}
    for idx in range(len(dataset)):
        sample = dataset[idx]
        class_name = sample.get('class_name', 'unknown')
        class_counts[class_name] = class_counts.get(class_name, 0) + 1
    
    for cls in classes:
        count = class_counts.get(cls, 0)
        print(f"  {cls:20s}: {count:3d} samples")

## 4. Visualizzazione: Immagine + Audio

Funzione helper per visualizzare un sample con immagine e audio

In [None]:
def visualize_sample(sample, title="Sample", show_audio=True):
    """
    Visualizza un sample VEGAS: immagine + audio
    
    Args:
        sample: dict con 'image', 'audio', 'class_name', etc.
        title: titolo per il plot
        show_audio: se True, mostra anche il waveform e player audio
    """
    fig = plt.figure(figsize=(15, 8))
    
    # Info
    class_name = sample.get('class_name', 'Unknown')
    sample_id = sample.get('sample_id', 'N/A')
    
    # Layout: 2 righe, 2 colonne
    # Row 1: Immagine (col 1-2)
    # Row 2: Waveform (col 1), Spectrogram (col 2)
    
    # 1. Immagine
    ax1 = plt.subplot(2, 2, (1, 2))
    image = sample['image']
    
    # Convert tensor to numpy if needed (move to CPU first!)
    if isinstance(image, torch.Tensor):
        # [C, H, W] -> [H, W, C] and move to CPU
        image = image.cpu().permute(1, 2, 0).numpy()
    
    # Denormalize se necessario
    if image.min() < 0:
        image = (image - image.min()) / (image.max() - image.min())
    
    ax1.imshow(image)
    ax1.set_title(f"{title}\nClass: {class_name} | ID: {sample_id}", fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    if show_audio and 'audio' in sample:
        audio = sample['audio']
        sample_rate = sample.get('sample_rate', 16000)
        
        # Convert tensor to numpy (move to CPU first!)
        if isinstance(audio, torch.Tensor):
            audio = audio.cpu().numpy()
        
        # 2. Waveform
        ax2 = plt.subplot(2, 2, 3)
        time_axis = np.arange(len(audio)) / sample_rate
        ax2.plot(time_axis, audio, linewidth=0.5)
        ax2.set_title("Audio Waveform")
        ax2.set_xlabel("Time (s)")
        ax2.set_ylabel("Amplitude")
        ax2.grid(True, alpha=0.3)
        
        # 3. Spectrogram
        ax3 = plt.subplot(2, 2, 4)
        ax3.specgram(audio, Fs=sample_rate, cmap='viridis')
        ax3.set_title("Spectrogram")
        ax3.set_xlabel("Time (s)")
        ax3.set_ylabel("Frequency (Hz)")
    
    plt.tight_layout()
    plt.show()
    
    # Audio player
    if show_audio and 'audio' in sample:
        print(f"\nüîä Audio Player for: {class_name}")
        display(ipd.Audio(audio, rate=sample_rate))

print("‚úì Visualization function defined")

## 5. Visualizza Samples da Nodo Single-Class

In [None]:
print("=" * 80)
print("VISUALIZING SAMPLES FROM SINGLE-CLASS NODES")
print("=" * 80)

# Visualizza 1 sample da ogni nodo single-class
for class_name, dataset in single_class_datasets:
    print(f"\n{'='*80}")
    print(f"Class: {class_name} ({len(dataset)} samples)")
    print(f"{'='*80}\n")
    
    if len(dataset) > 0:
        # Prendi un sample random
        idx = np.random.randint(0, len(dataset))
        sample = dataset[idx]
        
        visualize_sample(
            sample, 
            title=f"Class: {class_name} - Sample {idx}",
            show_audio=True
        )
    else:
        print(f"‚ö†Ô∏è  No samples in dataset")

## 6. Visualizza Samples da Nodo Multi-Class

In [None]:
print("=" * 80)
print("VISUALIZING SAMPLES FROM MULTI-CLASS NODES")
print("=" * 80)

# Visualizza samples da nodi multi-class
for i, (classes, dataset) in enumerate(multi_class_datasets):
    print(f"\n{'='*80}")
    print(f"Multi-Class Node {i}: Classes {classes}")
    print(f"{'='*80}\n")
    
    # Trova samples di classi diverse
    shown_classes = set()
    samples_to_show = min(3, len(classes))  # Max 3 per nodo
    
    attempts = 0
    while len(shown_classes) < samples_to_show and attempts < 100:
        idx = np.random.randint(0, len(dataset))
        sample = dataset[idx]
        class_name = sample.get('class_name', 'unknown')
        
        if class_name not in shown_classes:
            shown_classes.add(class_name)
            visualize_sample(
                sample,
                title=f"Multi-Class Node {i} - Class: {class_name}",
                show_audio=True
            )
        
        attempts += 1

## 7. Test: Verifica Non-Overlap tra Node Splits

In [None]:
print("=" * 80)
print("TEST: Verifico non-overlap tra nodi con stessa classe")
print("=" * 80)

# Test: Crea 2 dataset con stessa classe ma node_split_id diverso
test_class = 'dog'

print(f"\nCreating two datasets for class '{test_class}' with different node_split_id...")

split_0_dataset = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    selected_classes=[test_class],
    node_split_id=0,  # Split 0
    samples_per_node=20,  # RIDOTTO
    train_ratio=1.0,
    val_ratio=0.0,
    test_ratio=0.0,
    split='train',
    enable_ast_cache=False,
    load_audio=False,  # Disabilita audio per velocit√†
    load_image=False,  # Disabilita immagine per velocit√†
    load_video=False
)

split_1_dataset = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    selected_classes=[test_class],
    node_split_id=1,  # Split 1 (diverso!)
    samples_per_node=20,  # RIDOTTO
    train_ratio=1.0,
    val_ratio=0.0,
    test_ratio=0.0,
    split='train',
    enable_ast_cache=False,
    load_audio=False,
    load_image=False,
    load_video=False
)

# Raccogli sample IDs
split_0_ids = set()
for idx in range(len(split_0_dataset)):
    sample = split_0_dataset[idx]
    split_0_ids.add(sample.get('sample_id', idx))

split_1_ids = set()
for idx in range(len(split_1_dataset)):
    sample = split_1_dataset[idx]
    split_1_ids.add(sample.get('sample_id', idx))

# Check overlap
overlap = split_0_ids.intersection(split_1_ids)

print(f"\nClass: {test_class}")
print(f"Node Split 0 samples: {len(split_0_ids)}")
print(f"Node Split 1 samples: {len(split_1_ids)}")
print(f"Overlap: {len(overlap)} samples")

if len(overlap) == 0:
    print("\n‚úì SUCCESS: No overlap between node splits!")
else:
    print(f"\n‚ö†Ô∏è  WARNING: Found {len(overlap)} overlapping samples!")
    print(f"Overlapping IDs: {list(overlap)[:10]}...")

## 8. Statistiche Complete

In [None]:
def print_dataset_statistics(datasets_list, title="Dataset Statistics"):
    """Stampa statistiche dettagliate per una lista di dataset"""
    print("\n" + "=" * 80)
    print(title)
    print("=" * 80)
    
    total_samples = 0
    
    for i, item in enumerate(datasets_list):
        if isinstance(item, tuple):
            label, dataset = item
        else:
            label, dataset = f"Dataset {i}", item
            
        num_samples = len(dataset)
        
        print(f"\n{label}:")
        print(f"  Total samples: {num_samples:4d}")
        
        total_samples += num_samples
    
    print("\n" + "-" * 80)
    print(f"Grand Total: {total_samples:5d} samples")
    print("=" * 80)

# Statistiche per single-class datasets
print_dataset_statistics(single_class_datasets, "Single-Class Datasets Statistics")

# Statistiche per multi-class datasets
print_dataset_statistics(multi_class_datasets, "Multi-Class Datasets Statistics")

## 9. Visualizzazione Comparativa: Stessa Classe, Split Diversi

In [None]:
print("=" * 80)
print("COMPARATIVE VISUALIZATION: Same Class, Different Node Splits")
print("=" * 80)

# Crea 2 dataset con stessa classe ma split diversi (RIDOTTO A 2 INVECE DI 3)
test_class = 'dog'
comparison_datasets = []

print(f"\nCreating 2 datasets for class '{test_class}' with different node_split_id...")

for split_id in range(2):  # RIDOTTO A 2
    dataset = VEGASDataset(
        root_dir=str(VEGAS_ROOT),
        selected_classes=[test_class],
        node_split_id=split_id,
        samples_per_node=10,  # RIDOTTO
        train_ratio=1.0,
        val_ratio=0.0,
        test_ratio=0.0,
        split='train',
        enable_ast_cache=False,
        load_audio=True,
        load_image=True,
        load_video=False
    )
    comparison_datasets.append(dataset)
    print(f"  Split {split_id}: {len(dataset)} samples")

# Visualizza 1 sample da ogni split
print(f"\nComparing samples from 2 different node splits:\n")

for split_id, dataset in enumerate(comparison_datasets):
    if len(dataset) > 0:
        sample = dataset[0]  # Primo sample
        visualize_sample(
            sample,
            title=f"Split {split_id} - Class: {test_class}",
            show_audio=True
        )

## 10. Test Bilanciamento Classi (Multi-Class Node)

In [None]:
print("=" * 80)
print("TEST: Class Balance in Multi-Class Datasets")
print("=" * 80)

for i, (classes, dataset) in enumerate(multi_class_datasets):
    print(f"\nMulti-Class Dataset {i}:")
    print(f"Expected classes: {classes}")
    print(f"Total samples: {len(dataset)}")
    
    # Conta samples per classe
    class_counts = {}
    for idx in range(len(dataset)):
        sample = dataset[idx]
        class_name = sample.get('class_name', 'unknown')
        class_counts[class_name] = class_counts.get(class_name, 0) + 1
    
    print(f"\nClass distribution:")
    for cls in classes:
        count = class_counts.get(cls, 0)
        percentage = (count / len(dataset)) * 100 if len(dataset) > 0 else 0
        print(f"  {cls:20s}: {count:3d} samples ({percentage:5.1f}%)")
    
    print("-" * 80)

## 11. Export Sample per Debugging

In [None]:
# Salva alcuni samples per debugging
output_dir = Path("debug_samples")
output_dir.mkdir(exist_ok=True)

print(f"Exporting debug samples to: {output_dir}\n")

for class_name, dataset in single_class_datasets[:2]:  # Solo primi 2 dataset
    if len(dataset) > 0:
        sample = dataset[0]
        
        # Salva immagine
        image = sample['image']
        if isinstance(image, torch.Tensor):
            # Move to CPU first
            image = image.cpu().permute(1, 2, 0).numpy()
        if image.min() < 0:
            image = (image - image.min()) / (image.max() - image.min())
        image = (image * 255).astype(np.uint8)
        
        img_pil = Image.fromarray(image)
        img_path = output_dir / f"{class_name}_image.png"
        img_pil.save(img_path)
        print(f"‚úì Saved: {img_path}")
        
        # Salva audio
        if 'audio' in sample:
            audio = sample['audio']
            if isinstance(audio, torch.Tensor):
                # Move to CPU first
                audio = audio.cpu().numpy()
            
            sample_rate = sample.get('sample_rate', 16000)
            audio_path = output_dir / f"{class_name}_audio.wav"
            
            import scipy.io.wavfile as wavfile
            wavfile.write(audio_path, sample_rate, audio)
            print(f"‚úì Saved: {audio_path}")

print(f"\n‚úì Debug samples exported to {output_dir}")

## Summary

Questo notebook ha testato:

‚úÖ **Single-class datasets**: Dataset con una classe per nodo  
‚úÖ **Multi-class datasets**: Dataset con pi√π classi per nodo  
‚úÖ **Visualizzazione**: Immagini + audio per ogni sample  
‚úÖ **Non-overlap**: Verifica che split diversi non condividano samples  
‚úÖ **Bilanciamento**: Distribuzione classi in dataset multi-classe  
‚úÖ **Statistiche**: Numero samples per dataset  

### Parametri Chiave VEGASDataset:

- `selected_classes`: Lista classi da caricare
- `node_split_id`: ID split (0, 1, 2...) per evitare overlap tra nodi
- `samples_per_node`: Numero totale samples da caricare
- `train_ratio`, `val_ratio`, `test_ratio`: Proporzioni split
- `split`: 'train', 'val', 'test' - quale split caricare

### Esempio Uso:

```python
# Dataset con singola classe
dataset = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    selected_classes=['dog'],
    node_split_id=0,
    samples_per_node=200,
    split='train',
    enable_ast_cache=False
)

# Dataset con pi√π classi
dataset = VEGASDataset(
    root_dir=str(VEGAS_ROOT),
    selected_classes=['dog', 'chainsaw', 'helicopter'],
    node_split_id=0,
    samples_per_node=300,
    split='train',
    enable_ast_cache=False
)
```