# MIDI Dataset Exploration
This notebook explores the processed MIDI dataset for the melody-to-chord prediction project.

In [None]:
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import Counter, defaultdict

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

## Load Processed Data

In [None]:
# Load processed data
data_path = Path('./processed_data/processed_data.pkl')
stats_path = Path('./processed_data/processing_stats.pkl')

if data_path.exists():
    with open(data_path, 'rb') as f:
        processed_data = pickle.load(f)
    print(f"Loaded {len(processed_data)} processed MIDI files")
else:
    print("Processed data not found. Run the data pipeline first.")
    processed_data = []

if stats_path.exists():
    with open(stats_path, 'rb') as f:
        stats = pickle.load(f)
    print(f"Loaded processing statistics")
else:
    print("Statistics not found.")
    stats = {}

## Dataset Overview

In [None]:
if processed_data:
    # Basic statistics
    total_files = len(processed_data)
    total_pairs = sum(len(d['aligned_pairs']) for d in processed_data)
    
    print(f"Dataset Overview:")
    print(f"- Total MIDI files processed: {total_files:,}")
    print(f"- Total melody-chord pairs: {total_pairs:,}")
    print(f"- Average pairs per file: {total_pairs / total_files:.1f}")
    
    # Distribution of pairs per file
    pairs_per_file = [len(d['aligned_pairs']) for d in processed_data]
    
    print(f"\nPairs per file statistics:")
    print(f"- Min: {min(pairs_per_file)}")
    print(f"- Max: {max(pairs_per_file)}")
    print(f"- Median: {np.median(pairs_per_file):.1f}")
    print(f"- 25th percentile: {np.percentile(pairs_per_file, 25):.1f}")
    print(f"- 75th percentile: {np.percentile(pairs_per_file, 75):.1f}")

## Data Distribution Analysis

In [None]:
if processed_data:
    # Create a DataFrame from all melody-chord pairs
    all_pairs = []
    
    for file_data in processed_data:
        for pair in file_data['aligned_pairs']:
            all_pairs.append({
                'file_path': file_data['file_path'],
                'key_signature': file_data['key_signature'],
                'time_signature': file_data['time_signature'],
                'melody_pitch': pair['melody_pitch'],
                'melody_duration': pair['melody_duration'],
                'chord_root': pair['chord_root'],
                'chord_quality': pair['chord_quality'],
                'chord_duration': pair['chord_duration']
            })
    
    df = pd.DataFrame(all_pairs)
    print(f"Created DataFrame with {len(df):,} melody-chord pairs")
    print(f"Columns: {list(df.columns)}")

## Visualization: Key Signatures

In [None]:
if processed_data:
    # Key signature distribution
    key_counts = Counter([d['key_signature'] for d in processed_data])
    
    plt.figure(figsize=(12, 6))
    keys = list(key_counts.keys())[:15]  # Top 15 keys
    counts = [key_counts[k] for k in keys]
    
    plt.bar(keys, counts)
    plt.title('Distribution of Key Signatures in Dataset')
    plt.xlabel('Key Signature')
    plt.ylabel('Number of Files')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

## Visualization: Chord Distribution

In [None]:
if 'df' in locals() and not df.empty:
    # Chord quality distribution
    chord_quality_counts = df['chord_quality'].value_counts()
    
    plt.figure(figsize=(10, 6))
    chord_quality_counts.head(10).plot(kind='bar')
    plt.title('Top 10 Chord Qualities in Dataset')
    plt.xlabel('Chord Quality')
    plt.ylabel('Frequency')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Chord root distribution
    chord_root_counts = df['chord_root'].value_counts()
    
    plt.figure(figsize=(10, 6))
    chord_root_counts.plot(kind='bar')
    plt.title('Chord Root Distribution in Dataset')
    plt.xlabel('Chord Root')
    plt.ylabel('Frequency')
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.show()

## Visualization: Melody Pitch Distribution

In [None]:
if 'df' in locals() and not df.empty:
    # Melody pitch distribution
    plt.figure(figsize=(12, 6))
    plt.hist(df['melody_pitch'], bins=50, alpha=0.7, edgecolor='black')
    plt.title('Distribution of Melody Pitches (MIDI Note Numbers)')
    plt.xlabel('MIDI Note Number')
    plt.ylabel('Frequency')
    
    # Add some reference lines for common octaves
    plt.axvline(x=60, color='red', linestyle='--', alpha=0.7, label='Middle C (C4)')
    plt.axvline(x=72, color='orange', linestyle='--', alpha=0.7, label='C5')
    plt.axvline(x=48, color='blue', linestyle='--', alpha=0.7, label='C3')
    
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    print(f"Melody pitch statistics:")
    print(f"- Range: {df['melody_pitch'].min()} to {df['melody_pitch'].max()}")
    print(f"- Mean: {df['melody_pitch'].mean():.1f}")
    print(f"- Median: {df['melody_pitch'].median():.1f}")

## Duration Analysis

In [None]:
if 'df' in locals() and not df.empty:
    # Duration analysis
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Melody durations
    ax1.hist(df['melody_duration'], bins=30, alpha=0.7, edgecolor='black')
    ax1.set_title('Distribution of Melody Note Durations')
    ax1.set_xlabel('Duration (Quarter Notes)')
    ax1.set_ylabel('Frequency')
    ax1.axvline(x=1.0, color='red', linestyle='--', alpha=0.7, label='Quarter Note')
    ax1.axvline(x=0.5, color='orange', linestyle='--', alpha=0.7, label='Eighth Note')
    ax1.legend()
    
    # Chord durations
    ax2.hist(df['chord_duration'], bins=30, alpha=0.7, edgecolor='black')
    ax2.set_title('Distribution of Chord Durations')
    ax2.set_xlabel('Duration (Quarter Notes)')
    ax2.set_ylabel('Frequency')
    ax2.axvline(x=1.0, color='red', linestyle='--', alpha=0.7, label='Quarter Note')
    ax2.axvline(x=4.0, color='blue', linestyle='--', alpha=0.7, label='Whole Note')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

## Chord Progression Patterns

In [None]:
if processed_data:
    # Analyze chord progressions
    chord_transitions = []
    
    for file_data in processed_data[:100]:  # Sample first 100 files for performance
        pairs = file_data['aligned_pairs']
        if len(pairs) > 1:
            for i in range(len(pairs) - 1):
                current_chord = f"{pairs[i]['chord_root']}_{pairs[i]['chord_quality']}"
                next_chord = f"{pairs[i+1]['chord_root']}_{pairs[i+1]['chord_quality']}"
                chord_transitions.append((current_chord, next_chord))
    
    # Count most common transitions
    transition_counts = Counter(chord_transitions)
    top_transitions = transition_counts.most_common(20)
    
    print("Top 20 Chord Transitions:")
    for i, ((from_chord, to_chord), count) in enumerate(top_transitions, 1):
        print(f"{i:2d}. {from_chord:15} → {to_chord:15} ({count:3d} times)")

## Sample Data Inspection

In [None]:
if processed_data:
    # Show a sample of the data
    print("Sample file data structure:")
    sample_file = processed_data[0]
    
    print(f"\nFile: {sample_file['file_path']}")
    print(f"Key: {sample_file['key_signature']}")
    print(f"Time Signature: {sample_file['time_signature']}")
    print(f"Total melody notes: {sample_file['total_melody_notes']}")
    print(f"Total chords: {sample_file['total_chords']}")
    print(f"Aligned pairs: {len(sample_file['aligned_pairs'])}")
    
    print(f"\nFirst 5 melody-chord pairs:")
    for i, pair in enumerate(sample_file['aligned_pairs'][:5]):
        melody_note = pair['melody_pitch']
        chord = f"{pair['chord_root']} {pair['chord_quality']}"
        print(f"  {i+1}. Melody: MIDI {melody_note:3d} → Chord: {chord}")

## Processing Statistics

In [None]:
if stats:
    print("Processing Statistics:")
    print(f"- Files processed successfully: {stats.get('total_files_processed', 'N/A')}")
    print(f"- Files failed: {stats.get('total_files_failed', 'N/A')}")
    print(f"- Total melody-chord pairs: {stats.get('total_melody_chord_pairs', 'N/A')}")
    
    if 'chord_distribution' in stats:
        print(f"\nTop chord types in dataset:")
        for chord, count in list(stats['chord_distribution'].items())[:10]:
            print(f"  {chord}: {count}")
    
    if 'key_distribution' in stats:
        print(f"\nTop key signatures in dataset:")
        for key, count in list(stats['key_distribution'].items())[:10]:
            print(f"  {key}: {count}")

## Next Steps

Based on this data exploration, here are the next steps for building the melody-to-chord model:

1. **Feature Engineering**: Convert raw data into model-ready features
2. **Sequence Preparation**: Create input-output sequences for training
3. **Vocabulary Building**: Create mappings for notes and chords
4. **Data Splitting**: Train/validation/test splits
5. **Model Architecture**: Implement the Transformer-based model
6. **Training Pipeline**: Set up training loop with proper metrics