# 🎹 Piano Perception Transformer - Comprehensive Evaluation

**Phase 3: Performance Analysis and Baseline Comparisons**

This notebook implements comprehensive evaluation of the fine-tuned Piano Perception Transformer, including performance analysis, baseline comparisons, and model interpretability.

**Pipeline Overview:**
1. 🔧 **Setup & Environment** - Dependencies, model loading, evaluation framework
2. 📊 **Model Performance Analysis** - Correlation, MSE, per-dimension analysis
3. 🏆 **Baseline Comparisons** - CNN, Random Forest, Linear Regression baselines
4. 🔍 **Model Interpretability** - Attention visualization, feature analysis
5. 📈 **Results Visualization** - Comprehensive performance plots and analysis
6. 📝 **Final Report** - Summary of findings and conclusions

**Input:** Fine-tuned AST model from Phase 2  
**Output:** Comprehensive evaluation report with performance metrics and visualizations

---
## 🔧 Cell 1: Setup and Model Loading
---

In [None]:
print("🚀 Setting up Piano Perception Transformer - Evaluation Phase...")

# Clone repo (skip if already exists)
import os
if not os.path.exists('piano-perception-transformer'):
    !git clone https://github.com/Jai-Dhiman/piano-perception-transformer.git
else:
    print("Repository already exists, skipping clone...")

%cd piano-perception-transformer

# Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Install enhanced dependencies including ML research tools
print("📦 Installing enhanced dependencies with uv...")
!export PATH="/usr/local/bin:$PATH" && uv pip install --system jax[tpu] flax optax librosa pandas wandb requests zipfile36 scikit-learn scipy seaborn matplotlib pretty_midi soundfile plotly kaleido

# Import core libraries
import sys
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.model_selection import cross_val_score
import jax
import jax.numpy as jnp
from datetime import datetime
from flax import linen as nn
import time

# Initialize WandB for evaluation tracking
import wandb

try:
    wandb.login()  # This will prompt for API key in Colab
    
    run = wandb.init(
        project="piano-perception-transformer-evaluation",
        name=f"ast-evaluation-{datetime.now().strftime('%Y%m%d-%H%M')}",
        config={
            "phase": "comprehensive_evaluation",
            "evaluation_type": "final_performance_analysis",
            "metrics": ["correlation", "mse", "mae", "r2", "per_dimension_analysis"],
            "baselines": ["random_forest", "linear_regression", "ridge_regression", "cnn_baseline"],
            "interpretability": ["attention_maps", "feature_importance", "error_analysis"],
            "target_correlation": 0.7
        },
        tags=["evaluation", "ast", "percepiano", "analysis", "baselines"]
    )
    
    print("✅ WandB initialized successfully!")
    print(f"   • Project: piano-perception-transformer-evaluation")
    print(f"   • Run name: {run.name}")
    print(f"   • Tracking: https://wandb.ai/{run.entity}/{run.project}/runs/{run.id}")
    
except Exception as e:
    print(f"⚠️ WandB initialization failed: {e}")
    print("   • Continuing without experiment tracking")

# Mount Google Drive
from google.colab import drive
print("🔗 Mounting Google Drive...")
drive.mount('/content/drive')

# Create directory structure
base_dir = '/content/drive/MyDrive/piano_transformer'
directories = [
    f'{base_dir}/processed_spectrograms',
    f'{base_dir}/checkpoints/evaluation',
    f'{base_dir}/logs',
    f'{base_dir}/temp'
]

print("📁 Setting up directory structure...")
for directory in directories:
    os.makedirs(directory, exist_ok=True)
    print(f"✅ Created: {directory}")

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print(f"\n🧠 JAX Configuration:")
print(f"   • Backend: {jax.default_backend()}")
print(f"   • Devices: {jax.device_count()}")
print(f"   • Device type: {jax.devices()[0].device_kind}")

print("\n✅ Evaluation setup completed!")

---
## 📊 Cell 2: Load Fine-tuned Model and Test Data
---

In [None]:
import sys
import os
import pickle
import jax
import jax.numpy as jnp
from flax import linen as nn

sys.path.append('/content/piano-perception-transformer/src')

print("📂 Loading Fine-tuned Model and Test Data")
print("="*60)

# Define the model architecture (same as fine-tuning)
class ProductionASTForRegression(nn.Module):
    """AST model with regression head for perceptual prediction"""
    
    patch_size: int = 16
    embed_dim: int = 768
    num_layers: int = 12
    num_heads: int = 12
    mlp_ratio: float = 4.0
    dropout_rate: float = 0.1
    attention_dropout: float = 0.1
    stochastic_depth_rate: float = 0.1
    num_outputs: int = 19
    
    def setup(self):
        self.drop_rates = [
            self.stochastic_depth_rate * i / (self.num_layers - 1) 
            for i in range(self.num_layers)
        ]
    
    @nn.compact
    def __call__(self, x, training: bool = True):
        batch_size, time_frames, freq_bins = x.shape
        
        # === PATCH EMBEDDING ===
        patch_size = self.patch_size
        
        time_pad = (patch_size - time_frames % patch_size) % patch_size
        freq_pad = (patch_size - freq_bins % patch_size) % patch_size
        
        if time_pad > 0 or freq_pad > 0:
            x = jnp.pad(x, ((0, 0), (0, time_pad), (0, freq_pad)), mode='constant', constant_values=-80.0)
        
        time_patches = x.shape[1] // patch_size
        freq_patches = x.shape[2] // patch_size
        num_patches = time_patches * freq_patches
        
        x = x.reshape(batch_size, time_patches, patch_size, freq_patches, patch_size)
        x = x.transpose(0, 1, 3, 2, 4)
        x = x.reshape(batch_size, num_patches, patch_size * patch_size)
        
        x = nn.Dense(
            self.embed_dim, 
            kernel_init=nn.initializers.truncated_normal(stddev=0.02),
            bias_init=nn.initializers.zeros,
            name='patch_embedding'
        )(x)
        
        # === POSITIONAL ENCODING ===
        pos_embedding = self.param(
            'pos_embedding',
            nn.initializers.truncated_normal(stddev=0.02),
            (1, num_patches, self.embed_dim)
        )
        x = x + pos_embedding
        
        x = nn.Dropout(self.dropout_rate, deterministic=not training)(x)
        
        # === TRANSFORMER LAYERS ===
        for layer_idx in range(self.num_layers):
            drop_rate = self.drop_rates[layer_idx]
            
            # Self-Attention
            residual = x
            x = nn.LayerNorm(epsilon=1e-6, name=f'norm1_layer{layer_idx}')(x)
            
            attention = nn.MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                dropout_rate=self.attention_dropout,
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'attention_layer{layer_idx}'
            )(x, x, deterministic=not training)
            
            # Stochastic depth for attention
            if training and drop_rate > 0:
                random_tensor = jax.random.uniform(
                    self.make_rng('stochastic_depth'), (batch_size, 1, 1)
                )
                keep_prob = 1.0 - drop_rate
                binary_tensor = (random_tensor < keep_prob).astype(x.dtype)
                attention = attention * binary_tensor / keep_prob
            
            x = residual + nn.Dropout(self.dropout_rate, deterministic=not training)(attention)
            
            # MLP
            residual = x
            x = nn.LayerNorm(epsilon=1e-6, name=f'norm2_layer{layer_idx}')(x)
            
            mlp_hidden = int(self.embed_dim * self.mlp_ratio)
            
            mlp = nn.Dense(
                mlp_hidden, 
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'mlp_dense1_layer{layer_idx}'
            )(x)
            mlp = nn.gelu(mlp)
            mlp = nn.Dropout(self.dropout_rate, deterministic=not training)(mlp)
            
            mlp = nn.Dense(
                self.embed_dim,
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'mlp_dense2_layer{layer_idx}'
            )(mlp)
            
            # Stochastic depth for MLP
            if training and drop_rate > 0:
                random_tensor = jax.random.uniform(
                    self.make_rng('stochastic_depth'), (batch_size, 1, 1)
                )
                keep_prob = 1.0 - drop_rate
                binary_tensor = (random_tensor < keep_prob).astype(x.dtype)
                mlp = mlp * binary_tensor / keep_prob
            
            x = residual + nn.Dropout(self.dropout_rate, deterministic=not training)(mlp)
        
        # === FINAL PROCESSING ===
        x = nn.LayerNorm(epsilon=1e-6, name='final_norm')(x)
        
        # Global average pooling
        x = jnp.mean(x, axis=1)  # [batch, embed_dim]
        
        # === REGRESSION HEAD ===
        x = nn.Dense(
            512, 
            kernel_init=nn.initializers.truncated_normal(stddev=0.02),
            bias_init=nn.initializers.zeros,
            name='regression_hidden'
        )(x)
        x = nn.gelu(x)
        x = nn.Dropout(self.dropout_rate, deterministic=not training)(x)
        
        predictions = nn.Dense(
            self.num_outputs,
            kernel_init=nn.initializers.truncated_normal(stddev=0.02),
            bias_init=nn.initializers.zeros,
            name='regression_output'
        )(x)
        
        return predictions

def load_finetuned_model(checkpoint_path):
    """Load fine-tuned model checkpoint"""
    print(f"📂 Loading fine-tuned model: {checkpoint_path}")
    
    try:
        with open(checkpoint_path, 'rb') as f:
            checkpoint = pickle.load(f)
        
        print(f"✅ Model checkpoint loaded successfully")
        
        # Handle potential string/None values for correlation
        best_val_corr = checkpoint.get('best_val_correlation', 'N/A')
        if isinstance(best_val_corr, (int, float)) and best_val_corr is not None:
            print(f"   • Best validation correlation: {best_val_corr:.4f}")
        else:
            print(f"   • Best validation correlation: {best_val_corr}")
            
        print(f"   • Training epochs: {checkpoint.get('finetuning_results', {}).get('total_epochs', 'N/A')}")
        print(f"   • Model parameters: {sum(x.size for x in jax.tree.leaves(checkpoint['params'])):,}")
        
        return checkpoint
        
    except Exception as e:
        print(f"❌ Failed to load model checkpoint: {e}")
        raise

def model_predict(model, params, batch_specs):
    """Model prediction function (not JIT-compiled to avoid method issues)"""
    return model.apply(params, batch_specs, training=False)

def evaluate_model_on_dataset(model, params, dataset, batch_size=32, description="Dataset"):
    """Evaluate model on a dataset and return predictions and targets"""
    print(f"🔍 Evaluating model on {description}...")
    
    all_predictions = []
    all_targets = []
    
    num_samples = len(dataset)
    num_batches = (num_samples + batch_size - 1) // batch_size
    
    for batch_idx in range(num_batches):
        # Get batch
        batch_specs, batch_targets = dataset.get_batch(batch_size, shuffle=False)
        batch_specs = jnp.array(batch_specs)
        
        # Model prediction
        predictions = model_predict(model, params, batch_specs)
        
        # Convert to numpy and store
        all_predictions.append(np.array(predictions))
        all_targets.append(np.array(batch_targets))
        
        if batch_idx % 10 == 0:
            print(f"   Processed {batch_idx + 1}/{num_batches} batches...")
    
    # Concatenate all results
    predictions = np.concatenate(all_predictions, axis=0)[:num_samples]  # Trim to exact size
    targets = np.concatenate(all_targets, axis=0)[:num_samples]
    
    print(f"✅ Evaluation completed on {predictions.shape[0]} samples")
    print(f"   • Predictions shape: {predictions.shape}")
    print(f"   • Targets shape: {targets.shape}")
    
    return predictions, targets

# Load fine-tuned model
finetuned_checkpoint_path = '/content/drive/MyDrive/piano_transformer/checkpoints/finetuning/final_finetuned_model.pkl'

if os.path.exists(finetuned_checkpoint_path):
    finetuned_checkpoint = load_finetuned_model(finetuned_checkpoint_path)
    model_params = finetuned_checkpoint['params']
    label_scaler = finetuned_checkpoint['label_scaler']
    print(f"✅ Fine-tuned model loaded successfully!")
else:
    print(f"❌ Fine-tuned model not found: {finetuned_checkpoint_path}")
    print(f"   Please run the fine-tuning notebook first")
    raise FileNotFoundError("Fine-tuned model checkpoint not found")

# Initialize model
eval_model = ProductionASTForRegression(
    patch_size=16,
    embed_dim=768,
    num_layers=12,
    num_heads=12,
    mlp_ratio=4.0,
    dropout_rate=0.1,
    attention_dropout=0.1,
    stochastic_depth_rate=0.1,
    num_outputs=19
)

print(f"\n🧠 Model Architecture Summary:")
print(f"   • Total parameters: {sum(x.size for x in jax.tree.leaves(model_params)):,}")
print(f"   • Architecture: 12-layer AST + regression head")
print(f"   • Input: 128×128 mel-spectrograms")
print(f"   • Output: 19 perceptual dimensions")

# Note: We'll load the test dataset in the next cell when we evaluate baselines
print(f"\n🎯 Model loaded and ready for comprehensive evaluation!")

---
## 🏆 Cell 3: Load Test Data and Baseline Models
---

In [None]:
# Clone and setup PercePiano dataset
print("📂 Setting up PercePiano dataset for evaluation...")

# Define PercePiano directory path
percepiano_dir = '/content/drive/MyDrive/PercePiano'

# Clone PercePiano dataset if not exists
if not os.path.exists(percepiano_dir):
    print("📥 Cloning PercePiano dataset repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git {percepiano_dir}
    print("✅ PercePiano dataset cloned successfully!")
else:
    print("✅ PercePiano dataset already exists")

# Verify essential directory structure
required_paths = [
    f'{percepiano_dir}/labels/label_2round_mean_reg_19_with0_rm_highstd0.json',
    f'{percepiano_dir}/virtuoso/data/all_2rounds'
]

print("🔍 Verifying dataset structure...")
missing_paths = []
for path in required_paths:
    if not os.path.exists(path):
        missing_paths.append(path)
    else:
        if path.endswith('.json'):
            print(f"   ✅ Labels file found: {os.path.basename(path)}")
        else:
            midi_count = len([f for f in os.listdir(path) if f.endswith('.mid')])
            print(f"   ✅ MIDI directory found: {midi_count} MIDI files")

if missing_paths:
    print("❌ Missing required files/directories:")
    for path in missing_paths:
        print(f"   • {path}")
    print("\n💡 The dataset may need to be downloaded separately.")
    print("   Please refer to the PercePiano repository instructions.")
    raise FileNotFoundError("PercePiano dataset structure incomplete")

print("✅ Dataset structure verified successfully!")

# Import required libraries for baselines and data loading
import json
import pretty_midi
import librosa
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

print("\n📂 Loading Test Data and Creating Baseline Models")
print("="*60)

def midi_to_spectrogram(midi_path, sr=22050, n_mels=128, hop_length=512, n_fft=2048, target_length=128):
    """Convert MIDI to mel-spectrogram"""
    try:
        midi_data = pretty_midi.PrettyMIDI(midi_path)
        try:
            audio = midi_data.fluidsynth(fs=sr)
        except:
            audio = midi_data.synthesize(fs=sr)
        
        # Ensure minimum audio length
        min_duration = 2.0
        min_samples = int(min_duration * sr)
        if len(audio) < min_samples:
            padding = min_samples - len(audio)
            audio = np.pad(audio, (0, padding), mode='constant')
        
        # Convert to mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio, sr=sr, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft,
            power=2.0, fmin=20, fmax=sr//2
        )
        
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        mel_spec_transposed = mel_spec_db.T
        
        current_length = mel_spec_transposed.shape[0]
        if current_length >= target_length:
            normalized_spec = mel_spec_transposed[:target_length, :]
        else:
            pad_width = target_length - current_length
            normalized_spec = np.pad(
                mel_spec_transposed, ((0, pad_width), (0, 0)), 
                mode='constant', constant_values=-80.0
            )
        
        return normalized_spec
    except Exception as e:
        print(f"Error converting MIDI {midi_path}: {str(e)}")
        return None

def load_percepiano_data(percepiano_dir):
    """Load PercePiano dataset"""
    labels_file = f'{percepiano_dir}/labels/label_2round_mean_reg_19_with0_rm_highstd0.json'
    
    print(f"📋 Loading labels from: {labels_file}")
    with open(labels_file, 'r') as f:
        labels_data = json.load(f)

    print(f"📊 Loaded PercePiano labels: {len(labels_data)} samples")

    midi_dir = f'{percepiano_dir}/virtuoso/data/all_2rounds'
    
    if not os.path.exists(midi_dir):
        raise FileNotFoundError(f"MIDI directory not found: {midi_dir}")
    
    midi_files = [f for f in os.listdir(midi_dir) if f.endswith('.mid')]

    if len(midi_files) == 0:
        raise FileNotFoundError(f"No MIDI files found in: {midi_dir}")

    print(f"🎵 Found {len(midi_files)} MIDI files")

    samples = []
    processed_count = 0

    for filename, label_data in labels_data.items():
        # Find corresponding MIDI file (flexible matching)
        midi_filename = None
        for midi_file in midi_files:
            # Try multiple matching strategies
            if (filename in midi_file or 
                midi_file.replace('.mid', '') in filename or
                filename.replace('.mid', '') in midi_file.replace('.mid', '')):
                midi_filename = midi_file
                break

        if midi_filename is None:
            continue

        # Extract the 19 perceptual features
        if isinstance(label_data, list) and len(label_data) >= 19:
            perceptual_features = np.array(label_data[:19], dtype=np.float32)
        else:
            continue

        # Convert MIDI to spectrogram
        midi_path = os.path.join(midi_dir, midi_filename)
        spectrogram = midi_to_spectrogram(midi_path, target_length=128)
        
        if spectrogram is not None and spectrogram.shape == (128, 128):
            samples.append({
                'spectrogram': spectrogram,
                'labels': perceptual_features,
                'filename': filename
            })
            processed_count += 1
            
            if processed_count % 25 == 0:
                print(f"📊 Processed {processed_count} samples...")

    print(f"✅ Successfully processed {processed_count} samples")
    return samples

class PercePianoDataset:
    """PercePiano dataset class"""
    def __init__(self, samples, split='train', train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_seed=42):
        self.split = split
        self.random_seed = random_seed
        
        np.random.seed(random_seed)
        
        spectrograms = [s['spectrogram'] for s in samples]
        labels = [s['labels'] for s in samples]
        filenames = [s['filename'] for s in samples]
        
        # Create splits
        train_specs, temp_specs, train_labels, temp_labels, train_files, temp_files = train_test_split(
            spectrograms, labels, filenames,
            test_size=(val_ratio + test_ratio), 
            random_state=random_seed
        )
        
        val_size = val_ratio / (val_ratio + test_ratio)
        val_specs, test_specs, val_labels, test_labels, val_files, test_files = train_test_split(
            temp_specs, temp_labels, temp_files,
            test_size=(1 - val_size), 
            random_state=random_seed
        )
        
        # Assign data based on split
        if split == 'train':
            self.spectrograms = np.array(train_specs)
            self.labels = np.array(train_labels)
            self.filenames = train_files
        elif split == 'val':
            self.spectrograms = np.array(val_specs)
            self.labels = np.array(val_labels)
            self.filenames = val_files
        elif split == 'test':
            self.spectrograms = np.array(test_specs)
            self.labels = np.array(test_labels)
            self.filenames = test_files
        
        self.num_samples = len(self.spectrograms)
        
        print(f"📊 {split.title()} split: {self.num_samples} samples")
    
    def set_label_scaler(self, scaler):
        """Apply label normalization"""
        self.labels = scaler.transform(self.labels)
        print(f"✅ Applied label normalization to {self.split} split")
    
    def __len__(self):
        return self.num_samples
    
    def get_batch(self, batch_size, shuffle=False):
        """Get a batch of data"""
        if shuffle:
            indices = np.random.choice(self.num_samples, size=batch_size, replace=True)
        else:
            start_idx = np.random.randint(0, max(1, self.num_samples - batch_size + 1))
            indices = np.arange(start_idx, start_idx + batch_size) % self.num_samples
        
        batch_specs = self.spectrograms[indices]
        batch_labels = self.labels[indices]
        
        return batch_specs, batch_labels
    
    def get_all_data(self):
        """Get all data at once"""
        return self.spectrograms, self.labels

# Load PercePiano dataset for evaluation
print("\n🔄 Loading and processing PercePiano dataset for evaluation...")

try:
    # Load data
    raw_samples = load_percepiano_data(percepiano_dir)

    # Create datasets
    train_dataset = PercePianoDataset(raw_samples, split='train', random_seed=42)
    val_dataset = PercePianoDataset(raw_samples, split='val', random_seed=42)
    test_dataset = PercePianoDataset(raw_samples, split='test', random_seed=42)

    # Apply same label scaling as used in training
    train_dataset.labels = label_scaler.fit_transform(train_dataset.labels)  # Refit for consistency
    val_dataset.set_label_scaler(label_scaler)
    test_dataset.set_label_scaler(label_scaler)

    print(f"\n✅ Datasets loaded successfully!")
    print(f"   • Training dataset: {len(train_dataset)} samples")
    print(f"   • Validation dataset: {len(val_dataset)} samples")
    print(f"   • Test dataset: {len(test_dataset)} samples")

    # Prepare baseline features (flattened spectrograms + statistical features)
    def extract_baseline_features(spectrograms):
        """Extract features for baseline models"""
        features = []
        
        for spec in spectrograms:
            # Flatten spectrogram
            flat_spec = spec.flatten()
            
            # Statistical features
            stats = [
                np.mean(spec), np.std(spec), np.min(spec), np.max(spec),
                np.median(spec), np.percentile(spec, 25), np.percentile(spec, 75),
                np.mean(spec, axis=0).mean(), np.std(spec, axis=0).mean(),  # Frequency stats
                np.mean(spec, axis=1).mean(), np.std(spec, axis=1).mean(),  # Time stats
            ]
            
            # Combine flattened spec and stats (subsample for memory efficiency)
            subsampled_spec = flat_spec[::4]  # Take every 4th element
            combined = np.concatenate([subsampled_spec, stats])
            features.append(combined)
        
        return np.array(features)

    print("\n🔧 Extracting baseline features...")
    train_specs, train_labels = train_dataset.get_all_data()
    val_specs, val_labels = val_dataset.get_all_data()
    test_specs, test_labels = test_dataset.get_all_data()

    # Extract features for baselines
    train_features = extract_baseline_features(train_specs)
    val_features = extract_baseline_features(val_specs)
    test_features = extract_baseline_features(test_specs)

    print(f"✅ Baseline features extracted")
    print(f"   • Feature dimensions: {train_features.shape[1]}")
    print(f"   • Train features: {train_features.shape}")
    print(f"   • Val features: {val_features.shape}")
    print(f"   • Test features: {test_features.shape}")

    # Train baseline models
    print("\n🏗️ Training baseline models...")

    baselines = {
        'Linear Regression': LinearRegression(),
        'Ridge Regression': Ridge(alpha=1.0),
        'Random Forest': RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1),
        'MLP': MLPRegressor(hidden_layer_sizes=(512, 256), max_iter=300, random_state=42)
    }

    trained_baselines = {}

    for name, model in baselines.items():
        print(f"   Training {name}...")
        start_time = time.time()
        
        try:
            model.fit(train_features, train_labels)
            trained_baselines[name] = model
            
            training_time = time.time() - start_time
            print(f"     ✅ {name} trained in {training_time:.1f}s")
        except Exception as e:
            print(f"     ❌ {name} training failed: {e}")

    print(f"\n✅ Baseline models trained successfully!")
    print(f"   • {len(trained_baselines)} baseline models ready")
    print(f"\n🎯 Ready for comprehensive evaluation!")

except Exception as e:
    print(f"❌ PercePiano data loading failed: {e}")
    print(f"   Error details: {str(e)}")
    print(f"\n💡 Troubleshooting tips:")
    print(f"   1. Check if PercePiano repository was cloned correctly")
    print(f"   2. Verify internet connection for cloning")
    print(f"   3. Ensure sufficient disk space in Google Drive")
    raise Exception(f"PercePiano dataset setup failed: {e}")

---
## 📈 Cell 4: Comprehensive Model Evaluation
---

In [None]:
print("📈 Comprehensive Model Evaluation")
print("="*60)

def compute_detailed_metrics(predictions, targets, model_name="Model"):
    """Compute comprehensive evaluation metrics"""
    metrics = {}
    
    # Overall metrics
    mse = mean_squared_error(targets, predictions)
    mae = mean_absolute_error(targets, predictions)
    
    # Overall correlation
    flat_pred = predictions.flatten()
    flat_target = targets.flatten()
    pearson_corr, _ = pearsonr(flat_pred, flat_target)
    
    # R² score
    r2 = r2_score(targets, predictions)
    
    metrics.update({
        'mse': mse,
        'mae': mae,
        'rmse': np.sqrt(mse),
        'pearson_correlation': pearson_corr,
        'r2_score': r2
    })
    
    # Per-dimension metrics
    per_dim_corr = []
    per_dim_mse = []
    per_dim_mae = []
    
    for dim in range(predictions.shape[1]):
        dim_pred = predictions[:, dim]
        dim_target = targets[:, dim]
        
        if np.std(dim_pred) > 1e-8 and np.std(dim_target) > 1e-8:
            corr, _ = pearsonr(dim_pred, dim_target)
            per_dim_corr.append(corr)
        else:
            per_dim_corr.append(0.0)
        
        per_dim_mse.append(mean_squared_error(dim_target, dim_pred))
        per_dim_mae.append(mean_absolute_error(dim_target, dim_pred))
    
    metrics.update({
        'per_dimension_correlation': per_dim_corr,
        'per_dimension_mse': per_dim_mse,
        'per_dimension_mae': per_dim_mae,
        'mean_per_dim_correlation': np.mean(per_dim_corr),
        'std_per_dim_correlation': np.std(per_dim_corr)
    })
    
    return metrics

def print_metrics(metrics, model_name):
    """Print evaluation metrics in a formatted way"""
    print(f"\n📊 {model_name} Results:")
    print(f"   • Pearson Correlation: {metrics['pearson_correlation']:.4f}")
    print(f"   • Mean Per-Dim Correlation: {metrics['mean_per_dim_correlation']:.4f} ± {metrics['std_per_dim_correlation']:.4f}")
    print(f"   • RMSE: {metrics['rmse']:.4f}")
    print(f"   • MAE: {metrics['mae']:.4f}")
    print(f"   • R² Score: {metrics['r2_score']:.4f}")

# Evaluate Piano Transformer (our model)
print("🎹 Evaluating Piano Perception Transformer...")

# Get predictions on test set
ast_test_predictions, ast_test_targets = evaluate_model_on_dataset(
    eval_model, model_params, test_dataset, batch_size=32, description="Test Set"
)

# Compute metrics for AST
ast_metrics = compute_detailed_metrics(ast_test_predictions, ast_test_targets, "Piano Transformer")
print_metrics(ast_metrics, "🎹 Piano Perception Transformer")

# Evaluate baseline models
print("\n🏆 Evaluating Baseline Models...")

baseline_results = {}

for name, model in trained_baselines.items():
    print(f"\n   Evaluating {name}...")
    
    try:
        # Get predictions
        baseline_predictions = model.predict(test_features)
        
        # Compute metrics
        baseline_metrics = compute_detailed_metrics(baseline_predictions, test_labels, name)
        baseline_results[name] = {
            'predictions': baseline_predictions,
            'metrics': baseline_metrics
        }
        
        print_metrics(baseline_metrics, f"🏆 {name}")
        
    except Exception as e:
        print(f"     ❌ {name} evaluation failed: {e}")

# Create comparison table
print("\n📋 COMPREHENSIVE COMPARISON RESULTS")
print("="*80)

# Prepare comparison data
comparison_data = []

# Add AST results
comparison_data.append({
    'Model': '🎹 Piano Transformer (AST)',
    'Pearson Corr': ast_metrics['pearson_correlation'],
    'Mean Per-Dim Corr': ast_metrics['mean_per_dim_correlation'],
    'RMSE': ast_metrics['rmse'],
    'MAE': ast_metrics['mae'],
    'R² Score': ast_metrics['r2_score']
})

# Add baseline results
for name, result in baseline_results.items():
    comparison_data.append({
        'Model': f'🏆 {name}',
        'Pearson Corr': result['metrics']['pearson_correlation'],
        'Mean Per-Dim Corr': result['metrics']['mean_per_dim_correlation'],
        'RMSE': result['metrics']['rmse'],
        'MAE': result['metrics']['mae'],
        'R² Score': result['metrics']['r2_score']
    })

# Create comparison DataFrame
comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.sort_values('Pearson Corr', ascending=False)

print(comparison_df.to_string(index=False, float_format='%.4f'))

# Best model analysis
best_model = comparison_df.iloc[0]
print(f"\n🏆 BEST PERFORMING MODEL: {best_model['Model']}")
print(f"   • Pearson Correlation: {best_model['Pearson Corr']:.4f}")
print(f"   • Mean Per-Dimension Correlation: {best_model['Mean Per-Dim Corr']:.4f}")

# Performance analysis
ast_rank = comparison_df[comparison_df['Model'].str.contains('Piano Transformer')].index[0] + 1
total_models = len(comparison_df)

print(f"\n📈 PIANO TRANSFORMER PERFORMANCE ANALYSIS:")
print(f"   • Rank: {ast_rank}/{total_models}")
print(f"   • Performance vs Best: {ast_metrics['pearson_correlation']/best_model['Pearson Corr']*100:.1f}%")

if ast_rank == 1:
    print(f"   🎉 OUTSTANDING! Piano Transformer achieves BEST performance")
elif ast_rank <= 2:
    print(f"   ✅ EXCELLENT! Piano Transformer in top 2 performers")
elif ast_rank <= len(comparison_df) // 2:
    print(f"   ⚠️ MODERATE performance - room for improvement")
else:
    print(f"   ❌ BELOW AVERAGE performance - significant improvement needed")

# Log results to WandB
try:
    wandb.log({
        "ast_pearson_correlation": ast_metrics['pearson_correlation'],
        "ast_mean_per_dim_correlation": ast_metrics['mean_per_dim_correlation'],
        "ast_rmse": ast_metrics['rmse'],
        "ast_mae": ast_metrics['mae'],
        "ast_r2_score": ast_metrics['r2_score'],
        "ast_rank": ast_rank,
        "total_models_compared": total_models,
        "best_model_correlation": best_model['Pearson Corr']
    })
    
    # Log per-dimension correlations
    for i, corr in enumerate(ast_metrics['per_dimension_correlation']):
        wandb.log({f"ast_correlation_dim_{i+1}": corr})
    
    print(f"\n✅ Results logged to WandB")
except:
    print(f"\n⚠️ WandB logging failed - continuing without logging")

print(f"\n🎯 Comprehensive evaluation completed!")

# Store results for visualization
evaluation_results = {
    'ast_metrics': ast_metrics,
    'ast_predictions': ast_test_predictions,
    'ast_targets': ast_test_targets,
    'baseline_results': baseline_results,
    'test_features': test_features,
    'test_labels': test_labels,
    'comparison_df': comparison_df
}

print(f"\n📊 Results stored for visualization in next cell")

---
## 📊 Cell 5: Results Visualization and Analysis
---

In [None]:
print("📊 Creating Comprehensive Visualizations")
print("="*60)

# Set up plotting
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 12

# Define perceptual dimension names (assuming standard PercePiano dimensions)
DIMENSION_NAMES = [
    'Articulation', 'Attack', 'Brightness', 'Depth', 'Dynamics',
    'Fluidity', 'Pace', 'Pedaling', 'Precision', 'Rhythm',
    'Rubato', 'Softness', 'Spacing', 'Stability', 'Strength',
    'Tension', 'Texture', 'Timing', 'Touch'
]

# 1. Model Comparison Bar Chart
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# Overall correlation comparison
models = evaluation_results['comparison_df']['Model'].values
correlations = evaluation_results['comparison_df']['Pearson Corr'].values

colors = ['#FF6B6B' if 'Piano Transformer' in model else '#4ECDC4' for model in models]
bars1 = ax1.bar(range(len(models)), correlations, color=colors, alpha=0.7)
ax1.set_xlabel('Models')
ax1.set_ylabel('Pearson Correlation')
ax1.set_title('🏆 Model Performance Comparison - Overall Correlation')
ax1.set_xticks(range(len(models)))
ax1.set_xticklabels([m.replace('🎹 ', '').replace('🏆 ', '') for m in models], rotation=45, ha='right')
ax1.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar, corr in zip(bars1, correlations):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{corr:.3f}', ha='center', va='bottom', fontweight='bold')

# Per-dimension correlation comparison
per_dim_correlations = evaluation_results['comparison_df']['Mean Per-Dim Corr'].values
bars2 = ax2.bar(range(len(models)), per_dim_correlations, color=colors, alpha=0.7)
ax2.set_xlabel('Models')
ax2.set_ylabel('Mean Per-Dimension Correlation')
ax2.set_title('📊 Model Performance - Mean Per-Dimension Correlation')
ax2.set_xticks(range(len(models)))
ax2.set_xticklabels([m.replace('🎹 ', '').replace('🏆 ', '') for m in models], rotation=45, ha='right')
ax2.grid(axis='y', alpha=0.3)

# Add value labels
for bar, corr in zip(bars2, per_dim_correlations):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{corr:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/piano_transformer/evaluation_model_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

# 2. Per-Dimension Performance Analysis
fig, ax = plt.subplots(figsize=(20, 10))

ast_per_dim_corr = evaluation_results['ast_metrics']['per_dimension_correlation']
x_pos = np.arange(len(DIMENSION_NAMES))

# Create bars with color coding based on performance
colors = ['#2ECC71' if corr >= 0.7 else '#F39C12' if corr >= 0.5 else '#E74C3C' for corr in ast_per_dim_corr]
bars = ax.bar(x_pos, ast_per_dim_corr, color=colors, alpha=0.7)

ax.set_xlabel('Perceptual Dimensions')
ax.set_ylabel('Pearson Correlation')
ax.set_title('🎹 Piano Transformer: Per-Dimension Performance Analysis')
ax.set_xticks(x_pos)
ax.set_xticklabels(DIMENSION_NAMES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.3)

# Add horizontal lines for performance thresholds
ax.axhline(y=0.7, color='green', linestyle='--', alpha=0.7, label='Excellent (≥0.7)')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.7, label='Good (≥0.5)')
ax.axhline(y=0.3, color='red', linestyle='--', alpha=0.7, label='Moderate (≥0.3)')

# Add value labels on bars
for bar, corr in zip(bars, ast_per_dim_corr):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
             f'{corr:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=10)

ax.legend()
ax.set_ylim(0, 1.0)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/piano_transformer/evaluation_per_dimension.png', dpi=300, bbox_inches='tight')
plt.show()

# 3. Prediction vs. Target Scatter Plots (for best and worst dimensions)
best_dim_idx = np.argmax(ast_per_dim_corr)
worst_dim_idx = np.argmin(ast_per_dim_corr)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Best dimension
best_predictions = evaluation_results['ast_predictions'][:, best_dim_idx]
best_targets = evaluation_results['ast_targets'][:, best_dim_idx]
ax1.scatter(best_targets, best_predictions, alpha=0.6, color='#2ECC71')
ax1.plot([best_targets.min(), best_targets.max()], [best_targets.min(), best_targets.max()], 'r--', lw=2)
ax1.set_xlabel('True Values')
ax1.set_ylabel('Predicted Values')
ax1.set_title(f'✅ Best Dimension: {DIMENSION_NAMES[best_dim_idx]}\n(r = {ast_per_dim_corr[best_dim_idx]:.3f})')
ax1.grid(alpha=0.3)

# Worst dimension
worst_predictions = evaluation_results['ast_predictions'][:, worst_dim_idx]
worst_targets = evaluation_results['ast_targets'][:, worst_dim_idx]
ax2.scatter(worst_targets, worst_predictions, alpha=0.6, color='#E74C3C')
ax2.plot([worst_targets.min(), worst_targets.max()], [worst_targets.min(), worst_targets.max()], 'r--', lw=2)
ax2.set_xlabel('True Values')
ax2.set_ylabel('Predicted Values')
ax2.set_title(f'❌ Worst Dimension: {DIMENSION_NAMES[worst_dim_idx]}\n(r = {ast_per_dim_corr[worst_dim_idx]:.3f})')
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/piano_transformer/evaluation_scatter_plots.png', dpi=300, bbox_inches='tight')
plt.show()

# 4. Error Analysis
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Residuals distribution
residuals = evaluation_results['ast_predictions'] - evaluation_results['ast_targets']
ax1.hist(residuals.flatten(), bins=50, alpha=0.7, color='skyblue', edgecolor='black')
ax1.set_xlabel('Prediction Residuals')
ax1.set_ylabel('Frequency')
ax1.set_title('📊 Residuals Distribution')
ax1.axvline(x=0, color='red', linestyle='--', linewidth=2)
ax1.grid(alpha=0.3)

# Per-dimension RMSE
per_dim_rmse = np.sqrt(evaluation_results['ast_metrics']['per_dimension_mse'])
bars = ax2.bar(range(len(DIMENSION_NAMES)), per_dim_rmse, color='lightcoral', alpha=0.7)
ax2.set_xlabel('Perceptual Dimensions')
ax2.set_ylabel('RMSE')
ax2.set_title('📈 Per-Dimension Root Mean Squared Error')
ax2.set_xticks(range(len(DIMENSION_NAMES)))
ax2.set_xticklabels(DIMENSION_NAMES, rotation=45, ha='right')
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/piano_transformer/evaluation_error_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

# 5. Performance Summary Heatmap
fig, ax = plt.subplots(figsize=(12, 8))

# Create heatmap data: models vs metrics
metrics_names = ['Pearson Corr', 'Mean Per-Dim Corr', 'R² Score']
model_names = [m.replace('🎹 ', '').replace('🏆 ', '') for m in evaluation_results['comparison_df']['Model']]

heatmap_data = []
for _, row in evaluation_results['comparison_df'].iterrows():
    heatmap_data.append([
        row['Pearson Corr'],
        row['Mean Per-Dim Corr'],
        row['R² Score']
    ])

heatmap_data = np.array(heatmap_data)

im = ax.imshow(heatmap_data, cmap='RdYlGn', aspect='auto')
ax.set_xticks(range(len(metrics_names)))
ax.set_xticklabels(metrics_names)
ax.set_yticks(range(len(model_names)))
ax.set_yticklabels(model_names)
ax.set_title('🔥 Model Performance Heatmap')

# Add text annotations
for i in range(len(model_names)):
    for j in range(len(metrics_names)):
        text = ax.text(j, i, f'{heatmap_data[i, j]:.3f}',
                      ha="center", va="center", color="black", fontweight='bold')

plt.colorbar(im)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/piano_transformer/evaluation_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

# Performance Summary Statistics
print("\n📋 DETAILED PERFORMANCE ANALYSIS")
print("="*60)

print(f"\n🎹 Piano Transformer Performance:")
print(f"   • Overall Correlation: {evaluation_results['ast_metrics']['pearson_correlation']:.4f}")
print(f"   • Best Dimension: {DIMENSION_NAMES[best_dim_idx]} (r = {ast_per_dim_corr[best_dim_idx]:.4f})")
print(f"   • Worst Dimension: {DIMENSION_NAMES[worst_dim_idx]} (r = {ast_per_dim_corr[worst_dim_idx]:.4f})")

# Count dimensions by performance tier
excellent_dims = sum(1 for r in ast_per_dim_corr if r >= 0.7)
good_dims = sum(1 for r in ast_per_dim_corr if 0.5 <= r < 0.7)
moderate_dims = sum(1 for r in ast_per_dim_corr if 0.3 <= r < 0.5)
poor_dims = sum(1 for r in ast_per_dim_corr if r < 0.3)

print(f"\n📊 Performance Distribution:")
print(f"   • Excellent (≥0.7): {excellent_dims}/19 dimensions ({excellent_dims/19*100:.1f}%)")
print(f"   • Good (0.5-0.7): {good_dims}/19 dimensions ({good_dims/19*100:.1f}%)")
print(f"   • Moderate (0.3-0.5): {moderate_dims}/19 dimensions ({moderate_dims/19*100:.1f}%)")
print(f"   • Poor (<0.3): {poor_dims}/19 dimensions ({poor_dims/19*100:.1f}%)")

print(f"\n✅ All visualizations saved to Google Drive!")
print(f"   • /content/drive/MyDrive/piano_transformer/evaluation_*.png")

---
## 📝 Cell 6: Final Report and Conclusions
---

In [None]:
print("📝 Generating Final Evaluation Report")
print("="*70)

# Generate comprehensive final report
report = f"""
# 🎹 Piano Perception Transformer - Final Evaluation Report

**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  
**Model:** 12-layer Audio Spectrogram Transformer with Regression Head  
**Dataset:** PercePiano (19 perceptual dimensions)  
**Evaluation:** Comprehensive comparison with baseline models

## 📊 Executive Summary

The Piano Perception Transformer achieved the following performance metrics:

### 🎯 Key Performance Indicators
- **Overall Pearson Correlation:** {evaluation_results['ast_metrics']['pearson_correlation']:.4f}
- **Mean Per-Dimension Correlation:** {evaluation_results['ast_metrics']['mean_per_dim_correlation']:.4f} ± {evaluation_results['ast_metrics']['std_per_dim_correlation']:.4f}
- **Root Mean Squared Error (RMSE):** {evaluation_results['ast_metrics']['rmse']:.4f}
- **Mean Absolute Error (MAE):** {evaluation_results['ast_metrics']['mae']:.4f}
- **R² Score:** {evaluation_results['ast_metrics']['r2_score']:.4f}

### 🏆 Competitive Analysis
- **Model Ranking:** {ast_rank}/{total_models} among all evaluated models
- **Performance vs. Best Model:** {ast_metrics['pearson_correlation']/best_model['Pearson Corr']*100:.1f}%
- **Best Baseline Beaten:** {'✅ Yes' if ast_rank == 1 else '❌ No'}

## 📈 Detailed Performance Analysis

### Per-Dimension Performance Distribution:
- **Excellent (r ≥ 0.7):** {excellent_dims}/19 dimensions ({excellent_dims/19*100:.1f}%)
- **Good (0.5 ≤ r < 0.7):** {good_dims}/19 dimensions ({good_dims/19*100:.1f}%)
- **Moderate (0.3 ≤ r < 0.5):** {moderate_dims}/19 dimensions ({moderate_dims/19*100:.1f}%)
- **Poor (r < 0.3):** {poor_dims}/19 dimensions ({poor_dims/19*100:.1f}%)

### Best Performing Dimensions:
"""

# Add top 5 dimensions
sorted_dims = sorted(enumerate(ast_per_dim_corr), key=lambda x: x[1], reverse=True)
for i, (dim_idx, corr) in enumerate(sorted_dims[:5]):
    report += f"{i+1}. **{DIMENSION_NAMES[dim_idx]}:** {corr:.4f}\n"

report += f"""
### Worst Performing Dimensions:
"""

# Add bottom 5 dimensions
for i, (dim_idx, corr) in enumerate(sorted_dims[-5:]):
    report += f"{i+1}. **{DIMENSION_NAMES[dim_idx]}:** {corr:.4f}\n"

report += f"""
## 🏆 Baseline Comparison

| Model | Pearson Corr | Mean Per-Dim Corr | RMSE | MAE | R² Score |
|-------|--------------|-------------------|------|-----|----------|
"""

# Add comparison table
for _, row in evaluation_results['comparison_df'].iterrows():
    model_name = row['Model'].replace('🎹 ', '').replace('🏆 ', '')
    report += f"| {model_name} | {row['Pearson Corr']:.4f} | {row['Mean Per-Dim Corr']:.4f} | {row['RMSE']:.4f} | {row['MAE']:.4f} | {row['R² Score']:.4f} |\n"

# Performance assessment
if ast_metrics['pearson_correlation'] >= 0.7:
    performance_assessment = "🎉 **OUTSTANDING PERFORMANCE** - The model achieves excellent correlation (≥0.7) indicating strong predictive capability for perceptual dimensions."
elif ast_metrics['pearson_correlation'] >= 0.5:
    performance_assessment = "✅ **GOOD PERFORMANCE** - The model demonstrates solid predictive capability with good correlation (≥0.5) across perceptual dimensions."
elif ast_metrics['pearson_correlation'] >= 0.3:
    performance_assessment = "⚠️ **MODERATE PERFORMANCE** - The model shows moderate predictive capability but has room for significant improvement."
else:
    performance_assessment = "❌ **POOR PERFORMANCE** - The model demonstrates limited predictive capability and requires substantial improvements."

report += f"""
## 🔍 Key Findings

### Overall Assessment
{performance_assessment}

### Strengths
- **Architecture:** The 12-layer Audio Spectrogram Transformer successfully processes mel-spectrograms for perceptual prediction
- **Pre-training Benefits:** {'✅ Pre-training on MAESTRO likely contributed to feature learning' if os.path.exists('/content/drive/MyDrive/piano_transformer/checkpoints/ssast_pretraining/pretrained_for_finetuning.pkl') else '⚠️ No pre-training detected - may limit performance'}
- **Best Dimensions:** Strong performance on {DIMENSION_NAMES[best_dim_idx]} (r = {ast_per_dim_corr[best_dim_idx]:.4f})
- **Consistency:** {'Good' if evaluation_results['ast_metrics']['std_per_dim_correlation'] < 0.2 else 'Variable'} consistency across dimensions (std = {evaluation_results['ast_metrics']['std_per_dim_correlation']:.4f})

### Areas for Improvement
- **Challenging Dimensions:** {DIMENSION_NAMES[worst_dim_idx]} shows lowest performance (r = {ast_per_dim_corr[worst_dim_idx]:.4f})
- **Model Complexity:** Consider architecture adjustments for better dimension-specific modeling
- **Data Augmentation:** Additional data augmentation strategies may improve generalization
- **Loss Function:** Experiment with dimension-weighted or adversarial losses

## 📋 Technical Specifications

### Model Architecture
- **Backbone:** 12-layer Audio Spectrogram Transformer
- **Parameters:** {sum(x.size for x in jax.tree.leaves(model_params)):,} total parameters
- **Input:** 128×128 mel-spectrograms from MIDI synthesis
- **Output:** 19 perceptual dimensions (normalized)
- **Patch Size:** 16×16
- **Embedding Dimension:** 768
- **Attention Heads:** 12 per layer

### Training Configuration
- **Pre-training:** Self-supervised learning on MAESTRO dataset
- **Fine-tuning:** Supervised learning on PercePiano dataset
- **Loss Function:** MSE + correlation-based loss
- **Optimization:** AdamW with cosine learning rate schedule
- **Regularization:** Dropout (0.1) + stochastic depth (0.1)

### Dataset Statistics
- **Training Samples:** {len(train_dataset)}
- **Validation Samples:** {len(val_dataset)}
- **Test Samples:** {len(test_dataset)}
- **Label Normalization:** StandardScaler applied

## 🚀 Recommendations

### For Production Deployment
{'✅ **READY FOR DEPLOYMENT** - Model shows strong performance suitable for practical applications' if ast_metrics['pearson_correlation'] >= 0.6 else '⚠️ **NEEDS IMPROVEMENT** - Additional development recommended before deployment'}

### For Further Research
1. **Architecture Improvements:**
   - Experiment with different patch sizes and attention mechanisms
   - Consider hierarchical or multi-scale processing
   - Investigate dimension-specific attention heads

2. **Training Enhancements:**
   - Implement curriculum learning or progressive training
   - Explore advanced augmentation techniques
   - Consider meta-learning approaches for few-shot adaptation

3. **Evaluation Extensions:**
   - Human evaluation studies for perceptual validation
   - Cross-dataset generalization tests
   - Real-time inference performance optimization

## 📊 Supporting Materials

All evaluation visualizations and detailed results have been saved to:
- `evaluation_model_comparison.png` - Model performance comparison
- `evaluation_per_dimension.png` - Per-dimension analysis
- `evaluation_scatter_plots.png` - Prediction vs. target analysis
- `evaluation_error_analysis.png` - Error distribution analysis
- `evaluation_heatmap.png` - Performance heatmap

---
**Report Generated by Piano Perception Transformer Evaluation Pipeline**  
**Framework:** JAX/Flax | **Hardware:** {jax.devices()[0].device_kind} | **Date:** {datetime.now().strftime('%Y-%m-%d')}
"""

# Save report to file
report_path = '/content/drive/MyDrive/piano_transformer/Final_Evaluation_Report.md'
with open(report_path, 'w') as f:
    f.write(report)

# Display key sections of the report
print("🎹 PIANO PERCEPTION TRANSFORMER - FINAL EVALUATION SUMMARY")
print("="*70)

print(f"\n📊 PERFORMANCE METRICS:")
print(f"   • Overall Correlation: {evaluation_results['ast_metrics']['pearson_correlation']:.4f}")
print(f"   • Mean Per-Dimension: {evaluation_results['ast_metrics']['mean_per_dim_correlation']:.4f} ± {evaluation_results['ast_metrics']['std_per_dim_correlation']:.4f}")
print(f"   • RMSE: {evaluation_results['ast_metrics']['rmse']:.4f}")
print(f"   • R² Score: {evaluation_results['ast_metrics']['r2_score']:.4f}")

print(f"\n🏆 COMPETITIVE RANKING:")
print(f"   • Rank: {ast_rank}/{total_models} models")
print(f"   • vs Best: {ast_metrics['pearson_correlation']/best_model['Pearson Corr']*100:.1f}%")

print(f"\n📈 DIMENSION PERFORMANCE:")
print(f"   • Excellent (≥0.7): {excellent_dims}/19 ({excellent_dims/19*100:.1f}%)")
print(f"   • Good (0.5-0.7): {good_dims}/19 ({good_dims/19*100:.1f}%)")
print(f"   • Needs Work (<0.5): {moderate_dims + poor_dims}/19 ({(moderate_dims + poor_dims)/19*100:.1f}%)")

print(f"\n🎯 ASSESSMENT: {performance_assessment.split(' - ')[0]}")

# Final WandB logging
try:
    wandb.log({
        "final_overall_correlation": evaluation_results['ast_metrics']['pearson_correlation'],
        "final_model_rank": ast_rank,
        "excellent_dimensions_count": excellent_dims,
        "good_dimensions_count": good_dims,
        "evaluation_complete": True
    })
    
    # Save evaluation artifacts
    wandb.save('/content/drive/MyDrive/piano_transformer/evaluation_*.png')
    wandb.save('/content/drive/MyDrive/piano_transformer/Final_Evaluation_Report.md')
    
    print(f"\n✅ Results logged to WandB and artifacts saved")
except:
    print(f"\n⚠️ WandB logging failed - results saved locally")

print(f"\n📋 COMPLETE EVALUATION REPORT SAVED:")
print(f"   📄 {report_path}")
print(f"   📊 /content/drive/MyDrive/piano_transformer/evaluation_*.png")

print(f"\n🎉 COMPREHENSIVE EVALUATION COMPLETED SUCCESSFULLY!")
print(f"🎹 Piano Perception Transformer evaluation pipeline finished.")

# Cleanup
try:
    wandb.finish()
except:
    pass

print(f"\n✨ Thank you for using the Piano Perception Transformer! ✨")

---
## 🎉 Evaluation Complete!

**Final Results:**
- 📊 **Performance Metrics**: Comprehensive correlation and error analysis
- 🏆 **Baseline Comparisons**: Evaluated against 4 baseline models
- 📈 **Visualizations**: 5 detailed performance plots saved
- 📝 **Final Report**: Complete markdown report generated

**Key Deliverables:**
1. **Final Evaluation Report**: `Final_Evaluation_Report.md`
2. **Performance Visualizations**: `evaluation_*.png` files
3. **WandB Logging**: Comprehensive metrics and artifacts
4. **Model Assessment**: Production readiness evaluation

**Next Steps:**
- Review the detailed report for insights and recommendations
- Consider architectural improvements based on dimension analysis
- Plan production deployment or further research iterations

---

**🎹 Piano Perception Transformer Evaluation Pipeline Complete! 🎹**