# 🎹 Piano Perception Transformer - Fine-tuning

**Phase 2: Supervised Fine-tuning on PercePiano Dataset**

This notebook implements supervised fine-tuning of the pre-trained AST model on the PercePiano dataset to predict 19 perceptual dimensions of piano performance.

**Pipeline Overview:**
1. 🔧 **Setup & Environment** - Dependencies, WandB tracking, JAX configuration
2. 📊 **PercePiano Data Loading** - MIDI to spectrogram conversion and labeling
3. 🧠 **Model Architecture** - Load pre-trained AST and add regression head
4. 🎯 **Supervised Fine-tuning** - Train on perceptual prediction task
5. 💾 **Save Fine-tuned Model** - Checkpoint for evaluation

**Input:** Pre-trained AST model from Phase 1  
**Output:** Fine-tuned model ready for evaluation

---
## 🔧 Cell 1: Setup with WandB Integration
---

In [None]:
print("🚀 Setting up Piano Perception Transformer - Fine-tuning Phase...")

# Clone repo (skip if already exists)
import os
if not os.path.exists('piano-perception-transformer'):
    !git clone https://github.com/Jai-Dhiman/piano-perception-transformer.git
else:
    print("Repository already exists, skipping clone...")

%cd piano-perception-transformer

# Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Install enhanced dependencies including ML research tools
print("📦 Installing enhanced dependencies with uv...")
!export PATH="/usr/local/bin:$PATH" && uv pip install --system jax[tpu] flax optax librosa pandas wandb requests zipfile36 scikit-learn scipy seaborn matplotlib pretty_midi soundfile

# Import core libraries
import sys
import json
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import optax
from datetime import datetime
from flax import linen as nn
from flax.training import train_state
import time

# Initialize WandB for fine-tuning tracking
import wandb

try:
    wandb.login()  # This will prompt for API key in Colab
    
    run = wandb.init(
        project="piano-perception-transformer-finetuning",
        name=f"ast-finetuning-{datetime.now().strftime('%Y%m%d-%H%M')}",
        config={
            "phase": "supervised_finetuning",
            "architecture": "Pre-trained AST + Regression Head",
            "model_layers": 12,
            "embed_dim": 768,
            "num_heads": 12,
            "patch_size": 16,
            "learning_rate": 1e-4,  # Lower LR for fine-tuning
            "batch_size": 16,
            "dropout": 0.1,
            "dataset": "PercePiano",
            "target_dimensions": 19,
            "experiment_type": "supervised_finetuning",
            "loss_function": "mse_with_correlation"
        },
        tags=["finetuning", "ast", "percepiano", "supervised", "regression"]
    )
    
    print("✅ WandB initialized successfully!")
    print(f"   • Project: piano-perception-transformer-finetuning")
    print(f"   • Run name: {run.name}")
    print(f"   • Tracking: https://wandb.ai/{run.entity}/{run.project}/runs/{run.id}")
    
except Exception as e:
    print(f"⚠️ WandB initialization failed: {e}")
    print("   • Continuing without experiment tracking")
    print("   • Set up WandB API key: https://wandb.ai/settings")

# Mount Google Drive
from google.colab import drive
print("🔗 Mounting Google Drive...")
drive.mount('/content/drive')

# Create directory structure
base_dir = '/content/drive/MyDrive/piano_transformer'
directories = [
    f'{base_dir}/processed_spectrograms',
    f'{base_dir}/checkpoints/finetuning',
    f'{base_dir}/logs',
    f'{base_dir}/temp'
]

print("📁 Setting up directory structure...")
for directory in directories:
    os.makedirs(directory, exist_ok=True)
    print(f"✅ Created: {directory}")

# Verify JAX setup
print(f"\n🧠 JAX Configuration:")
print(f"   • Backend: {jax.default_backend()}")
print(f"   • Devices: {jax.device_count()}")
print(f"   • Device type: {jax.devices()[0].device_kind}")

print("\n✅ Fine-tuning setup completed!")

---
## 📊 Cell 2: PercePiano Data Loading and Preprocessing
---

In [None]:
# Clone and setup PercePiano dataset
print("📂 Setting up PercePiano dataset...")

# Define PercePiano directory path
percepiano_dir = '/content/drive/MyDrive/PercePiano'

# Clone PercePiano dataset if not exists
if not os.path.exists(percepiano_dir):
    print("📥 Cloning PercePiano dataset repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git {percepiano_dir}
    print("✅ PercePiano dataset cloned successfully!")
else:
    print("✅ PercePiano dataset already exists")

# Verify essential directory structure
required_paths = [
    f'{percepiano_dir}/labels/label_2round_mean_reg_19_with0_rm_highstd0.json',
    f'{percepiano_dir}/virtuoso/data/all_2rounds'
]

print("🔍 Verifying dataset structure...")
missing_paths = []
for path in required_paths:
    if not os.path.exists(path):
        missing_paths.append(path)
    else:
        if path.endswith('.json'):
            print(f"   ✅ Labels file found: {os.path.basename(path)}")
        else:
            midi_count = len([f for f in os.listdir(path) if f.endswith('.mid')])
            print(f"   ✅ MIDI directory found: {midi_count} MIDI files")

if missing_paths:
    print("❌ Missing required files/directories:")
    for path in missing_paths:
        print(f"   • {path}")
    print("\n💡 The dataset may need to be downloaded separately.")
    print("   Please refer to the PercePiano repository instructions.")
    raise FileNotFoundError("PercePiano dataset structure incomplete")

print("✅ Dataset structure verified successfully!")

import json
import numpy as np
import pretty_midi
import librosa
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def midi_to_spectrogram(midi_path, sr=22050, n_mels=128, hop_length=512, n_fft=2048, target_length=128):
    """
    Enhanced MIDI to mel-spectrogram converter with better synthesis
    
    Args:
        midi_path: Path to MIDI file
        sr: Sample rate for audio synthesis
        n_mels: Number of mel frequency bins
        hop_length: Hop length for STFT
        n_fft: FFT size
        target_length: Fixed time dimension for output spectrogram
    
    Returns:
        mel_spectrogram: [time, freq] mel-spectrogram with shape (target_length, n_mels)
    """
    try:
        # Load MIDI file
        midi_data = pretty_midi.PrettyMIDI(midi_path)
        
        # Enhanced synthesis with better defaults
        try:
            audio = midi_data.fluidsynth(fs=sr)  # Use fluidsynth if available
        except:
            audio = midi_data.synthesize(fs=sr)  # Fallback to basic synthesis
        
        # Ensure minimum audio length
        min_duration = 2.0  # seconds
        min_samples = int(min_duration * sr)
        if len(audio) < min_samples:
            # Pad with silence
            padding = min_samples - len(audio)
            audio = np.pad(audio, (0, padding), mode='constant')
        
        # Convert to mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=sr,
            n_mels=n_mels,
            hop_length=hop_length,
            n_fft=n_fft,
            power=2.0,
            fmin=20,  # Lower frequency bound
            fmax=sr//2  # Nyquist frequency
        )
        
        # Convert to log scale
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Transpose to [time, freq] format
        mel_spec_transposed = mel_spec_db.T  # Shape: (time, freq)
        
        # Normalize to fixed time dimension
        current_length = mel_spec_transposed.shape[0]
        
        if current_length >= target_length:
            # Truncate to target length
            normalized_spec = mel_spec_transposed[:target_length, :]
        else:
            # Pad to target length
            pad_width = target_length - current_length
            normalized_spec = np.pad(
                mel_spec_transposed, 
                ((0, pad_width), (0, 0)), 
                mode='constant', 
                constant_values=-80.0  # Use a reasonable silence value in dB scale
            )
        
        return normalized_spec
        
    except Exception as e:
        print(f"Error converting MIDI {midi_path}: {str(e)}")
        return None

def load_percepiano_data(percepiano_dir):
    """Load PercePiano dataset with MIDI-to-spectrogram conversion"""
    labels_file = f'{percepiano_dir}/labels/label_2round_mean_reg_19_with0_rm_highstd0.json'
    
    print(f"📋 Loading labels from: {labels_file}")
    with open(labels_file, 'r') as f:
        labels_data = json.load(f)

    if len(labels_data) == 0:
        raise ValueError("Labels file is empty")

    print(f"📊 Loaded PercePiano labels: {len(labels_data)} samples")

    # Load MIDI files and convert to spectrograms
    midi_dir = f'{percepiano_dir}/virtuoso/data/all_2rounds'
    
    if not os.path.exists(midi_dir):
        raise FileNotFoundError(f"MIDI directory not found: {midi_dir}")
    
    midi_files = [f for f in os.listdir(midi_dir) if f.endswith('.mid')]

    if len(midi_files) == 0:
        raise FileNotFoundError(f"No MIDI files found in: {midi_dir}")

    print(f"🎵 Found {len(midi_files)} MIDI files")

    samples = []
    processed_count = 0

    for filename, label_data in labels_data.items():
        # Find corresponding MIDI file (flexible matching)
        midi_filename = None
        for midi_file in midi_files:
            # Try multiple matching strategies
            if (filename in midi_file or 
                midi_file.replace('.mid', '') in filename or
                filename.replace('.mid', '') in midi_file.replace('.mid', '')):
                midi_filename = midi_file
                break

        if midi_filename is None:
            print(f"⚠️ MIDI file not found for label: {filename}")
            continue

        # Extract the 19 perceptual features
        if isinstance(label_data, list) and len(label_data) >= 19:
            perceptual_features = np.array(label_data[:19], dtype=np.float32)
        else:
            print(f"⚠️ Insufficient label data for {filename}: expected 19 features, got {len(label_data) if isinstance(label_data, list) else 'non-list'}")
            continue

        # Convert MIDI to spectrogram with fixed dimensions
        midi_path = os.path.join(midi_dir, midi_filename)
        spectrogram = midi_to_spectrogram(
            midi_path, 
            sr=22050, 
            n_mels=128, 
            hop_length=512,
            n_fft=2048,
            target_length=128  # Match MAESTRO preprocessing
        )
        
        if spectrogram is not None:
            # Verify expected shape
            expected_shape = (128, 128)  # (time, freq)
            if spectrogram.shape != expected_shape:
                print(f"⚠️ Unexpected spectrogram shape for {midi_filename}: {spectrogram.shape}, expected {expected_shape}")
                continue
                
            samples.append({
                'spectrogram': spectrogram,
                'labels': perceptual_features,
                'filename': filename
            })
            processed_count += 1
            
            # Print progress every 25 files
            if processed_count % 25 == 0:
                print(f"📊 Processed {processed_count} samples...")
        else:
            print(f"⚠️ Failed to convert MIDI: {midi_filename}")
            continue

    if processed_count == 0:
        raise ValueError("No samples were successfully processed")

    print(f"✅ Successfully processed {processed_count} samples")
    return samples

class PercePianoDataset:
    """PercePiano dataset with train/val/test splits and label normalization"""
    
    def __init__(self, samples, split='train', train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_seed=42):
        self.split = split
        self.random_seed = random_seed
        
        # Validate split ratios
        assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-6, "Split ratios must sum to 1.0"
        
        # Create reproducible train/val/test splits
        np.random.seed(random_seed)
        
        # Extract spectrograms and labels
        spectrograms = [s['spectrogram'] for s in samples]
        labels = [s['labels'] for s in samples]
        filenames = [s['filename'] for s in samples]
        
        # First split: train vs (val + test)
        train_specs, temp_specs, train_labels, temp_labels, train_files, temp_files = train_test_split(
            spectrograms, labels, filenames,
            test_size=(val_ratio + test_ratio), 
            random_state=random_seed,
            stratify=None  # Can't stratify continuous labels easily
        )
        
        # Second split: val vs test
        val_size = val_ratio / (val_ratio + test_ratio)
        val_specs, test_specs, val_labels, test_labels, val_files, test_files = train_test_split(
            temp_specs, temp_labels, temp_files,
            test_size=(1 - val_size), 
            random_state=random_seed
        )
        
        # Assign data based on split
        if split == 'train':
            self.spectrograms = np.array(train_specs)
            self.labels = np.array(train_labels)
            self.filenames = train_files
        elif split == 'val':
            self.spectrograms = np.array(val_specs)
            self.labels = np.array(val_labels)
            self.filenames = val_files
        elif split == 'test':
            self.spectrograms = np.array(test_specs)
            self.labels = np.array(test_labels)
            self.filenames = test_files
        else:
            raise ValueError(f"Invalid split: {split}. Must be 'train', 'val', or 'test'")
        
        self.num_samples = len(self.spectrograms)
        
        print(f"📊 PercePiano Split Statistics:")
        print(f"   • Train: {len(train_specs)} samples ({len(train_specs)/len(samples)*100:.1f}%)")
        print(f"   • Val:   {len(val_specs)} samples ({len(val_specs)/len(samples)*100:.1f}%)")
        print(f"   • Test:  {len(test_specs)} samples ({len(test_specs)/len(samples)*100:.1f}%)")
        print(f"   • Using: {self.num_samples} samples for '{split}' split")
        
        # Label normalization (fit on training data only)
        if split == 'train':
            self.label_scaler = StandardScaler()
            self.labels = self.label_scaler.fit_transform(self.labels)
            print(f"✅ Label scaler fitted on training data")
            print(f"   • Original label stats: mean={np.mean(train_labels, axis=0)[:3]}, std={np.std(train_labels, axis=0)[:3]}")
            print(f"   • Normalized label stats: mean={np.mean(self.labels, axis=0)[:3]}, std={np.std(self.labels, axis=0)[:3]}")
        else:
            self.label_scaler = None  # Will be set externally
    
    def set_label_scaler(self, scaler):
        """Set the label scaler for val/test splits"""
        self.label_scaler = scaler
        self.labels = scaler.transform(self.labels)
        print(f"✅ Applied label normalization to {self.split} split")
    
    def __len__(self):
        return self.num_samples
    
    def get_batch(self, batch_size, shuffle=None):
        """Get a batch of spectrograms and labels"""
        if shuffle is None:
            shuffle = (self.split == 'train')
        
        if shuffle:
            indices = np.random.choice(self.num_samples, size=batch_size, replace=True)
        else:
            start_idx = np.random.randint(0, max(1, self.num_samples - batch_size + 1))
            indices = np.arange(start_idx, start_idx + batch_size) % self.num_samples
        
        batch_specs = self.spectrograms[indices]
        batch_labels = self.labels[indices]
        
        return batch_specs, batch_labels

# Load PercePiano dataset
print("\n🔄 Loading and processing PercePiano dataset...")

try:
    # Load raw data
    raw_samples = load_percepiano_data(percepiano_dir)
    
    # Create dataset splits
    train_dataset = PercePianoDataset(raw_samples, split='train', random_seed=42)
    val_dataset = PercePianoDataset(raw_samples, split='val', random_seed=42)
    test_dataset = PercePianoDataset(raw_samples, split='test', random_seed=42)
    
    # Apply label normalization to val/test sets
    val_dataset.set_label_scaler(train_dataset.label_scaler)
    test_dataset.set_label_scaler(train_dataset.label_scaler)
    
    print(f"\n✅ PercePiano datasets created successfully!")
    print(f"   • Training dataset: {len(train_dataset)} samples")
    print(f"   • Validation dataset: {len(val_dataset)} samples")
    print(f"   • Test dataset: {len(test_dataset)} samples")
    
    # Test batch loading
    print(f"\n🧪 Testing data pipeline...")
    train_specs, train_labels = train_dataset.get_batch(4)
    val_specs, val_labels = val_dataset.get_batch(4)
    
    print(f"   • Train batch specs shape: {train_specs.shape}")
    print(f"   • Train batch labels shape: {train_labels.shape}")
    print(f"   • Val batch specs shape: {val_specs.shape}")
    print(f"   • Val batch labels shape: {val_labels.shape}")
    print(f"   • Label stats (normalized): min={train_labels.min():.2f}, max={train_labels.max():.2f}, mean={train_labels.mean():.2f}")
    
    print(f"\n🎯 Ready for supervised fine-tuning!")
    
except Exception as e:
    print(f"❌ PercePiano data loading failed: {e}")
    print(f"   Error details: {str(e)}")
    print(f"\n💡 Troubleshooting tips:")
    print(f"   1. Check if PercePiano repository was cloned correctly")
    print(f"   2. Verify internet connection for cloning")
    print(f"   3. Ensure sufficient disk space in Google Drive")
    raise Exception(f"PercePiano dataset setup failed: {e}")

---
## 🧠 Cell 3: Load Pre-trained Model and Add Regression Head
---

In [None]:
import sys
import os
import pickle
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import linen as nn
from flax.training import train_state
from flax import traverse_util
import time

sys.path.append('/content/piano-perception-transformer/src')

print("🧠 Loading Pre-trained AST and Creating Regression Model")
print("="*60)

class ProductionASTForRegression(nn.Module):
    """AST model with regression head for perceptual prediction
    
    Loads pre-trained AST backbone and adds a regression head for 19 perceptual dimensions
    """
    
    patch_size: int = 16
    embed_dim: int = 768
    num_layers: int = 12
    num_heads: int = 12
    mlp_ratio: float = 4.0
    dropout_rate: float = 0.1
    attention_dropout: float = 0.1
    stochastic_depth_rate: float = 0.1
    num_outputs: int = 19  # 19 perceptual dimensions
    
    def setup(self):
        # Pre-compute stochastic depth drop rates (linearly increasing)
        self.drop_rates = [
            self.stochastic_depth_rate * i / (self.num_layers - 1) 
            for i in range(self.num_layers)
        ]
    
    @nn.compact
    def __call__(self, x, training: bool = True):
        """
        Full AST forward pass with regression head
        Args:
            x: Mel-spectrogram [batch, time, freq] -> (batch, 128, 128)
        Returns:
            predictions: [batch, 19] regression outputs for perceptual dimensions
        """
        batch_size, time_frames, freq_bins = x.shape
        
        # === PATCH EMBEDDING ===
        patch_size = self.patch_size
        
        # Ensure input can be divided into patches  
        time_pad = (patch_size - time_frames % patch_size) % patch_size
        freq_pad = (patch_size - freq_bins % patch_size) % patch_size
        
        if time_pad > 0 or freq_pad > 0:
            x = jnp.pad(x, ((0, 0), (0, time_pad), (0, freq_pad)), mode='constant', constant_values=-80.0)
        
        time_patches = x.shape[1] // patch_size
        freq_patches = x.shape[2] // patch_size
        num_patches = time_patches * freq_patches
        
        # Reshape to patches: [batch, num_patches, patch_dim]
        x = x.reshape(batch_size, time_patches, patch_size, freq_patches, patch_size)
        x = x.transpose(0, 1, 3, 2, 4)
        x = x.reshape(batch_size, num_patches, patch_size * patch_size)
        
        # Linear patch embedding (pre-trained)
        x = nn.Dense(
            self.embed_dim, 
            kernel_init=nn.initializers.truncated_normal(stddev=0.02),
            bias_init=nn.initializers.zeros,
            name='patch_embedding'
        )(x)
        
        # === 2D POSITIONAL ENCODING (pre-trained) ===
        pos_embedding = self.param(
            'pos_embedding',
            nn.initializers.truncated_normal(stddev=0.02),
            (1, num_patches, self.embed_dim)
        )
        x = x + pos_embedding
        
        # Embedding dropout
        x = nn.Dropout(self.dropout_rate, deterministic=not training)(x)
        
        # === 12-LAYER TRANSFORMER ENCODER (pre-trained) ===
        for layer_idx in range(self.num_layers):
            # Stochastic depth probability for this layer
            drop_rate = self.drop_rates[layer_idx]
            
            # Multi-Head Self-Attention Block
            residual = x
            x = nn.LayerNorm(epsilon=1e-6, name=f'norm1_layer{layer_idx}')(x)
            
            attention = nn.MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                dropout_rate=self.attention_dropout,
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'attention_layer{layer_idx}'
            )(x, x, deterministic=not training)
            
            # Stochastic depth for attention (training only)
            if training and drop_rate > 0:
                random_tensor = jax.random.uniform(
                    self.make_rng('stochastic_depth'), (batch_size, 1, 1)
                )
                keep_prob = 1.0 - drop_rate
                binary_tensor = (random_tensor < keep_prob).astype(x.dtype)
                attention = attention * binary_tensor / keep_prob
            
            x = residual + nn.Dropout(self.dropout_rate, deterministic=not training)(attention)
            
            # Feed-Forward Network Block
            residual = x
            x = nn.LayerNorm(epsilon=1e-6, name=f'norm2_layer{layer_idx}')(x)
            
            # MLP with 4x expansion
            mlp_hidden = int(self.embed_dim * self.mlp_ratio)
            
            mlp = nn.Dense(
                mlp_hidden, 
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'mlp_dense1_layer{layer_idx}'
            )(x)
            mlp = nn.gelu(mlp)
            mlp = nn.Dropout(self.dropout_rate, deterministic=not training)(mlp)
            
            mlp = nn.Dense(
                self.embed_dim,
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'mlp_dense2_layer{layer_idx}'
            )(mlp)
            
            # Stochastic depth for MLP
            if training and drop_rate > 0:
                random_tensor = jax.random.uniform(
                    self.make_rng('stochastic_depth'), (batch_size, 1, 1)
                )
                keep_prob = 1.0 - drop_rate
                binary_tensor = (random_tensor < keep_prob).astype(x.dtype)
                mlp = mlp * binary_tensor / keep_prob
            
            x = residual + nn.Dropout(self.dropout_rate, deterministic=not training)(mlp)
        
        # === FINAL NORMALIZATION (pre-trained) ===
        x = nn.LayerNorm(epsilon=1e-6, name='final_norm')(x)
        
        # === REGRESSION HEAD (new, trainable) ===
        # Global average pooling over patches
        x = jnp.mean(x, axis=1)  # [batch, embed_dim]
        
        # Regression layers
        x = nn.Dense(
            512, 
            kernel_init=nn.initializers.truncated_normal(stddev=0.02),
            bias_init=nn.initializers.zeros,
            name='regression_hidden'
        )(x)
        x = nn.gelu(x)
        x = nn.Dropout(self.dropout_rate, deterministic=not training)(x)
        
        # Final prediction layer
        predictions = nn.Dense(
            self.num_outputs,
            kernel_init=nn.initializers.truncated_normal(stddev=0.02),
            bias_init=nn.initializers.zeros,
            name='regression_output'
        )(x)
        
        return predictions  # [batch, 19]

def load_pretrained_weights(model, checkpoint_path):
    """Load pre-trained weights from SSAST checkpoint"""
    print(f"📂 Loading pre-trained checkpoint: {checkpoint_path}")
    
    try:
        with open(checkpoint_path, 'rb') as f:
            checkpoint = pickle.load(f)
        
        pretrained_params = checkpoint['params']
        print(f"✅ Pre-trained checkpoint loaded successfully")
        print(f"   • Checkpoint epoch: {checkpoint.get('epoch', 'N/A')}")
        print(f"   • Checkpoint loss: {checkpoint.get('best_val_loss', 'N/A')}")
        
        return pretrained_params, checkpoint
        
    except Exception as e:
        print(f"❌ Failed to load pre-trained checkpoint: {e}")
        raise

def transfer_pretrained_weights(pretrained_params, new_params):
    """Transfer pre-trained weights to new model with regression head"""
    print("🔄 Transferring pre-trained weights...")
    
    # Flatten parameter trees for easier manipulation
    flat_pretrained = traverse_util.flatten_dict(pretrained_params)
    flat_new = traverse_util.flatten_dict(new_params)
    
    transferred_count = 0
    new_count = 0
    
    for key in flat_new:
        if key in flat_pretrained:
            # Transfer pre-trained weight
            flat_new[key] = flat_pretrained[key]
            transferred_count += 1
        else:
            # Keep randomly initialized weight (regression head)
            new_count += 1
    
    # Reconstruct parameter tree
    final_params = traverse_util.unflatten_dict(flat_new)
    
    print(f"✅ Weight transfer completed:")
    print(f"   • Transferred parameters: {transferred_count}")
    print(f"   • New parameters (regression head): {new_count}")
    
    return final_params

def create_finetuning_optimizer(total_steps, learning_rate=1e-4, weight_decay=0.01, warmup_steps=500):
    """Create optimizer for fine-tuning with lower learning rate"""
    
    # Warmup + cosine decay schedule
    warmup_schedule = optax.linear_schedule(
        init_value=1e-8,
        end_value=learning_rate,
        transition_steps=warmup_steps
    )
    
    cosine_schedule = optax.cosine_decay_schedule(
        init_value=learning_rate,
        decay_steps=total_steps - warmup_steps,
        alpha=0.1  # Final LR = 10% of initial LR for fine-tuning
    )
    
    lr_schedule = optax.join_schedules(
        schedules=[warmup_schedule, cosine_schedule],
        boundaries=[warmup_steps]
    )
    
    # AdamW optimizer with gradient clipping
    optimizer = optax.chain(
        optax.clip_by_global_norm(0.5),  # Lower gradient clipping for fine-tuning
        optax.adamw(
            learning_rate=lr_schedule,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.999,
            eps=1e-8
        )
    )
    
    return optimizer

@jax.jit
def compute_correlation(predictions, targets):
    """Compute Pearson correlation between predictions and targets"""
    # Center the data
    pred_centered = predictions - jnp.mean(predictions, axis=0, keepdims=True)
    target_centered = targets - jnp.mean(targets, axis=0, keepdims=True)
    
    # Compute correlation
    numerator = jnp.sum(pred_centered * target_centered, axis=0)
    pred_norm = jnp.sqrt(jnp.sum(pred_centered**2, axis=0))
    target_norm = jnp.sqrt(jnp.sum(target_centered**2, axis=0))
    
    correlation = numerator / (pred_norm * target_norm + 1e-8)
    
    return jnp.mean(correlation)  # Average correlation across dimensions

@jax.jit
def finetuning_train_step(train_state_obj, batch_specs, batch_labels, dropout_rng, stochastic_rng):
    """Training step for fine-tuning with regression loss"""
    
    def loss_fn(params):
        # Forward pass
        predictions = train_state_obj.apply_fn(
            params, batch_specs,
            training=True,
            rngs={'dropout': dropout_rng, 'stochastic_depth': stochastic_rng}
        )
        
        # MSE loss
        mse_loss = jnp.mean(jnp.square(predictions - batch_labels))
        
        # Correlation-based loss (negative correlation to maximize)
        correlation = compute_correlation(predictions, batch_labels)
        correlation_loss = -correlation  # Negative to maximize correlation
        
        # Combined loss
        total_loss = mse_loss + 0.1 * correlation_loss
        
        # Metrics for monitoring
        metrics = {
            'total_loss': total_loss,
            'mse_loss': mse_loss,
            'correlation': correlation,
            'prediction_mean': jnp.mean(predictions),
            'prediction_std': jnp.std(predictions),
            'target_mean': jnp.mean(batch_labels),
            'target_std': jnp.std(batch_labels)
        }
        
        return total_loss, metrics
    
    # Compute gradients
    (loss_val, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state_obj.params)
    
    # Gradient norm for monitoring
    grad_norm = optax.global_norm(grads)
    
    # Update parameters
    new_train_state = train_state_obj.apply_gradients(grads=grads)
    
    # Get current learning rate - safe extraction from chained optimizer
    try:
        # For optax.chain() with adamw as the second transformation
        current_lr = train_state_obj.opt_state[1].hyperparams['learning_rate']
    except (AttributeError, KeyError, IndexError):
        try:
            # Alternative: try accessing from inner state
            current_lr = train_state_obj.tx.inner_state[1].hyperparams['learning_rate']
        except (AttributeError, KeyError, IndexError):
            # Fallback: use step-based approximation
            current_lr = 1e-4  # Default learning rate
    
    # Update metrics
    metrics.update({
        'grad_norm': grad_norm,
        'learning_rate': current_lr
    })
    
    return new_train_state, metrics

# Load pre-trained model and create regression model
print(f"🏗️ Creating AST model with regression head...")

# Initialize regression model
regression_model = ProductionASTForRegression(
    patch_size=16,
    embed_dim=768,
    num_layers=12,
    num_heads=12,
    mlp_ratio=4.0,
    dropout_rate=0.1,
    attention_dropout=0.1,
    stochastic_depth_rate=0.1,
    num_outputs=19  # 19 perceptual dimensions
)

# Initialize model parameters
dummy_input = jnp.ones((2, 128, 128))
rng = jax.random.PRNGKey(42)
init_rng, dropout_rng, stochastic_rng = jax.random.split(rng, 3)

new_params = regression_model.init(
    {'params': init_rng, 'dropout': dropout_rng, 'stochastic_depth': stochastic_rng},
    dummy_input,
    training=False
)

# Load pre-trained weights
pretrained_checkpoint_path = '/content/drive/MyDrive/piano_transformer/checkpoints/ssast_pretraining/pretrained_for_finetuning.pkl'

if os.path.exists(pretrained_checkpoint_path):
    pretrained_params, pretrained_checkpoint = load_pretrained_weights(regression_model, pretrained_checkpoint_path)
    
    # Transfer pre-trained weights
    final_params = transfer_pretrained_weights(pretrained_params, new_params)
    
    print(f"✅ Pre-trained weights loaded and transferred successfully!")
else:
    print(f"⚠️ Pre-trained checkpoint not found: {pretrained_checkpoint_path}")
    print(f"   Using randomly initialized weights for all layers")
    final_params = new_params

# Count parameters
param_count = sum(x.size for x in jax.tree.leaves(final_params))
print(f"\n📊 Regression Model Statistics:")
print(f"   • Total parameters: {param_count:,}")
print(f"   • Memory usage: ~{param_count * 4 / 1024**2:.1f} MB (FP32)")
print(f"   • Architecture: 12-layer AST + regression head")
print(f"   • Output dimensions: 19 perceptual features")

# Test forward pass
print(f"\n🚀 Testing regression model forward pass...")
output = regression_model.apply(
    final_params, dummy_input,
    training=False,
    rngs={'dropout': dropout_rng, 'stochastic_depth': stochastic_rng}
)

print(f"✅ Forward pass successful!")
print(f"   • Input shape: {dummy_input.shape}")
print(f"   • Output shape: {output.shape}")
print(f"   • Output stats: min={output.min():.4f}, max={output.max():.4f}, mean={output.mean():.4f}")

print(f"\n🎯 Ready for supervised fine-tuning!")

---
## 🎯 Cell 4: Execute Supervised Fine-tuning
---

In [None]:
import sys
import os

sys.path.append('./src')

print("🎯 SUPERVISED FINE-TUNING - EXECUTION")
print("="*70)

# Check prerequisites
if 'train_dataset' not in locals():
    raise RuntimeError("Run Cell 2 first to load PercePiano dataset")

if 'regression_model' not in locals():
    raise RuntimeError("Run Cell 3 first to create regression model")

print("✅ All prerequisites ready")
print(f"   • PercePiano datasets loaded: ✅")
print(f"   • Pre-trained AST with regression head: ✅")
print(f"   • Fine-tuning pipeline: ✅")
print(f"   • WandB experiment tracking: ✅")

def execute_supervised_finetuning(
    model, params, train_dataset, val_dataset,
    num_epochs=30, batch_size=16, patience=10
):
    """Execute supervised fine-tuning on PercePiano dataset"""
    print("🎯 Starting Supervised Fine-tuning...")
    print("="*60)
    
    # Initialize random state
    rng = jax.random.PRNGKey(42)
    
    # Training configuration
    train_size = len(train_dataset)
    steps_per_epoch = max(train_size // batch_size, 5)
    total_steps = num_epochs * steps_per_epoch
    
    print(f"📊 Fine-tuning Configuration:")
    print(f"   • Model: {model.__class__.__name__}")
    print(f"   • Parameters: {sum(x.size for x in jax.tree.leaves(params)):,}")
    print(f"   • Train size: {train_size} samples")
    print(f"   • Val size: {len(val_dataset)} samples")
    print(f"   • Batch size: {batch_size}")
    print(f"   • Steps per epoch: {steps_per_epoch}")
    print(f"   • Total steps: {total_steps:,}")
    print(f"   • Epochs: {num_epochs}")
    print(f"   • Early stopping patience: {patience}")
    
    # Create fine-tuning optimizer with lower learning rate
    optimizer = create_finetuning_optimizer(
        total_steps=total_steps,
        learning_rate=1e-4,  # Lower than pre-training
        weight_decay=0.01,
        warmup_steps=total_steps // 10
    )
    
    # Create training state
    train_state_obj = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )
    
    # Training tracking
    best_val_correlation = -1.0  # Best correlation (higher is better)
    patience_counter = 0
    training_history = {
        'train_loss': [],
        'val_loss': [],
        'train_correlation': [],
        'val_correlation': [],
        'learning_rates': []
    }
    
    # Create checkpoint directory
    checkpoint_dir = '/content/drive/MyDrive/piano_transformer/checkpoints/finetuning'
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    print(f"\n🎯 Starting fine-tuning loop...")
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        # === TRAINING PHASE ===
        train_metrics = []
        
        for step in range(steps_per_epoch):
            # Get training batch
            batch_specs, batch_labels = train_dataset.get_batch(batch_size, shuffle=True)
            batch_specs = jnp.array(batch_specs)
            batch_labels = jnp.array(batch_labels)
            
            # Generate RNG keys for this step
            rng, dropout_rng, stochastic_rng = jax.random.split(rng, 3)
            
            # Training step
            train_state_obj, metrics = finetuning_train_step(
                train_state_obj, batch_specs, batch_labels, dropout_rng, stochastic_rng
            )
            
            train_metrics.append(metrics)
            
            # Log to WandB every 5 steps
            if step % 5 == 0:
                try:
                    wandb.log({
                        "train/total_loss": float(metrics['total_loss']),
                        "train/mse_loss": float(metrics['mse_loss']),
                        "train/correlation": float(metrics['correlation']),
                        "train/prediction_mean": float(metrics['prediction_mean']),
                        "train/prediction_std": float(metrics['prediction_std']),
                        "train/grad_norm": float(metrics['grad_norm']),
                        "train/learning_rate": float(metrics['learning_rate']),
                        "epoch": epoch,
                        "step": int(train_state_obj.step)
                    })
                except:
                    pass  # Continue if WandB fails
        
        # === VALIDATION PHASE ===
        val_metrics = []
        val_steps = max(len(val_dataset) // batch_size, 1)
        
        for val_step in range(val_steps):
            batch_specs, batch_labels = val_dataset.get_batch(batch_size, shuffle=False)
            batch_specs = jnp.array(batch_specs)
            batch_labels = jnp.array(batch_labels)
            
            # Validation forward pass (no training)
            rng, dropout_rng, stochastic_rng = jax.random.split(rng, 3)
            
            predictions = model.apply(
                train_state_obj.params, batch_specs,
                training=False,
                rngs={'dropout': dropout_rng, 'stochastic_depth': stochastic_rng}
            )
            
            # Compute validation metrics
            val_mse = jnp.mean(jnp.square(predictions - batch_labels))
            val_correlation = compute_correlation(predictions, batch_labels)
            val_loss = val_mse - 0.1 * val_correlation  # Same as training loss
            
            val_metrics.append({
                'val_loss': val_loss,
                'val_mse': val_mse,
                'val_correlation': val_correlation
            })
        
        # === EPOCH SUMMARY ===
        # Average metrics
        avg_train_loss = np.mean([m['total_loss'] for m in train_metrics])
        avg_train_correlation = np.mean([m['correlation'] for m in train_metrics])
        avg_val_loss = np.mean([m['val_loss'] for m in val_metrics])
        avg_val_correlation = np.mean([m['val_correlation'] for m in val_metrics])
        avg_lr = np.mean([m['learning_rate'] for m in train_metrics])
        
        # Store history
        training_history['train_loss'].append(avg_train_loss)
        training_history['val_loss'].append(avg_val_loss)
        training_history['train_correlation'].append(avg_train_correlation)
        training_history['val_correlation'].append(avg_val_correlation)
        training_history['learning_rates'].append(avg_lr)
        
        epoch_time = time.time() - epoch_start
        total_time = time.time() - start_time
        
        print(f"Epoch {epoch+1:3d}: "
              f"Train Corr={avg_train_correlation:.4f}, "
              f"Val Corr={avg_val_correlation:.4f}, "
              f"Val Loss={avg_val_loss:.4f}, "
              f"LR={avg_lr:.6f}, "
              f"Time={epoch_time:.1f}s")
        
        # Log epoch metrics to WandB
        try:
            wandb.log({
                "epoch/train_loss": avg_train_loss,
                "epoch/val_loss": avg_val_loss,
                "epoch/train_correlation": avg_train_correlation,
                "epoch/val_correlation": avg_val_correlation,
                "epoch/learning_rate": avg_lr,
                "epoch/time_seconds": epoch_time,
                "epoch/total_time_hours": total_time / 3600,
                "epoch/epoch": epoch + 1
            })
        except:
            pass
        
        # === EARLY STOPPING & CHECKPOINTING ===
        improved = avg_val_correlation > best_val_correlation
        
        if improved:
            best_val_correlation = avg_val_correlation
            patience_counter = 0
            
            # Save best model
            best_checkpoint = {
                'params': train_state_obj.params,
                'step': train_state_obj.step,
                'epoch': epoch + 1,
                'best_val_correlation': best_val_correlation,
                'val_loss': avg_val_loss,
                'training_history': training_history,
                'model_config': {
                    'embed_dim': 768,
                    'num_layers': 12,
                    'num_heads': 12,
                    'patch_size': 16,
                    'num_outputs': 19
                },
                'label_scaler': train_dataset.label_scaler  # Include scaler for inference
            }
            
            best_path = os.path.join(checkpoint_dir, 'best_finetuned_model.pkl')
            with open(best_path, 'wb') as f:
                pickle.dump(best_checkpoint, f)
            
            print(f"   ✅ New best model saved (val_correlation: {best_val_correlation:.4f})")
            
        else:
            patience_counter += 1
            print(f"   ⏳ No improvement ({patience_counter}/{patience})")
        
        # Regular checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pkl')
            regular_checkpoint = {
                'params': train_state_obj.params,
                'step': train_state_obj.step,
                'epoch': epoch + 1,
                'val_correlation': avg_val_correlation,
                'val_loss': avg_val_loss,
                'training_history': training_history,
                'label_scaler': train_dataset.label_scaler
            }
            with open(checkpoint_path, 'wb') as f:
                pickle.dump(regular_checkpoint, f)
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping after {patience} epochs without improvement")
            print(f"   Best validation correlation: {best_val_correlation:.4f}")
            break
    
    # === FINE-TUNING COMPLETE ===
    total_training_time = time.time() - start_time
    
    print(f"\n" + "="*60)
    print(f"🎉 SUPERVISED FINE-TUNING COMPLETED!")
    print(f"="*60)
    print(f"📈 Final Results:")
    print(f"   • Best validation correlation: {best_val_correlation:.4f}")
    print(f"   • Final validation loss: {avg_val_loss:.4f}")
    print(f"   • Total epochs: {epoch + 1}")
    print(f"   • Total steps: {train_state_obj.step:,}")
    print(f"   • Training time: {total_training_time/60:.1f} minutes")
    print(f"   • Final learning rate: {avg_lr:.2e}")
    
    return train_state_obj, best_val_correlation, training_history

# Execute supervised fine-tuning
try:
    print(f"\n🎯 Starting Supervised Fine-tuning on PercePiano dataset...")
    print(f"   • Using pre-trained AST with regression head")
    print(f"   • Training set: {len(train_dataset)} samples (MIDI→spectrogram)")
    print(f"   • Validation set: {len(val_dataset)} samples")
    print(f"   • Target: 19 perceptual dimensions")
    print(f"   • Lower learning rate for stable fine-tuning")
    print(f"   • Early stopping with patience=10")
    print(f"   • Correlation-based evaluation")
    
    # Execute fine-tuning
    final_state, best_correlation, history = execute_supervised_finetuning(
        model=regression_model,
        params=final_params,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        num_epochs=30,   # Sufficient for fine-tuning
        batch_size=16,   # Smaller batch size for fine-tuning
        patience=10      # Early stopping patience
    )
    
    print(f"\n🎉 SUPERVISED FINE-TUNING COMPLETED SUCCESSFULLY!")
    print(f"="*70)
    
    # Save final fine-tuned model for evaluation
    finetuned_model_path = '/content/drive/MyDrive/piano_transformer/checkpoints/finetuning/final_finetuned_model.pkl'
    final_checkpoint = {
        'params': final_state.params,
        'model_config': {
            'embed_dim': 768,
            'num_layers': 12,
            'num_heads': 12,
            'patch_size': 16,
            'num_outputs': 19
        },
        'finetuning_results': {
            'best_val_correlation': float(best_correlation),
            'total_epochs': len(history['train_loss']),
            'convergence_achieved': best_correlation > 0.5  # Good correlation threshold
        },
        'label_scaler': train_dataset.label_scaler,
        'training_complete': True
    }
    
    with open(finetuned_model_path, 'wb') as f:
        pickle.dump(final_checkpoint, f)
    
    print(f"💾 Fine-tuned model saved for evaluation: {finetuned_model_path}")
    print(f"🎯 READY FOR COMPREHENSIVE EVALUATION!")
    
    # Performance summary
    if best_correlation > 0.7:
        print(f"🎉 Excellent correlation achieved: {best_correlation:.4f} > 0.70")
    elif best_correlation > 0.5:
        print(f"✅ Good correlation achieved: {best_correlation:.4f} > 0.50")
    elif best_correlation > 0.3:
        print(f"⚠️ Moderate correlation: {best_correlation:.4f} > 0.30")
    else:
        print(f"❌ Low correlation: {best_correlation:.4f} < 0.30 - May need more training")
    
except Exception as e:
    print(f"❌ Supervised fine-tuning failed: {str(e)}")
    raise

---
## 🎯 Fine-tuning Complete!

**Next Steps:**
1. 📊 **Evaluation**: Run `3_Piano_Transformer_Evaluation.ipynb` for comprehensive performance analysis
2. 🔍 **Analysis**: Examine per-dimension correlations and model interpretability
3. 📈 **Comparison**: Compare with baseline models and human performance

**Fine-tuned Model Location:**
```
/content/drive/MyDrive/piano_transformer/checkpoints/finetuning/final_finetuned_model.pkl
```

**Model Configuration:**
- **Architecture**: Pre-trained 12-layer AST + regression head
- **Input**: 128×128 mel-spectrograms from MIDI synthesis
- **Output**: 19 normalized perceptual dimensions
- **Training**: Supervised learning with correlation-based loss
---