# 🎵 ManuAI Bird Call Classifier - Universal Training

**Scalable preprocessing for all bird species**

This streamlined notebook provides an efficient training pipeline with:
- Essential imports and configuration
- Universal adaptive preprocessing for all species
- Efficient data loading and training
- Quick evaluation and results
- Ready for dataset expansion without species-specific dependencies

In [1]:
# Essential imports
import os
import glob
import numpy as np
import pandas as pd
import torch
import librosa
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from scipy import signal
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from torch.utils.data import Dataset
from transformers import (
    ViTImageProcessor, ViTForImageClassification, 
    Trainer, TrainingArguments, EarlyStoppingCallback
)
import torchvision.transforms as transforms

# Import for high-quality image resizing
try:
    from skimage import transform as sk_transform
    print("✅ scikit-image available for high-quality resizing")
except ImportError:
    sk_transform = None
    print("⚠️ scikit-image not available, using PIL for resizing")

print("✅ Essential libraries imported")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"📱 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

  from .autonotebook import tqdm as notebook_tqdm


✅ scikit-image available for high-quality resizing
✅ Essential libraries imported
🔥 PyTorch version: 2.7.1
📱 Device: CPU


In [2]:
# Fine-tuning Configuration
config = {
    'audio_dir': 'segments',  # Updated to use pre-segmented data
    'dataset_splits_path': 'segment_splits.npz',  # Updated filename
    'model_name': 'google/vit-base-patch16-224-in21k', # Pre-trained ViT model 
    'num_epochs': 20,
    'batch_size': 16,
    'learning_rate': 2e-5,
    'warmup_steps': 500, 
    'eval_steps': 100,
    'save_steps': 200,
    'early_stopping_patience': 3,
    'target_size': (224, 224),
    'validate_quality': True,
    'seed': 42
}

# Set random seeds
np.random.seed(config['seed'])
torch.manual_seed(config['seed'])

print("⚙️ Configuration loaded")
print(f" Using pre-segmented data from: {config['audio_dir']}/")
print(f" Batch size: {config['batch_size']}")
print(f" Max epochs: {config['num_epochs']}")
print(f" Learning rate: {config['learning_rate']}")
print(" Note: Each segment = 1 mel spectrogram for training")

⚙️ Configuration loaded
 Using pre-segmented data from: segments/
 Batch size: 16
 Max epochs: 20
 Learning rate: 2e-05
 Note: Each segment = 1 mel spectrogram for training


In [3]:
# Dataset Splitting Configuration
splitting_config = {
    'force_new_splits': False,  # Set to True to always create new splits
    'test_size': 0.2,          # 20% for test set
    'val_size': 0.2,           # 20% for validation set (60% train, 20% val, 20% test)
    'stratify': True,          # Ensure balanced splits across species
    'show_split_info': True    # Display detailed split information
}

print("⚙️ Dataset splitting configuration:")
print(f"   Force new splits: {splitting_config['force_new_splits']}")
print(f"   Test size: {splitting_config['test_size']*100:.0f}%")
print(f"   Validation size: {splitting_config['val_size']*100:.0f}%")
print(f"   Train size: {(1-splitting_config['test_size']-splitting_config['val_size'])*100:.0f}%")
print("   💡 Set 'force_new_splits=True' to regenerate splits")
print("   🎛️ Use dataset_size_config above to control dataset size and balancing")

⚙️ Dataset splitting configuration:
   Force new splits: False
   Test size: 20%
   Validation size: 20%
   Train size: 60%
   💡 Set 'force_new_splits=True' to regenerate splits
   🎛️ Use dataset_size_config above to control dataset size and balancing


In [4]:
# Universal Dataset Class with Adaptive Preprocessing
class UniversalAudioDataset(Dataset):
    """Dataset with robust universal preprocessing for all bird species."""
    
    def __init__(self, audio_paths, labels, label_encoder, transform=None, target_size=(224, 224)):
        self.audio_paths = audio_paths
        self.labels = labels
        self.label_encoder = label_encoder
        self.transform = transform
        self.target_size = target_size
        
        # Initialize ViT processor
        self.processor = ViTImageProcessor.from_pretrained(
            'google/vit-base-patch16-224-in21k'
        )
    
    def adaptive_preprocessing(self, audio, sr):
        """Adaptive preprocessing based on audio characteristics."""
        # Normalize loudness
        if np.max(np.abs(audio)) > 0:
            audio = audio / np.max(np.abs(audio))
        
        # Adaptive parameters based on signal characteristics
        rms_energy = np.sqrt(np.mean(audio**2))
        spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=sr))
        
        # Adaptive mel parameters
        if rms_energy < 0.1:  # Low energy signal
            n_mels = 128
            hop_length = 256  
        else:  # Normal signal
            n_mels = 128
            hop_length = 256
            
        # Frequency adaptation
        if spectral_centroid > 4000:  # High frequency signal
            fmax = 10000
        else:  # Normal range
            fmax = 8000
            
        return n_mels, hop_length, fmax
    
    def create_robust_spectrogram(self, audio, sr):
        """Create robust mel spectrogram with universal parameters."""
        try:
            # Get adaptive parameters
            n_mels, hop_length, fmax = self.adaptive_preprocessing(audio, sr)
            
            # Generate mel spectrogram
            mel_spec = librosa.feature.melspectrogram(
                y=audio, 
                sr=sr, 
                n_mels=n_mels,
                hop_length=hop_length,
                win_length=1024,
                fmax=fmax
            )
            
            # Convert to dB and resize
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
            
            # Resize to target dimensions
            if mel_spec_db.shape != self.target_size:
                # Adjust height (mel bins)
                if mel_spec_db.shape[0] != self.target_size[0]:
                    if mel_spec_db.shape[0] < self.target_size[0]:
                        mel_spec_db = np.pad(mel_spec_db, 
                                           ((0, self.target_size[0] - mel_spec_db.shape[0]), (0, 0)), 
                                           mode='constant')
                    else:
                        mel_spec_db = mel_spec_db[:self.target_size[0], :]
                
                # Adjust width (time)
                if mel_spec_db.shape[1] != self.target_size[1]:
                    if mel_spec_db.shape[1] < self.target_size[1]:
                        mel_spec_db = np.pad(mel_spec_db,
                                           ((0, 0), (0, self.target_size[1] - mel_spec_db.shape[1])),
                                           mode='constant')
                    else:
                        mel_spec_db = mel_spec_db[:, :self.target_size[1]]
            
            # Normalize to 0-255
            mel_spec_normalized = ((mel_spec_db - mel_spec_db.min()) / 
                                 (mel_spec_db.max() - mel_spec_db.min() + 1e-8) * 255).astype(np.uint8)
            
            # Convert to RGB PIL Image
            spectrogram_image = Image.fromarray(mel_spec_normalized, mode='L').convert('RGB')
            
            return spectrogram_image
            
        except Exception as e:
            print(f"Error in spectrogram creation: {e}")
            # Return black image as fallback
            return Image.new('RGB', self.target_size, color='black')
    
    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]
        
        try:
            # Load audio
            audio, sr = librosa.load(audio_path, sr=None)
            
            # Create spectrogram 
            spectrogram_image = self.create_robust_spectrogram(audio, sr)
            
            # Apply transforms if provided
            if self.transform:
                spectrogram_image = self.transform(spectrogram_image)
                # Convert back to PIL if transform returned tensor
                if isinstance(spectrogram_image, torch.Tensor):
                    spectrogram_image = transforms.ToPILImage()(spectrogram_image)
            
            # Process for ViT
            inputs = self.processor(images=spectrogram_image, return_tensors="pt")
            
            return {
                'pixel_values': inputs['pixel_values'].squeeze(),
                'labels': torch.tensor(label, dtype=torch.long)
            }
            
        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            # Robust fallback
            black_image = Image.new('RGB', self.target_size, color='black')
            inputs = self.processor(images=black_image, return_tensors="pt")
            return {
                'pixel_values': inputs['pixel_values'].squeeze(),
                'labels': torch.tensor(label, dtype=torch.long)
            }

print("✅ UniversalAudioDataset class ready - scalable for all species")

✅ UniversalAudioDataset class ready - scalable for all species


In [5]:
# Load or create segment-based dataset splits with size control
print("📂 Loading/creating segment-based dataset splits with size control...")

# Use segments directory - this contains the actual training data (pre-segmented)
segments_directory = 'segments'

# Create a unique splits filename based on configuration
size_suffix = f"_{dataset_size_config['size']}"
if dataset_size_config['balance_classes']:
    size_suffix += "_balanced"
splits_filename = f"segment_splits{size_suffix}.npz"
controlled_splits_path = splits_filename

# Determine if we need to create new splits
splits_exist = os.path.exists(controlled_splits_path)
create_new_splits = splitting_config['force_new_splits'] or not splits_exist

if splits_exist and not splitting_config['force_new_splits']:
    print(f"✅ Found existing segment splits file: {controlled_splits_path}")
    print("💡 Set splitting_config['force_new_splits']=True to regenerate")
elif splits_exist and splitting_config['force_new_splits']:
    print(f"🔄 Forcing creation of new splits (existing file will be overwritten)")
else:
    print(f"❌ No splits file found at {controlled_splits_path}")
    print("🆕 Creating new segment splits automatically")

# Load existing or create new splits
if create_new_splits:
    print(f"\n🎛️ Creating {dataset_size_config['size']} segment dataset with selected species...")
    
    X_train, X_val, X_test, y_train, y_val, y_test = create_controlled_segment_splits(
        segments_directory, 
        selected_species_limits,
        test_size=splitting_config['test_size'], 
        val_size=splitting_config['val_size'], 
        seed=config['seed'],
        show_info=splitting_config['show_split_info']
    )
    
    # Save the splits for future use
    np.savez(
        controlled_splits_path,
        X_train=X_train, X_val=X_val, X_test=X_test,
        y_train=y_train, y_val=y_val, y_test=y_test,
        # Save configuration for reference
        size_config=dataset_size_config['size'],
        species_limits=selected_species_limits,
        data_type='segments'  # Flag to indicate this uses segments
    )
    print(f"💾 Controlled segment splits saved to {controlled_splits_path}")
    
else:
    # Load existing splits
    data_splits = np.load(controlled_splits_path, allow_pickle=True)
    X_train, X_val, X_test = data_splits['X_train'], data_splits['X_val'], data_splits['X_test']
    y_train, y_val, y_test = data_splits['y_train'], data_splits['y_val'], data_splits['y_test']
    
    # Try to load configuration info if available
    try:
        saved_size_config = data_splits['size_config'].item() if 'size_config' in data_splits else 'unknown'
        saved_species_limits = data_splits['species_limits'].item() if 'species_limits' in data_splits else {}
        data_type = data_splits['data_type'].item() if 'data_type' in data_splits else 'unknown'
        print(f"📁 Loaded existing {saved_size_config} {data_type} dataset splits:")
    except:
        print(f"📁 Loaded existing splits:")
    
    if splitting_config['show_split_info']:
        print(f"   Train: {len(X_train)} spectrograms")
        print(f"   Validation: {len(X_val)} spectrograms") 
        print(f"   Test: {len(X_test)} spectrograms")

# Create label encoder
all_labels = np.concatenate([y_train, y_val, y_test])
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Convert to numeric labels
y_train_encoded = label_encoder.transform(y_train)
y_val_encoded = label_encoder.transform(y_val)
y_test_encoded = label_encoder.transform(y_test)

print(f"\n✅ {dataset_size_config['size'].title()} segment dataset ready:")
print(f"   Train: {len(X_train)} spectrograms")
print(f"   Validation: {len(X_val)} spectrograms")
print(f"   Test: {len(X_test)} spectrograms")
print(f"   Classes: {len(label_encoder.classes_)}")
print(f"   Species: {', '.join(map(str, label_encoder.classes_))}")

# Enhanced data transforms for better generalization
transform_train = transforms.Compose([
    transforms.RandomRotation(degrees=(-3, 3)),  # Slight temporal shifts
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1),  # Spectral variations
    transforms.RandomHorizontalFlip(p=0.1),  # Rare but helps with temporal reversals
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))
    ], p=0.1),  # Slight smoothing to simulate recording conditions
])

transform_val = None  # No augmentation for validation

print("✅ Enhanced universal data transforms ready")
print("🎵 Note: Using pre-segmented audio files - each segment becomes 1 mel spectrogram")

📂 Loading/creating segment-based dataset splits with size control...


NameError: name 'dataset_size_config' is not defined

In [None]:
# Analyze actual segmented dataset (the real training data)
print("🔍 Analyzing segmented dataset (actual training data)...")

# Update to use segments directory - this is the real training data!
segments_dir = 'segments'

def analyze_segmented_dataset(segments_dir):
    """Analyze the actual segmented dataset that will be used for training."""
    print(f"📊 Analyzing segmented dataset in {segments_dir}...")
    
    # Get all segment files from species folders
    segment_paths = []
    labels = []
    
    for species_folder in os.listdir(segments_dir):
        species_path = os.path.join(segments_dir, species_folder)
        if os.path.isdir(species_path):
            # Find all .wav files in this species folder (including subdirectories)
            species_files = glob.glob(os.path.join(species_path, '**', '*.wav'), recursive=True)
            
            segment_paths.extend(species_files)
            labels.extend([species_folder] * len(species_files))
            
            print(f"   {species_folder}: {len(species_files)} segments")
    
    # Analyze distribution
    species_counts = Counter(labels)
    total_segments = len(segment_paths)
    
    print(f"\n📈 Segmented Dataset Distribution ({total_segments} total segments):")
    print("=" * 60)
    
    sorted_species = sorted(species_counts.items(), key=lambda x: x[1], reverse=True)
    min_segments = min(species_counts.values())
    max_segments = max(species_counts.values())
    
    for species, count in sorted_species:
        percentage = (count / total_segments) * 100
        print(f"   {species.ljust(12)}: {count:5d} segments ({percentage:5.1f}%)")
    
    print("=" * 60)
    print(f"📊 Segment Distribution Summary:")
    print(f"   • Minimum segments per class: {min_segments}")
    print(f"   • Maximum segments per class: {max_segments}")
    print(f"   • Imbalance ratio: {max_segments/min_segments:.2f}:1")
    
    # Determine if dataset is balanced
    is_balanced = (max_segments / min_segments) <= 2.0  # Allow up to 2:1 ratio as "balanced"
    balance_status = "✅ Relatively balanced" if is_balanced else "⚠️ Highly imbalanced"
    print(f"   • Balance status: {balance_status}")
    
    # Calculate expected spectrograms (1:1 ratio with segments for individual processing)
    print(f"\n🎵 Spectrogram Count (1 per segment):")
    print(f"   • Total spectrograms for training: {total_segments}")
    print(f"   • Each segment → 1 mel spectrogram")
    
    return species_counts, sorted_species, segment_paths, labels

# Analyze actual segmented dataset
species_counts, sorted_species, all_segment_paths, all_segment_labels = analyze_segmented_dataset(segments_dir)

In [None]:
# Segment-Based Dataset Size Selection Configuration
dataset_size_config = {
    'size': 'medium',  # Options: 'small', 'medium', 'large', 'full'
    'balance_classes': True,  # Whether to balance classes by limiting to min class size
    'min_segments_per_class': 50,  # Minimum segments required per class
    'show_size_info': True
}

def create_sized_segment_dataset(species_counts, size='medium', balance_classes=True, min_segments_per_class=50):
    """
    Create different sized datasets based on segment counts (actual spectrograms).
    All sizes include ALL classes - only the training data amount per class varies.
    
    Args:
        species_counts: Counter object with species and their segment counts
        size: 'small', 'medium', 'large', or 'full'
        balance_classes: If True, limit all classes to the size of the smallest class
        min_segments_per_class: Minimum segments required per class
    """
    
    print(f"\n🎛️ Creating '{size}' segment-based dataset...")
    
    # Define size parameters based on spectrogram counts
    # All sizes use ALL CLASSES - only training data amount varies
    size_configs = {
        'small': {
            'max_segments_per_class': 200,
            'description': 'Quick prototyping (200 spectrograms/class, all classes)',
            'estimated_time': '5-15 minutes',
            'total_estimate': 2000  # 200 * 10 classes
        },
        'medium': {
            'max_segments_per_class': 800,
            'description': 'Balanced experimentation (800 spectrograms/class, all classes)',
            'estimated_time': '30-60 minutes',
            'total_estimate': 8000  # 800 * 10 classes
        },
        'large': {
            'max_segments_per_class': 1500,
            'description': 'Comprehensive training (1500 spectrograms/class, all classes)',
            'estimated_time': '2-4 hours',
            'total_estimate': 15000  # 1500 * 10 classes
        },
        'full': {
            'max_segments_per_class': float('inf'),
            'description': 'Complete dataset (all available spectrograms, all classes)',
            'estimated_time': '4-8 hours',
            'total_estimate': 28706
        }
    }
    
    if size not in size_configs:
        raise ValueError(f"Size must be one of: {list(size_configs.keys())}")
    
    config = size_configs[size]
    print(f"📝 Configuration: {config['description']}")
    print(f"⏱️ Estimated training time: {config['estimated_time']}")
    
    # Use ALL species - sort by segment count for consistency
    sorted_species = sorted(species_counts.items(), key=lambda x: x[1], reverse=True)
    print(f"🎯 Using all {len(sorted_species)} classes for consistent model architecture")
    
    # Determine segments per class
    species_limits = {}
    
    if balance_classes and size != 'full':
        # For balanced datasets, find the limiting factor
        available_segments = [count for species, count in sorted_species]
        
        if size in ['small', 'medium', 'large']:
            # Use the configured max, but don't exceed what's available
            # Also consider the minimum available if balancing
            if balance_classes:
                limit = min(config['max_segments_per_class'], min(available_segments))
            else:
                limit = config['max_segments_per_class']
        else:
            limit = min(available_segments)
        
        print(f"⚖️ Balancing classes: {limit} segments per class")
        for species, _ in sorted_species:
            species_limits[species] = limit
            
    else:
        # Use available segments up to the configured maximum
        print(f"📊 Using available segments (max {config['max_segments_per_class']} per class)")
        for species, count in sorted_species:
            if size == 'full':
                species_limits[species] = count
            else:
                species_limits[species] = min(count, config['max_segments_per_class'])
    
    # Verify minimum requirements and handle classes with insufficient data
    valid_species = {}
    insufficient_species = []
    
    for species, limit in species_limits.items():
        if limit >= min_segments_per_class:
            valid_species[species] = limit
        else:
            # For classes with insufficient data, use what's available but warn
            available = species_counts[species]
            valid_species[species] = available
            insufficient_species.append((species, available))
            print(f"⚠️ {species}: only {available} segments available (< {min_segments_per_class} desired)")
    
    # Calculate totals
    total_segments = sum(valid_species.values())
    
    print(f"\n📈 Segment Dataset Summary:")
    print("=" * 70)
    for species, limit in sorted(valid_species.items(), key=lambda x: x[1], reverse=True):
        percentage = (limit / total_segments) * 100
        available = species_counts[species]
        status = "⚠️" if limit < min_segments_per_class else "✅"
        print(f"   {status} {species.ljust(12)}: {limit:4d} segments ({percentage:5.1f}%) [of {available} available]")
    print("=" * 70)
    print(f"   Total: {total_segments} segments across {len(valid_species)} classes")
    print(f"   🎵 = {total_segments} mel spectrograms for training")
    
    # Calculate approximate train/val/test split sizes
    train_size = int(total_segments * 0.6)
    val_size = int(total_segments * 0.2)
    test_size = total_segments - train_size - val_size
    
    print(f"\n🔄 Approximate split sizes:")
    print(f"   Train: ~{train_size} spectrograms")
    print(f"   Validation: ~{val_size} spectrograms")
    print(f"   Test: ~{test_size} spectrograms")
    
    # Balance efficiency info
    if balance_classes:
        original_total = sum(species_counts[species] for species in valid_species.keys())
        efficiency = (total_segments / original_total) * 100
        print(f"\n📊 Dataset efficiency: {efficiency:.1f}% of available segments used")
        if efficiency < 50:
            print("   💡 Consider 'balance_classes=False' to use more data")
    
    # Model consistency info
    print(f"\n🏗️ Model Architecture:")
    print(f"   Classes: {len(valid_species)} (same across all dataset sizes)")
    print(f"   Species: {list(valid_species.keys())}")
    if insufficient_species:
        print(f"   Note: {len(insufficient_species)} classes have limited data")
    
    return valid_species, config

# Create the configured segment dataset size
selected_species_limits, size_config = create_sized_segment_dataset(
    species_counts, 
    size=dataset_size_config['size'],
    balance_classes=dataset_size_config['balance_classes'],
    min_segments_per_class=dataset_size_config['min_segments_per_class']
)

print(f"\n✅ Segment-based dataset configuration complete!")
print(f"💡 To change size, modify dataset_size_config['size'] to: 'small', 'medium', 'large', or 'full'")
print(f"🔧 Toggle balancing with dataset_size_config['balance_classes'] = True/False")
print(f"🎯 All dataset sizes use the same {len(selected_species_limits)} classes for consistent comparison")

In [None]:
# Updated Segment-Based Dataset Creation with Size Control
def create_controlled_segment_splits(segments_dir, species_limits, test_size=0.2, val_size=0.2, seed=42, show_info=True):
    """
    Create train/val/test splits with controlled species and segment limits.
    
    Args:
        segments_dir: Directory containing segmented audio files organized by species
        species_limits: Dict mapping species names to maximum segment counts
        test_size, val_size: Split proportions
        seed: Random seed for reproducibility
        show_info: Whether to display detailed information
    """
    print(f"🔧 Creating controlled segment-based dataset splits...")
    
    # Get all segment files organized by species
    species_segments = {}
    total_available = 0
    
    for species in species_limits.keys():
        species_path = os.path.join(segments_dir, species)
        if os.path.isdir(species_path):
            # Find all .wav files in this species folder (including subdirectories)
            species_files = glob.glob(os.path.join(species_path, '**', '*.wav'), recursive=True)
            species_segments[species] = species_files
            total_available += len(species_files)
            
            if show_info:
                print(f"   Found {len(species_files)} {species} segments")
    
    # Apply segment limits and shuffle
    np.random.seed(seed)
    controlled_paths = []
    controlled_labels = []
    
    total_selected = 0
    for species, limit in species_limits.items():
        if species in species_segments:
            files = species_segments[species]
            # Shuffle and take up to the limit
            np.random.shuffle(files)
            selected_files = files[:limit]
            
            for file_path in selected_files:
                controlled_paths.append(file_path)
                controlled_labels.append(species)
                total_selected += 1
    
    if show_info:
        unique_species = sorted(set(controlled_labels))
        print(f"\n✅ Segment selection complete:")
        print(f"   Selected: {total_selected} segments from {len(unique_species)} species")
        print(f"   Species: {sorted(unique_species)} ({len(unique_species)} total)")
        efficiency = (total_selected / total_available) * 100 if total_available > 0 else 0
        print(f"   📈 Dataset efficiency: {efficiency:.1f}% of available segments used")
    
    # Convert to arrays
    X = np.array(controlled_paths)
    y = np.array(controlled_labels)
    
    # Create splits with stratification
    X_temp, X_test, y_temp, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )
    
    val_size_adjusted = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted, random_state=seed, stratify=y_temp
    )
    
    if show_info:
        print(f"\n✅ Controlled segment splits created:")
        print(f"   Train: {len(X_train)} spectrograms ({len(X_train)/len(X)*100:.1f}%)")
        print(f"   Validation: {len(X_val)} spectrograms ({len(X_val)/len(X)*100:.1f}%)")
        print(f"   Test: {len(X_test)} spectrograms ({len(X_test)/len(X)*100:.1f}%)")
        
        # Show balance in each split
        print(f"\n⚖️ Class balance verification:")
        for split_name, split_labels in [("Train", y_train), ("Val", y_val), ("Test", y_test)]:
            species_in_split = Counter(split_labels)
            min_samples = min(species_in_split.values()) if species_in_split else 0
            max_samples = max(species_in_split.values()) if species_in_split else 0
            balance_ratio = max_samples / min_samples if min_samples > 0 else float('inf')
            print(f"   {split_name}: {min_samples}-{max_samples} spectrograms/class (ratio: {balance_ratio:.2f}:1)")
    
    return X_train, X_val, X_test, y_train, y_val, y_test

print("✅ Controlled segment-based dataset creation function ready")

In [None]:
# 🎛️ Segment-Based Dataset Size Options Demo
print("🎛️ Available Segment-Based Dataset Size Options:")
print("="*70)
print("🎯 All sizes use the SAME number of classes - only training data amount varies")
print("="*70)

# Demo all size options with segment counts
demo_sizes = ['small', 'medium', 'large', 'full']

for size in demo_sizes:
    print(f"\n📊 {size.upper()} Segment Dataset:")
    demo_limits, demo_config = create_sized_segment_dataset(
        species_counts, 
        size=size,
        balance_classes=True,  # Show balanced version
        min_segments_per_class=50
    )
    
    total_segments = sum(demo_limits.values())
    train_approx = int(total_segments * 0.6)
    val_approx = int(total_segments * 0.2)
    test_approx = total_segments - train_approx - val_approx
    
    print(f"   📈 Training time estimate: {demo_config['estimated_time']}")
    print(f"   🎵 Total spectrograms: {total_segments}")
    print(f"   🏗️ Classes: {len(demo_limits)} (consistent across all sizes)")
    if size == 'small':
        print("      🎯 Perfect for: Rapid prototyping, algorithm testing")
        print("      💡 Best for: Initial experiments and debugging")
    elif size == 'medium':
        print("      🎯 Perfect for: Hyperparameter tuning, model comparison")
        print("      💡 Best for: Finding optimal settings")
    elif size == 'large':
        print("      🎯 Perfect for: Comprehensive evaluation, final training")
        print("      💡 Best for: Thorough model assessment")
    elif size == 'full':
        print("      🎯 Perfect for: Production model, maximum performance")
        print("      💡 Best for: Final deployment model")

print("\n"+"="*70)
print("🚀 Recommended Development Workflow:")
print("   1. 'small' → Rapid prototyping & debugging (~2K spectrograms, 10 classes)")
print("   2. 'medium' → Hyperparameter optimization (~8K spectrograms, 10 classes)")
print("   3. 'large' → Comprehensive evaluation (~15K spectrograms, 10 classes)")
print("   4. 'full' → Production training (~29K spectrograms, 10 classes)")
print("\n🎯 Key Benefits:")
print("   ✅ Consistent model architecture across all experiments")
print("   ✅ Fair comparison between different dataset sizes")
print("   ✅ No need to retrain with different class counts")
print("   ✅ Same evaluation metrics across all experiments")
print("\n🔧 To change size: modify dataset_size_config['size'] above and re-run dataset loading")
print("⚖️ Toggle balancing: dataset_size_config['balance_classes'] = True/False")
print("🎵 Note: Each segment = 1 mel spectrogram for training")

In [None]:
# 🧪 Quick Test: Different Dataset Size Configurations
print("🧪 Testing Different Dataset Size Configurations:")
print("-" * 50)

# Test different dataset sizes to show consistent class counts
test_sizes = ['small', 'medium', 'large', 'full']

for test_size in test_sizes:
    print(f"\n📊 {test_size.upper()} Dataset Configuration:")
    test_limits, test_config = create_sized_segment_dataset(
        species_counts, 
        size=test_size,
        balance_classes=True,
        min_segments_per_class=50
    )
    total_spectrograms = sum(test_limits.values())
    min_per_class = min(test_limits.values()) if test_limits else 0
    max_per_class = max(test_limits.values()) if test_limits else 0
    
    print(f"   Classes: {len(test_limits)} (same for all sizes)")
    print(f"   Spectrograms per class: {min_per_class}-{max_per_class}")
    print(f"   Total spectrograms: {total_spectrograms}")
    print(f"   Estimated time: {test_config['estimated_time']}")

print(f"\n🔄 Current setting: {dataset_size_config['size']} dataset")
print(f"📊 Current classes: {len(selected_species_limits)} (consistent across all sizes)")
print(f"📈 Current total spectrograms: {sum(selected_species_limits.values())}")

print(f"\n💡 To switch dataset size:")
print("   1. Change: dataset_size_config['size'] = 'small'/'medium'/'large'/'full'")
print("   2. Set: splitting_config['force_new_splits'] = True")
print("   3. Re-run the dataset loading cells")
print(f"   4. Result: Always 10 classes, different training data amounts")

print(f"\n🎯 Key Benefits:")
print("   ✅ Same model architecture for all experiments")
print("   ✅ Fair comparison between dataset sizes")
print("   ✅ Consistent evaluation metrics")
print("   ✅ No need to change model configuration")

In [None]:
# Create universal datasets
print("🔧 Creating universal datasets...")

train_dataset = UniversalAudioDataset(
    X_train, y_train_encoded, label_encoder, 
    transform=transform_train, target_size=config['target_size']
)

val_dataset = UniversalAudioDataset(
    X_val, y_val_encoded, label_encoder, 
    transform=transform_val, target_size=config['target_size']
)

test_dataset = UniversalAudioDataset(
    X_test, y_test_encoded, label_encoder, 
    transform=transform_val, target_size=config['target_size']
)

print("✅ Datasets created with universal adaptive preprocessing")

# Quick test
sample = train_dataset[0]
print(f"📊 Sample check: {sample['pixel_values'].shape}, label: {label_encoder.classes_[sample['labels'].item()]}")

In [None]:
# Model setup and training
print("🤖 Setting up model...")

# Load model and processor
model = ViTForImageClassification.from_pretrained(
    config['model_name'], 
    num_labels=len(label_encoder.classes_),
    ignore_mismatched_sizes=True
)
processor = ViTImageProcessor.from_pretrained(config['model_name'])

# Training arguments
training_args = TrainingArguments(
    output_dir='./vit-base-manuai',
    num_train_epochs=config['num_epochs'],
    per_device_train_batch_size=config['batch_size'],
    per_device_eval_batch_size=config['batch_size'],
    warmup_steps=config['warmup_steps'],
    learning_rate=config['learning_rate'],
    logging_dir='./logs',
    logging_steps=50,
    eval_steps=config['eval_steps'],
    save_steps=config['save_steps'],
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    greater_is_better=True,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
)

# Metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Data collator
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([item['pixel_values'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch])
    }

# Early stopping
early_stopping = EarlyStoppingCallback(early_stopping_patience=config['early_stopping_patience'])

print(f"✅ Model loaded: {config['model_name']}")
print(f"🎯 Output classes: {len(label_encoder.classes_)}")
print(f"⚙️ Training configuration set")

In [None]:
# Give a quick overview of the training setup
print("\n🔍 Training setup:")
print(f"   Model: {config['model_name']}")
print(f"   Epochs: {config['num_epochs']}")
print(f"   Batch size: {config['batch_size']}")
print(f"   Learning rate: {config['learning_rate']}")
print(f"   Warmup steps: {config['warmup_steps']}")
print(f"   Evaluation steps: {config['eval_steps']}")
print(f"   Save steps: {config['save_steps']}")
print(f"   Early stopping patience: {config['early_stopping_patience']}")

In [None]:
# Training
print("🚀 Starting training...")

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,
    callbacks=[early_stopping]
)

# Train the model
train_results = trainer.train()

print("✅ Training completed!")
print(f"📈 Final train loss: {train_results.training_loss:.4f}")

# Save the model
trainer.save_model()
print("💾 Model saved")

In [None]:
# Evaluation
print("📊 Evaluating model performance...")

# Evaluate on validation set
val_results = trainer.evaluate(eval_dataset=val_dataset)
print(f"\n🎯 Validation Results:")
print(f"   Accuracy: {val_results['eval_accuracy']:.4f}")
print(f"   F1 Score: {val_results['eval_f1']:.4f}")
print(f"   Precision: {val_results['eval_precision']:.4f}")
print(f"   Recall: {val_results['eval_recall']:.4f}")

# Evaluate on test set
test_results = trainer.evaluate(eval_dataset=test_dataset)
print(f"\n🧪 Test Results:")
print(f"   Accuracy: {test_results['eval_accuracy']:.4f}")
print(f"   F1 Score: {test_results['eval_f1']:.4f}")
print(f"   Precision: {test_results['eval_precision']:.4f}")
print(f"   Recall: {test_results['eval_recall']:.4f}")

# Detailed per-class analysis
predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids

# Classification report
print("\n📋 Detailed Classification Report:")
print(classification_report(y_true, y_pred, target_names=label_encoder.classes_, digits=3))

# Focus on challenging species performance
print("\n🎯 Species Performance Overview:")
for species in label_encoder.classes_:
    species_idx = np.where(label_encoder.classes_ == species)[0][0]
    species_mask = y_true == species_idx
    if np.sum(species_mask) > 0:  # Only if species has test samples
        species_accuracy = np.mean(y_pred[species_mask] == y_true[species_mask])
        print(f"   {species.title()}: {species_accuracy:.3f} ({species_accuracy*100:.1f}%)")

print("\n✅ Training and evaluation complete!")
print("\n🎯 Quick Summary:")
print("   • Universal adaptive preprocessing that scales with dataset expansion")
print("   • Robust preprocessing based on audio characteristics, not species names")
print("   • Enhanced data augmentation for better generalization")
print("   • Ready for production use with any number of bird species")