# 🎹 Piano Perception Transformer - SSAST Pre-training

**Phase 1: Self-Supervised Pre-training on MAESTRO Dataset**

This notebook implements SSAST (Self-Supervised Audio Spectrogram Transformer) pre-training on the MAESTRO dataset to learn general audio representations that will be fine-tuned for perceptual tasks.

**Pipeline Overview:**
1. 🔧 **Setup & Environment** - Dependencies, WandB tracking, JAX configuration
2. 💾 **MAESTRO Data Processing** - Streaming download and spectrogram conversion
3. 📊 **Dataset Creation** - Train/val/test splits with augmentation
4. 🧠 **AST Model Architecture** - 12-layer production transformer
5. 🚀 **SSAST Pre-training** - Self-supervised learning execution

**Output:** Pre-trained model checkpoint ready for fine-tuning

---
## 🔧 Cell 1: Enhanced Setup with WandB Integration
---

In [None]:
print("🚀 Setting up Piano Perception Transformer - Production Version...")

# Clone repo (skip if already exists)
import os
if not os.path.exists('piano-perception-transformer'):
    !git clone https://github.com/Jai-Dhiman/piano-perception-transformer.git
else:
    print("Repository already exists, skipping clone...")

%cd piano-perception-transformer

# Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Install enhanced dependencies including ML research tools
print("📦 Installing enhanced dependencies with uv...")
!export PATH="/usr/local/bin:$PATH" && uv pip install --system jax[tpu] flax optax librosa pandas wandb requests zipfile36 scikit-learn scipy seaborn matplotlib pretty_midi soundfile

# Initialize WandB for experiment tracking
import wandb
import jax
from datetime import datetime

# WandB Setup
try:
    wandb.login()  # This will prompt for API key in Colab
    
    run = wandb.init(
        project="piano-perception-transformer-pretraining",
        name=f"ssast-pretraining-{datetime.now().strftime('%Y%m%d-%H%M')}",
        config={
            "phase": "ssast_pretraining",
            "architecture": "Production AST-SSAST",
            "model_layers": 12,
            "embed_dim": 768,
            "num_heads": 12,
            "patch_size": 16,
            "learning_rate": 5e-5,
            "batch_size": 32,
            "dropout": 0.1,
            "stochastic_depth": 0.1,
            "dataset": "MAESTRO-v3",
            "experiment_type": "self_supervised_pretraining"
        },
        tags=["pretraining", "ssast", "maestro", "self-supervised"]
    )
    
    print("✅ WandB initialized successfully!")
    print(f"   • Project: piano-perception-transformer-pretraining")
    print(f"   • Run name: {run.name}")
    print(f"   • Tracking: https://wandb.ai/{run.entity}/{run.project}/runs/{run.id}")
    
except Exception as e:
    print(f"⚠️ WandB initialization failed: {e}")
    print("   • Continuing without experiment tracking")
    print("   • Set up WandB API key: https://wandb.ai/settings")

# Verify JAX setup
print(f"\n🧠 JAX Configuration:")
print(f"   • Backend: {jax.default_backend()}")
print(f"   • Devices: {jax.device_count()}")
print(f"   • Device type: {jax.devices()[0].device_kind}")

print("\n✅ Enhanced setup completed!")

---
## 💾 Cell 2: Mount Google Drive & Setup Storage
---

In [None]:
from google.colab import drive
import os

print("🔗 Mounting Google Drive for persistent storage...")
drive.mount('/content/drive')

# Create directory structure
base_dir = '/content/drive/MyDrive/piano_transformer'
directories = [
    f'{base_dir}/processed_spectrograms',
    f'{base_dir}/checkpoints/ssast_pretraining',
    f'{base_dir}/logs',
    f'{base_dir}/temp'
]

print("📁 Setting up directory structure...")
for directory in directories:
    os.makedirs(directory, exist_ok=True)
    print(f"✅ Created: {directory}")

print("\n✅ Google Drive mounted and directories ready!")

---
## 🌊 Cell 3: Streaming MAESTRO Processing
---

In [None]:
import os
import requests
import json
import librosa
import numpy as np
import zipfile
import tempfile
from pathlib import Path
import sys
from io import BytesIO
sys.path.append('./src')

print("🌊 Starting streaming MAESTRO processing...")

def download_and_process_maestro_streaming(max_files=None):
    """Download MAESTRO ZIP as stream, extract and process audio→spectrograms, save to Drive"""
    
    # Download metadata first to get real file paths
    print("📋 Downloading MAESTRO metadata...")
    metadata_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.json"
    
    try:
        metadata_response = requests.get(metadata_url, timeout=30)
        metadata_response.raise_for_status()
        maestro_metadata = metadata_response.json()
    except requests.exceptions.RequestException as e:
        print(f"❌ Failed to download metadata: {e}")
        raise Exception(f"Cannot download MAESTRO metadata: {e}")
    
    print(f"📊 Found metadata for MAESTRO dataset")
    
    # Save metadata to Drive
    try:
        with open('/content/drive/MyDrive/piano_transformer/maestro_metadata.json', 'w') as f:
            json.dump(maestro_metadata, f)
        print("✅ Metadata saved to Drive")
    except IOError as e:
        print(f"❌ Failed to save metadata: {e}")
        raise Exception(f"Cannot save metadata to Drive: {e}")
    
    # Process MAESTRO metadata structure
    if not isinstance(maestro_metadata, dict):
        raise Exception(f"Expected dict metadata, got {type(maestro_metadata)}")
    
    # Check for required fields
    required_fields = ['audio_filename', 'canonical_composer', 'canonical_title']
    for field in required_fields:
        if field not in maestro_metadata:
            raise Exception(f"Required field '{field}' not found in metadata. Available fields: {list(maestro_metadata.keys())}")
    
    # Get the audio filenames from the pandas-style structure
    audio_filenames = maestro_metadata['audio_filename']
    if not isinstance(audio_filenames, dict):
        raise Exception(f"Expected dict for audio_filename field, got {type(audio_filenames)}")
    
    total_files = len(audio_filenames)
    print(f"📝 Found {total_files} audio files in metadata")
    
    # Get list of audio files to process
    target_files = set()
    files_to_process = list(audio_filenames.items())
    if max_files:
        files_to_process = files_to_process[:max_files]
        print(f"🎯 Processing first {max_files} files for demo/testing")
    else:
        print(f"🎯 Processing all {total_files} files")
    
    for idx, filename in files_to_process:
        if filename and isinstance(filename, str) and filename.endswith('.wav'):
            target_files.add(filename)
    
    if not target_files:
        raise Exception("No valid .wav files found in metadata")
    
    print(f"🎵 Target: {len(target_files)} audio files from ZIP")
    
    # Download and stream process the MAESTRO ZIP
    zip_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip"
    print(f"📦 Downloading MAESTRO ZIP stream from: {zip_url}")
    
    processed_count = 0
    
    try:
        # Stream download the ZIP file
        with requests.get(zip_url, stream=True, timeout=300) as zip_response:
            zip_response.raise_for_status()
            
            print("✅ ZIP stream connected, processing...")
            
            # Create a temporary file to hold the ZIP stream
            with tempfile.NamedTemporaryFile(suffix='.zip') as temp_zip:
                # Download ZIP in chunks to avoid memory issues
                total_size = int(zip_response.headers.get('content-length', 0))
                downloaded = 0
                
                print(f"📊 ZIP size: {total_size / (1024**3):.1f}GB")
                
                for chunk in zip_response.iter_content(chunk_size=8192 * 1024):  # 8MB chunks
                    if chunk:
                        temp_zip.write(chunk)
                        downloaded += len(chunk)
                        
                        # Show progress every 1GB
                        if downloaded % (1024**3) < (8192 * 1024):
                            progress = (downloaded / total_size) * 100 if total_size > 0 else 0
                            print(f"📥 Downloaded: {downloaded / (1024**3):.1f}GB ({progress:.1f}%)")
                
                print("✅ ZIP download completed, extracting audio files...")
                temp_zip.seek(0)  # Reset file pointer
                
                # Process ZIP contents
                with zipfile.ZipFile(temp_zip, 'r') as zip_file:
                    # Get list of files in ZIP
                    zip_files = zip_file.namelist()
                    audio_files_in_zip = [f for f in zip_files if f.endswith('.wav')]
                    
                    print(f"📂 Found {len(audio_files_in_zip)} audio files in ZIP")
                    
                    # Process target files found in ZIP
                    for zip_audio_path in audio_files_in_zip:
                        # Check if this file is in our target list
                        audio_filename = Path(zip_audio_path).name
                        if not any(audio_filename in target_file for target_file in target_files):
                            continue
                            
                        try:
                            print(f"🎛️ Processing: {audio_filename}...")
                            
                            # Extract audio file to memory
                            with zip_file.open(zip_audio_path) as audio_file:
                                audio_data = audio_file.read()
                            
                            # Save to temp file for librosa
                            with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_audio:
                                temp_audio.write(audio_data)
                                temp_audio_path = temp_audio.name
                            
                            try:
                                # Load audio (limit duration to save memory)
                                y, sr = librosa.load(temp_audio_path, sr=22050, duration=60.0)  # 60 seconds
                                
                                # Generate mel-spectrogram
                                mel_spec = librosa.feature.melspectrogram(
                                    y=y, sr=sr, n_fft=2048, hop_length=512, n_mels=128
                                )
                                mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
                                
                                # Save spectrogram to Drive
                                spec_filename = Path(audio_filename).stem + '_mel.npy'
                                spec_path = f'/content/drive/MyDrive/piano_transformer/processed_spectrograms/{spec_filename}'
                                
                                np.save(spec_path, mel_spec_db)
                                print(f"✅ Saved: {spec_filename} (shape: {mel_spec_db.shape})")
                                processed_count += 1
                                
                                # Check if we've reached our target (if max_files is set)
                                if max_files and processed_count >= max_files:
                                    print(f"🎯 Reached target limit of {processed_count} files")
                                    break
                                    
                            except Exception as audio_error:
                                print(f"❌ Audio processing error: {audio_error}")
                                continue
                            finally:
                                # Cleanup temp audio file
                                if os.path.exists(temp_audio_path):
                                    os.remove(temp_audio_path)
                                    
                        except Exception as extract_error:
                            print(f"❌ Extraction error for {zip_audio_path}: {extract_error}")
                            continue
                        
                        # Storage check periodically
                        if processed_count % 10 == 0:
                            try:
                                storage_info = os.statvfs('/content')
                                free_gb = (storage_info.f_bavail * storage_info.f_frsize) / (1024**3)
                                print(f"💾 Storage: {free_gb:.1f}GB free, {processed_count} files processed")
                            except OSError:
                                pass
                        
                        # Break if we've reached our target
                        if max_files and processed_count >= max_files:
                            break
    
    except requests.exceptions.RequestException as download_error:
        raise Exception(f"Failed to download MAESTRO ZIP: {download_error}")
    except zipfile.BadZipFile as zip_error:
        raise Exception(f"Invalid ZIP file: {zip_error}")
    except Exception as general_error:
        raise Exception(f"Processing error: {general_error}")
    
    print(f"\n🎉 Streaming processing completed!")
    print(f"✅ Successfully processed: {processed_count} files")
    print(f"💾 Spectrograms saved to: /content/drive/MyDrive/piano_transformer/processed_spectrograms/")
    
    if processed_count == 0:
        raise Exception("No files were successfully processed")
    
    return processed_count


# Run streaming processing with proper error handling
try:
    # Set max_files=None to process all files, or set a number for testing
    # For testing: max_files=50
    # For full dataset: max_files=None
    num_processed = download_and_process_maestro_streaming(max_files=None)
    print(f"\n✅ SUCCESS: {num_processed} MAESTRO files processed!")
    print("🎯 Ready to proceed with pre-training on processed spectrograms")
        
except Exception as main_error:
    print(f"❌ Processing failed: {main_error}")
    raise Exception(f"MAESTRO processing failed: {main_error}")

---
## 📊 Cell 4: Enhanced Dataset Setup with Train/Val/Test Splits
---

In [None]:
import os
import numpy as np
import jax.numpy as jnp
from pathlib import Path
import random
from sklearn.model_selection import train_test_split

print("📊 Enhanced Dataset Setup with Proper Splits & Augmentation")
print("="*60)

class EnhancedMAESTRODataset:
    """Enhanced MAESTRO dataset with proper train/val/test splits and augmentation"""
    
    def __init__(self, spec_dir, split='train', train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, 
                 augmentation=True, target_shape=(128, 128), random_seed=42):
        """
        Initialize dataset with proper splits
        
        Args:
            spec_dir: Directory containing processed spectrograms
            split: 'train', 'val', or 'test'
            train_ratio: Proportion for training
            val_ratio: Proportion for validation  
            test_ratio: Proportion for testing
            augmentation: Whether to apply data augmentation (train only)
            target_shape: Target spectrogram shape (time, freq)
            random_seed: Random seed for reproducible splits
        """
        self.spec_dir = spec_dir
        self.split = split
        self.augmentation = augmentation and (split == 'train')
        self.target_shape = target_shape
        self.random_seed = random_seed
        
        # Validate split ratios
        assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-6, "Split ratios must sum to 1.0"
        
        # Get all spectrogram files
        all_files = [f for f in os.listdir(spec_dir) if f.endswith('_mel.npy')]
        
        if len(all_files) == 0:
            raise FileNotFoundError(f"No spectrogram files found in {spec_dir}")
        
        print(f"📁 Found {len(all_files)} total spectrogram files")
        
        # Create reproducible train/val/test splits
        np.random.seed(random_seed)
        random.seed(random_seed)
        
        # First split: train vs (val + test)
        train_files, temp_files = train_test_split(
            all_files, 
            test_size=(val_ratio + test_ratio), 
            random_state=random_seed
        )
        
        # Second split: val vs test
        val_size = val_ratio / (val_ratio + test_ratio)
        val_files, test_files = train_test_split(
            temp_files, 
            test_size=(1 - val_size), 
            random_state=random_seed
        )
        
        # Assign files based on split
        if split == 'train':
            self.files = train_files
        elif split == 'val':
            self.files = val_files
        elif split == 'test':
            self.files = test_files
        else:
            raise ValueError(f"Invalid split: {split}. Must be 'train', 'val', or 'test'")
        
        self.num_files = len(self.files)
        
        print(f"📊 Split Statistics:")
        print(f"   • Train: {len(train_files)} files ({len(train_files)/len(all_files)*100:.1f}%)")
        print(f"   • Val:   {len(val_files)} files ({len(val_files)/len(all_files)*100:.1f}%)")
        print(f"   • Test:  {len(test_files)} files ({len(test_files)/len(all_files)*100:.1f}%)")
        print(f"   • Using: {self.num_files} files for '{split}' split")
        
        if self.augmentation:
            print(f"✨ Data augmentation enabled for training")
        else:
            print(f"🔒 No augmentation (validation/test mode)")
    
    def __len__(self):
        return self.num_files
    
    def load_spectrogram(self, filename):
        """Load and normalize a single spectrogram"""
        filepath = os.path.join(self.spec_dir, filename)
        
        try:
            spec = np.load(filepath)
            
            # Transpose to [time, freq] if needed
            if spec.shape[0] > spec.shape[1]:  # Likely [freq, time], need [time, freq]
                spec = spec.T
            
            # Resize to target shape
            target_time, target_freq = self.target_shape
            current_time, current_freq = spec.shape
            
            # Handle time dimension
            if current_time >= target_time:
                # Truncate
                spec = spec[:target_time, :]
            else:
                # Pad with silence (-80 dB)
                pad_width = target_time - current_time
                spec = np.pad(spec, ((0, pad_width), (0, 0)), mode='constant', constant_values=-80.0)
            
            # Handle frequency dimension 
            if current_freq >= target_freq:
                # Truncate 
                spec = spec[:, :target_freq]
            else:
                # Pad with silence
                pad_width = target_freq - current_freq
                spec = np.pad(spec, ((0, 0), (0, pad_width)), mode='constant', constant_values=-80.0)
            
            # Final shape verification
            assert spec.shape == self.target_shape, f"Shape mismatch: got {spec.shape}, expected {self.target_shape}"
            
            return spec.astype(np.float32)
            
        except Exception as e:
            print(f"❌ Error loading {filename}: {e}")
            # Return silence spectrogram as fallback
            return np.full(self.target_shape, -80.0, dtype=np.float32)
    
    def augment_spectrogram(self, spec):
        """Apply data augmentation to spectrogram"""
        if not self.augmentation:
            return spec
        
        # Time masking (SpecAugment style)
        if np.random.random() < 0.5:
            time_mask_length = np.random.randint(1, min(20, spec.shape[0] // 4))
            time_mask_start = np.random.randint(0, spec.shape[0] - time_mask_length)
            spec = spec.copy()
            spec[time_mask_start:time_mask_start + time_mask_length, :] = -80.0
        
        # Frequency masking
        if np.random.random() < 0.5:
            freq_mask_length = np.random.randint(1, min(15, spec.shape[1] // 4))
            freq_mask_start = np.random.randint(0, spec.shape[1] - freq_mask_length)
            spec = spec.copy()
            spec[:, freq_mask_start:freq_mask_start + freq_mask_length] = -80.0
        
        # Gaussian noise
        if np.random.random() < 0.3:
            noise_factor = np.random.uniform(0.01, 0.05)
            noise = np.random.normal(0, noise_factor, spec.shape)
            spec = spec + noise
        
        # Volume scaling
        if np.random.random() < 0.4:
            scale_factor = np.random.uniform(0.8, 1.2)
            spec = spec * scale_factor
        
        return spec
    
    def get_batch(self, batch_size, shuffle=None):
        """Get a batch of spectrograms"""
        # Default shuffle behavior: True for train, False for val/test
        if shuffle is None:
            shuffle = (self.split == 'train')
        
        if shuffle:
            # Random sampling with replacement for training
            indices = np.random.choice(self.num_files, size=batch_size, replace=True)
        else:
            # Sequential sampling for consistent validation/test results
            start_idx = np.random.randint(0, max(1, self.num_files - batch_size + 1))
            indices = np.arange(start_idx, start_idx + batch_size) % self.num_files
        
        batch_specs = []
        
        for idx in indices:
            filename = self.files[idx]
            spec = self.load_spectrogram(filename)
            spec = self.augment_spectrogram(spec)  # Will be no-op if augmentation disabled
            batch_specs.append(spec)
        
        return np.array(batch_specs)

# Initialize enhanced datasets with proper splits
spec_dir = '/content/drive/MyDrive/piano_transformer/processed_spectrograms'

print(f"\n🔧 Creating enhanced MAESTRO datasets...")

try:
    # Create datasets for each split
    train_dataset = EnhancedMAESTRODataset(
        spec_dir=spec_dir, 
        split='train',
        train_ratio=0.7, 
        val_ratio=0.15, 
        test_ratio=0.15,
        augmentation=True,  # Enable augmentation for training
        target_shape=(128, 128),
        random_seed=42
    )
    
    val_dataset = EnhancedMAESTRODataset(
        spec_dir=spec_dir,
        split='val', 
        train_ratio=0.7, 
        val_ratio=0.15, 
        test_ratio=0.15,
        augmentation=False,  # No augmentation for validation
        target_shape=(128, 128),
        random_seed=42
    )
    
    test_dataset = EnhancedMAESTRODataset(
        spec_dir=spec_dir,
        split='test',
        train_ratio=0.7, 
        val_ratio=0.15, 
        test_ratio=0.15,
        augmentation=False,  # No augmentation for testing
        target_shape=(128, 128),
        random_seed=42
    )
    
    print(f"\n✅ Enhanced datasets created successfully!")
    print(f"   • Training dataset: {len(train_dataset)} files (with augmentation)")
    print(f"   • Validation dataset: {len(val_dataset)} files (no augmentation)")
    print(f"   • Test dataset: {len(test_dataset)} files (no augmentation)")
    
    # Test batch loading with augmentation
    print(f"\n🧪 Testing enhanced data pipeline...")
    train_batch = train_dataset.get_batch(4)
    val_batch = val_dataset.get_batch(4)
    
    print(f"   • Train batch shape: {train_batch.shape}")
    print(f"   • Val batch shape: {val_batch.shape}")
    print(f"   • Train batch stats: min={train_batch.min():.2f}, max={train_batch.max():.2f}, mean={train_batch.mean():.2f}")
    print(f"   • Val batch stats: min={val_batch.min():.2f}, max={val_batch.max():.2f}, mean={val_batch.mean():.2f}")
    
    print(f"\n🎯 Ready for SSAST pre-training!")
    
except Exception as e:
    print(f"❌ Dataset creation failed: {e}")
    raise Exception(f"Enhanced dataset setup failed: {e}")

---
## 🧠 Cell 5: Production AST Model Architecture
---

In [None]:
import sys
import os
import json
import pickle
from pathlib import Path
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datetime import datetime
from flax import linen as nn
from flax.training import train_state
import time

sys.path.append('/content/piano-perception-transformer/src')

print("🧠 Production AST Model + Enhanced Training Pipeline")
print("="*60)

class ProductionASTForSSAST(nn.Module):
    """Production AST implementation for SSAST pre-training
    
    Based on Gong et al. 2021 - Audio Spectrogram Transformer
    - 12 transformer layers (full production scale)
    - 768 embedding dimensions
    - 12 attention heads
    - Stochastic depth regularization
    - Self-supervised pre-training objective
    """
    
    patch_size: int = 16
    embed_dim: int = 768
    num_layers: int = 12
    num_heads: int = 12
    mlp_ratio: float = 4.0
    dropout_rate: float = 0.1
    attention_dropout: float = 0.1
    stochastic_depth_rate: float = 0.1
    
    def setup(self):
        # Pre-compute stochastic depth drop rates (linearly increasing)
        self.drop_rates = [
            self.stochastic_depth_rate * i / (self.num_layers - 1) 
            for i in range(self.num_layers)
        ]
    
    @nn.compact
    def __call__(self, x, training: bool = True):
        """
        Full production AST forward pass for pre-training
        Args:
            x: Mel-spectrogram [batch, time, freq] -> (batch, 128, 128)
        Returns:
            features: [batch, num_patches, embed_dim] - representations for self-supervised learning
        """
        batch_size, time_frames, freq_bins = x.shape
        
        # === PATCH EMBEDDING ===
        patch_size = self.patch_size
        
        # Ensure input can be divided into patches  
        time_pad = (patch_size - time_frames % patch_size) % patch_size
        freq_pad = (patch_size - freq_bins % patch_size) % patch_size
        
        if time_pad > 0 or freq_pad > 0:
            x = jnp.pad(x, ((0, 0), (0, time_pad), (0, freq_pad)), mode='constant', constant_values=-80.0)
        
        time_patches = x.shape[1] // patch_size
        freq_patches = x.shape[2] // patch_size
        num_patches = time_patches * freq_patches
        
        # Reshape to patches: [batch, num_patches, patch_dim]
        x = x.reshape(batch_size, time_patches, patch_size, freq_patches, patch_size)
        x = x.transpose(0, 1, 3, 2, 4)  # [batch, time_patches, freq_patches, patch_size, patch_size]
        x = x.reshape(batch_size, num_patches, patch_size * patch_size)
        
        # Linear patch embedding
        x = nn.Dense(
            self.embed_dim, 
            kernel_init=nn.initializers.truncated_normal(stddev=0.02),
            bias_init=nn.initializers.zeros,
            name='patch_embedding'
        )(x)
        
        # === 2D POSITIONAL ENCODING ===
        pos_embedding = self.param(
            'pos_embedding',
            nn.initializers.truncated_normal(stddev=0.02),
            (1, num_patches, self.embed_dim)
        )
        x = x + pos_embedding
        
        # Embedding dropout
        x = nn.Dropout(self.dropout_rate, deterministic=not training)(x)
        
        # === 12-LAYER TRANSFORMER ENCODER ===
        for layer_idx in range(self.num_layers):
            # Stochastic depth probability for this layer
            drop_rate = self.drop_rates[layer_idx]
            
            # Multi-Head Self-Attention Block
            residual = x
            x = nn.LayerNorm(epsilon=1e-6, name=f'norm1_layer{layer_idx}')(x)
            
            attention = nn.MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                dropout_rate=self.attention_dropout,
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'attention_layer{layer_idx}'
            )(x, x, deterministic=not training)
            
            # Stochastic depth for attention (training only)
            if training and drop_rate > 0:
                random_tensor = jax.random.uniform(
                    self.make_rng('stochastic_depth'), (batch_size, 1, 1)
                )
                keep_prob = 1.0 - drop_rate
                binary_tensor = (random_tensor < keep_prob).astype(x.dtype)
                attention = attention * binary_tensor / keep_prob
            
            x = residual + nn.Dropout(self.dropout_rate, deterministic=not training)(attention)
            
            # Feed-Forward Network Block
            residual = x
            x = nn.LayerNorm(epsilon=1e-6, name=f'norm2_layer{layer_idx}')(x)
            
            # MLP with 4x expansion
            mlp_hidden = int(self.embed_dim * self.mlp_ratio)
            
            mlp = nn.Dense(
                mlp_hidden, 
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'mlp_dense1_layer{layer_idx}'
            )(x)
            mlp = nn.gelu(mlp)
            mlp = nn.Dropout(self.dropout_rate, deterministic=not training)(mlp)
            
            mlp = nn.Dense(
                self.embed_dim,
                kernel_init=nn.initializers.truncated_normal(stddev=0.02),
                bias_init=nn.initializers.zeros,
                name=f'mlp_dense2_layer{layer_idx}'
            )(mlp)
            
            # Stochastic depth for MLP
            if training and drop_rate > 0:
                random_tensor = jax.random.uniform(
                    self.make_rng('stochastic_depth'), (batch_size, 1, 1)
                )
                keep_prob = 1.0 - drop_rate
                binary_tensor = (random_tensor < keep_prob).astype(x.dtype)
                mlp = mlp * binary_tensor / keep_prob
            
            x = residual + nn.Dropout(self.dropout_rate, deterministic=not training)(mlp)
        
        # === FINAL NORMALIZATION ===
        x = nn.LayerNorm(epsilon=1e-6, name='final_norm')(x)
        
        # For SSAST pre-training, we return the raw features
        # These will be used for self-supervised objectives
        return x  # Shape: [batch, num_patches, embed_dim]

def create_advanced_optimizer(total_steps, learning_rate=5e-5, weight_decay=0.01, warmup_steps=1000):
    """Create advanced optimizer with cosine decay and warmup"""
    
    # Cosine decay schedule with warmup
    warmup_schedule = optax.linear_schedule(
        init_value=1e-8,
        end_value=learning_rate,
        transition_steps=warmup_steps
    )
    
    cosine_schedule = optax.cosine_decay_schedule(
        init_value=learning_rate,
        decay_steps=total_steps - warmup_steps,
        alpha=0.01  # Final LR = 1% of initial LR
    )
    
    lr_schedule = optax.join_schedules(
        schedules=[warmup_schedule, cosine_schedule],
        boundaries=[warmup_steps]
    )
    
    # AdamW optimizer with gradient clipping
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),  # Gradient clipping
        optax.adamw(
            learning_rate=lr_schedule,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.999,
            eps=1e-8
        )
    )
    
    return optimizer

def get_learning_rate_from_optimizer(opt_state, step):
    """Extract current learning rate from optimizer state"""
    try:
        # Try to get learning rate from the schedule directly
        # This works for most modern optax optimizers with schedules
        if hasattr(opt_state, 'hyperparams') and 'learning_rate' in opt_state.hyperparams:
            return opt_state.hyperparams['learning_rate']
        
        # For chained optimizers (like adamw with gradient clipping)
        # Navigate the chain to find the schedule
        current_state = opt_state
        for i in range(10):  # Safety limit
            if hasattr(current_state, 'hyperparams'):
                if isinstance(current_state.hyperparams, dict) and 'learning_rate' in current_state.hyperparams:
                    lr = current_state.hyperparams['learning_rate']
                    # If it's a schedule function, call it with the step
                    if callable(lr):
                        return lr(step)
                    return lr
                elif hasattr(current_state.hyperparams, 'learning_rate'):
                    lr = current_state.hyperparams.learning_rate
                    if callable(lr):
                        return lr(step)
                    return lr
            
            # Try to go deeper into nested state
            if hasattr(current_state, 'inner_state') and len(current_state.inner_state) > 0:
                current_state = current_state.inner_state[0]
            elif hasattr(current_state, 'inner_state'):
                break
            else:
                break
        
        # Fallback: return a default learning rate
        return 5e-5
        
    except Exception as e:
        # If all else fails, return the base learning rate
        return 5e-5

@jax.jit
def enhanced_train_step(train_state_obj, batch_specs, dropout_rng, stochastic_rng):
    """Enhanced training step with advanced SSAST loss"""
    
    def loss_fn(params):
        # Forward pass
        features = train_state_obj.apply_fn(
            params, batch_specs,
            training=True,
            rngs={'dropout': dropout_rng, 'stochastic_depth': stochastic_rng}
        )  # Shape: [batch, num_patches, embed_dim]
        
        # SSAST Self-Supervised Loss Components:
        
        # 1. Consistency Loss: Features should be consistent across patches
        # Encourage similar features for similar patches
        patch_mean = jnp.mean(features, axis=1, keepdims=True)  # [batch, 1, embed_dim]
        consistency_loss = jnp.mean(jnp.var(features - patch_mean, axis=1))
        
        # 2. Magnitude Regularization: Prevent feature explosion
        magnitude_loss = jnp.mean(jnp.square(features))
        
        # 3. Diversity Loss: Encourage diverse representations across embedding dimensions
        feature_std = jnp.std(features, axis=(0, 1))  # [embed_dim]
        diversity_loss = -jnp.mean(jnp.log(feature_std + 1e-8))  # Encourage high std
        
        # Combined loss
        total_loss = consistency_loss + 0.1 * magnitude_loss + 0.01 * diversity_loss
        
        # Additional metrics for monitoring
        metrics = {
            'total_loss': total_loss,
            'consistency_loss': consistency_loss,
            'magnitude_loss': magnitude_loss,
            'diversity_loss': diversity_loss,
            'output_mean': jnp.mean(features),
            'output_std': jnp.std(features)
        }
        
        return total_loss, metrics
    
    # Compute gradients
    (loss_val, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state_obj.params)
    
    # Gradient norm for monitoring
    grad_norm = optax.global_norm(grads)
    
    # Update parameters
    new_train_state = train_state_obj.apply_gradients(grads=grads)
    
    # Get current learning rate (using safe extraction function)
    current_lr = get_learning_rate_from_optimizer(train_state_obj.opt_state, train_state_obj.step)
    
    # Update metrics
    metrics.update({
        'grad_norm': grad_norm,
        'learning_rate': current_lr
    })
    
    return new_train_state, metrics

# Initialize production AST model
print(f"🏗️ Initializing Production AST Model...")
ast_model = ProductionASTForSSAST(
    patch_size=16,
    embed_dim=768,
    num_layers=12,  # Full 12 layers for production
    num_heads=12,   # Full 12 heads
    mlp_ratio=4.0,
    dropout_rate=0.1,
    attention_dropout=0.1,
    stochastic_depth_rate=0.1
)

print(f"✅ Production AST Model Initialized!")
print(f"   • Patch size: 16x16")
print(f"   • Embedding dimension: 768")
print(f"   • Transformer layers: 12")
print(f"   • Attention heads: 12")
print(f"   • MLP ratio: 4.0")
print(f"   • Stochastic depth: 0.1")
print(f"   • Total patches per spectrogram: 64 (8x8)")

# Test model initialization
print(f"\n🧪 Testing model initialization...")
dummy_input = jnp.ones((2, 128, 128))  # Batch of 2 spectrograms
rng = jax.random.PRNGKey(42)
init_rng, dropout_rng, stochastic_rng = jax.random.split(rng, 3)

params = ast_model.init(
    {'params': init_rng, 'dropout': dropout_rng, 'stochastic_depth': stochastic_rng},
    dummy_input,
    training=False
)

# Count parameters
param_count = sum(x.size for x in jax.tree.leaves(params))
print(f"✅ Model initialized successfully!")
print(f"   • Total parameters: {param_count:,}")
print(f"   • Memory usage: ~{param_count * 4 / 1024**2:.1f} MB (FP32)")

# Test forward pass
print(f"\n🚀 Testing forward pass...")
output = ast_model.apply(
    params, dummy_input,
    training=False,
    rngs={'dropout': dropout_rng, 'stochastic_depth': stochastic_rng}
)

print(f"✅ Forward pass successful!")
print(f"   • Input shape: {dummy_input.shape}")
print(f"   • Output shape: {output.shape}")
print(f"   • Output stats: min={output.min():.4f}, max={output.max():.4f}, mean={output.mean():.4f}")

print(f"\n🎯 Ready for SSAST pre-training!")

---
## 🚀 Cell 6: Execute SSAST Pre-training
---

In [None]:
import sys
import os

sys.path.append('./src')

print("🚀 SSAST PRE-TRAINING - PRODUCTION EXECUTION")
print("="*70)

# Check that we have our components ready
if 'train_dataset' not in locals():
    raise RuntimeError("Run Cell 4 first to set up datasets with train/val/test splits")

if 'ast_model' not in locals():
    raise RuntimeError("Run Cell 5 first to initialize production AST model")

print("✅ All prerequisites ready")
print(f"   • Enhanced datasets with augmentation: ✅")
print(f"   • Production 12-layer AST model: ✅") 
print(f"   • Advanced training pipeline: ✅")
print(f"   • WandB experiment tracking: ✅")

def execute_ssast_pretraining(
    model, train_dataset, val_dataset, 
    num_epochs=50, batch_size=32, patience=15
):
    """Execute SSAST pre-training with all improvements"""
    print("🚀 Starting SSAST Pre-training...")
    print("="*60)
    
    # Initialize model parameters
    rng = jax.random.PRNGKey(42)
    rng, init_rng, dropout_rng, stochastic_rng = jax.random.split(rng, 4)
    
    dummy_input = jnp.ones((batch_size, 128, 128))
    params = model.init(
        {'params': init_rng, 'dropout': dropout_rng, 'stochastic_depth': stochastic_rng},
        dummy_input,
        training=False
    )
    
    # Training configuration
    train_size = len(train_dataset)
    steps_per_epoch = max(train_size // batch_size, 10)
    total_steps = num_epochs * steps_per_epoch
    
    print(f"📊 Training Configuration:")
    print(f"   • Model: {model.__class__.__name__}")
    print(f"   • Parameters: {sum(x.size for x in jax.tree.leaves(params)):,}")
    print(f"   • Train size: {train_size} spectrograms")
    print(f"   • Val size: {len(val_dataset)} spectrograms")
    print(f"   • Batch size: {batch_size}")
    print(f"   • Steps per epoch: {steps_per_epoch}")
    print(f"   • Total steps: {total_steps:,}")
    print(f"   • Epochs: {num_epochs}")
    print(f"   • Early stopping patience: {patience}")
    
    # Create advanced optimizer
    optimizer = create_advanced_optimizer(
        total_steps=total_steps,
        learning_rate=5e-5,
        weight_decay=0.01
    )
    
    # Create training state
    train_state_obj = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )
    
    # Training tracking
    best_val_loss = float('inf')
    patience_counter = 0
    training_history = {
        'train_loss': [],
        'val_loss': [],
        'learning_rates': [],
        'grad_norms': []
    }
    
    # Create checkpoint directory
    checkpoint_dir = '/content/drive/MyDrive/piano_transformer/checkpoints/ssast_pretraining'
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    print(f"\n🎯 Starting training loop...")
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        # === TRAINING PHASE ===
        train_metrics = []
        
        for step in range(steps_per_epoch):
            # Get training batch
            batch_specs = train_dataset.get_batch(batch_size, shuffle=True)
            batch_specs = jnp.array(batch_specs)
            
            # Generate RNG keys for this step
            rng, dropout_rng, stochastic_rng = jax.random.split(rng, 3)
            
            # Training step
            train_state_obj, metrics = enhanced_train_step(
                train_state_obj, batch_specs, dropout_rng, stochastic_rng
            )
            
            train_metrics.append(metrics)
            
            # Log to WandB every 10 steps
            if step % 10 == 0:
                try:
                    wandb.log({
                        "train/loss": float(metrics['total_loss']),
                        "train/consistency_loss": float(metrics['consistency_loss']),
                        "train/magnitude_loss": float(metrics['magnitude_loss']),
                        "train/diversity_loss": float(metrics['diversity_loss']),
                        "train/output_mean": float(metrics['output_mean']),
                        "train/output_std": float(metrics['output_std']),
                        "train/grad_norm": float(metrics['grad_norm']),
                        "train/learning_rate": float(metrics['learning_rate']),
                        "epoch": epoch,
                        "step": int(train_state_obj.step)
                    })
                except:
                    pass  # Continue if WandB fails
        
        # === VALIDATION PHASE ===
        val_metrics = []
        val_steps = max(len(val_dataset) // batch_size, 1)
        
        for val_step in range(val_steps):
            batch_specs = val_dataset.get_batch(batch_size, shuffle=False)
            batch_specs = jnp.array(batch_specs)
            
            # Validation forward pass (no training)
            rng, dropout_rng, stochastic_rng = jax.random.split(rng, 3)
            
            features = model.apply(
                train_state_obj.params, batch_specs,
                training=False,
                rngs={'dropout': dropout_rng, 'stochastic_depth': stochastic_rng}
            )
            
            # Compute validation loss (same as training but without gradients)
            patch_mean = jnp.mean(features, axis=1, keepdims=True)
            val_consistency_loss = jnp.mean(jnp.var(features - patch_mean, axis=1))
            val_magnitude_loss = jnp.mean(jnp.square(features))
            feature_std = jnp.std(features, axis=(0, 1))
            val_diversity_loss = -jnp.mean(jnp.log(feature_std + 1e-8))
            val_loss = val_consistency_loss + 0.1 * val_magnitude_loss + 0.01 * val_diversity_loss
            
            val_metrics.append({
                'val_loss': val_loss,
                'val_consistency_loss': val_consistency_loss,
                'val_magnitude_loss': val_magnitude_loss,
                'val_diversity_loss': val_diversity_loss
            })
        
        # === EPOCH SUMMARY ===
        # Average metrics
        avg_train_loss = np.mean([m['total_loss'] for m in train_metrics])
        avg_val_loss = np.mean([m['val_loss'] for m in val_metrics])
        avg_lr = np.mean([m['learning_rate'] for m in train_metrics])
        avg_grad_norm = np.mean([m['grad_norm'] for m in train_metrics])
        
        # Store history
        training_history['train_loss'].append(avg_train_loss)
        training_history['val_loss'].append(avg_val_loss)
        training_history['learning_rates'].append(avg_lr)
        training_history['grad_norms'].append(avg_grad_norm)
        
        epoch_time = time.time() - epoch_start
        total_time = time.time() - start_time
        
        print(f"Epoch {epoch+1:3d}: "
              f"Train Loss={avg_train_loss:.4f}, "
              f"Val Loss={avg_val_loss:.4f}, "
              f"LR={avg_lr:.6f}, "
              f"Time={epoch_time:.1f}s")
        
        # Log epoch metrics to WandB
        try:
            wandb.log({
                "epoch/train_loss": avg_train_loss,
                "epoch/val_loss": avg_val_loss,
                "epoch/learning_rate": avg_lr,
                "epoch/grad_norm": avg_grad_norm,
                "epoch/time_seconds": epoch_time,
                "epoch/total_time_hours": total_time / 3600,
                "epoch/epoch": epoch + 1
            })
        except:
            pass
        
        # === EARLY STOPPING & CHECKPOINTING ===
        improved = avg_val_loss < best_val_loss
        
        if improved:
            best_val_loss = avg_val_loss
            patience_counter = 0
            
            # Save best model
            best_checkpoint = {
                'params': train_state_obj.params,
                'step': train_state_obj.step,
                'epoch': epoch + 1,
                'best_val_loss': best_val_loss,
                'training_history': training_history,
                'model_config': {
                    'embed_dim': 768,
                    'num_layers': 12,
                    'num_heads': 12,
                    'patch_size': 16,
                    'stochastic_depth_rate': 0.1
                }
            }
            
            best_path = os.path.join(checkpoint_dir, 'best_ssast_model.pkl')
            with open(best_path, 'wb') as f:
                pickle.dump(best_checkpoint, f)
            
            print(f"   ✅ New best model saved (val_loss: {best_val_loss:.4f})")
            
        else:
            patience_counter += 1
            print(f"   ⏳ No improvement ({patience_counter}/{patience})")
        
        # Regular checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pkl')
            regular_checkpoint = {
                'params': train_state_obj.params,
                'step': train_state_obj.step,
                'epoch': epoch + 1,
                'val_loss': avg_val_loss,
                'training_history': training_history
            }
            with open(checkpoint_path, 'wb') as f:
                pickle.dump(regular_checkpoint, f)
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping after {patience} epochs without improvement")
            print(f"   Best validation loss: {best_val_loss:.4f}")
            break
    
    # === TRAINING COMPLETE ===
    total_training_time = time.time() - start_time
    
    print(f"\n" + "="*60)
    print(f"🎉 SSAST PRE-TRAINING COMPLETED!")
    print(f"="*60)
    print(f"📈 Final Results:")
    print(f"   • Best validation loss: {best_val_loss:.4f}")
    print(f"   • Total epochs: {epoch + 1}")
    print(f"   • Total steps: {train_state_obj.step:,}")
    print(f"   • Training time: {total_training_time/3600:.1f} hours")
    print(f"   • Final learning rate: {avg_lr:.2e}")
    
    return train_state_obj, best_val_loss, training_history

# Execute SSAST pre-training
try:
    print(f"\n🎯 Starting SSAST Pre-training on MAESTRO dataset...")
    print(f"   • Using production 12-layer AST architecture")
    print(f"   • Training set: {len(train_dataset)} spectrograms (with augmentation)")
    print(f"   • Validation set: {len(val_dataset)} spectrograms (no augmentation)")
    print(f"   • Advanced optimization with cosine LR scheduling")
    print(f"   • Early stopping with patience=15")
    print(f"   • Comprehensive WandB logging")
    
    # Execute training with production settings
    final_state, best_loss, history = execute_ssast_pretraining(
        model=ast_model,
        train_dataset=train_dataset, 
        val_dataset=val_dataset,
        num_epochs=50,   # Reduced from 100 for efficiency, increase as needed
        batch_size=32,   # Optimal for most GPUs
        patience=15      # Early stopping patience
    )
    
    print(f"\n🎉 SSAST PRE-TRAINING COMPLETED SUCCESSFULLY!")
    print(f"="*70)
    
    # Save pre-trained model for fine-tuning
    pretrained_model_path = '/content/drive/MyDrive/piano_transformer/checkpoints/ssast_pretraining/pretrained_for_finetuning.pkl'
    pretrained_checkpoint = {
        'params': final_state.params,
        'model_config': {
            'embed_dim': 768,
            'num_layers': 12,
            'num_heads': 12,
            'patch_size': 16,
            'stochastic_depth_rate': 0.1
        },
        'pretraining_results': {
            'best_val_loss': float(best_loss),
            'total_epochs': len(history['train_loss']),
            'convergence_achieved': best_loss < 1.0
        }
    }
    
    with open(pretrained_model_path, 'wb') as f:
        pickle.dump(pretrained_checkpoint, f)
    
    print(f"💾 Pre-trained model saved for fine-tuning: {pretrained_model_path}")
    print(f"🎯 READY FOR FINE-TUNING PHASE!")
    
except Exception as e:
    print(f"❌ SSAST pre-training failed: {str(e)}")
    raise

---
## 🎯 Pre-training Complete!

**Next Steps:**
1. 🎯 **Fine-tuning**: Run `2_Piano_Transformer_Finetuning.ipynb` to fine-tune on PercePiano dataset
2. 📊 **Evaluation**: Run `3_Piano_Transformer_Evaluation.ipynb` to evaluate performance

**Pre-trained Model Location:**
```
/content/drive/MyDrive/piano_transformer/checkpoints/ssast_pretraining/pretrained_for_finetuning.pkl
```
---