# MIDI Dataset Exploration

This notebook explores classical MIDI datasets to understand:
- Dataset structure and composition
- Note distributions per composer
- Sequence lengths and statistics
- Data quality and validation

**Prerequisites**: Download sample data first using `scripts/download_hf_dataset.py`

In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
import pandas as pd

from utils.midi_utils import (
    load_midi_file,
    get_midi_info,
    validate_midi_file,
    print_midi_summary,
    get_piano_midi
)

plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 1. Dataset Overview

First, let's check what data we have available.

In [None]:
data_dir = Path("../data/raw")

if data_dir.exists():
    print("Available datasets:")
    for subdir in data_dir.iterdir():
        if subdir.is_dir():
            midi_files = list(subdir.rglob("*.mid")) + list(subdir.rglob("*.midi"))
            print(f"  {subdir.name}: {len(midi_files)} MIDI files")
else:
    print("No data directory found. Please download data first using:")
    print("  python scripts/download_hf_dataset.py")

## 2. Sample MIDI File Analysis

Let's load and analyze a sample MIDI file to understand the structure.

In [None]:
all_midi_files = list(data_dir.rglob("*.mid")) + list(data_dir.rglob("*.midi"))

if len(all_midi_files) > 0:
    sample_file = all_midi_files[0]
    print(f"Analyzing: {sample_file}\n")
    print_midi_summary(sample_file)
else:
    print("No MIDI files found. Download data first.")

## 3. Dataset Statistics

Collect statistics across all MIDI files in the dataset.

In [None]:
def collect_dataset_stats(midi_files, max_files=None):
    """
    Collect statistics from a list of MIDI files.
    
    Args:
        midi_files: List of MIDI file paths
        max_files: Maximum number of files to process (None for all)
    """
    stats = {
        'total_notes': [],
        'durations': [],
        'pitch_ranges': [],
        'num_instruments': [],
        'composers': defaultdict(int),
        'valid_files': 0,
        'invalid_files': 0,
        'errors': []
    }
    
    files_to_process = midi_files[:max_files] if max_files else midi_files
    
    for i, midi_file in enumerate(files_to_process):
        if (i + 1) % 100 == 0:
            print(f"Processed {i+1}/{len(files_to_process)} files...")
        
        try:
            info = get_midi_info(midi_file)
            
            stats['total_notes'].append(info['total_notes'])
            
            if 'duration_seconds' in info:
                stats['durations'].append(info['duration_seconds'])
            
            if 'pitch_range' in info:
                stats['pitch_ranges'].append(info['pitch_range']['span'])
            
            stats['num_instruments'].append(len(info['instruments']))
            
            composer = midi_file.parent.name
            stats['composers'][composer] += 1
            
            stats['valid_files'] += 1
            
        except Exception as e:
            stats['invalid_files'] += 1
            stats['errors'].append((str(midi_file), str(e)))
    
    return stats

if len(all_midi_files) > 0:
    print(f"Collecting statistics from {min(len(all_midi_files), 1000)} files...\n")
    dataset_stats = collect_dataset_stats(all_midi_files, max_files=1000)
    print(f"\nProcessed {dataset_stats['valid_files']} valid files")
    print(f"Found {dataset_stats['invalid_files']} invalid files")
else:
    print("No MIDI files to analyze")

## 4. Visualization

Visualize the collected statistics.

In [None]:
if 'dataset_stats' in locals() and dataset_stats['valid_files'] > 0:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    axes[0, 0].hist(dataset_stats['total_notes'], bins=50, edgecolor='black')
    axes[0, 0].set_xlabel('Number of Notes')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Distribution of Notes per File')
    axes[0, 0].set_yscale('log')
    
    if dataset_stats['durations']:
        axes[0, 1].hist(dataset_stats['durations'], bins=50, edgecolor='black')
        axes[0, 1].set_xlabel('Duration (seconds)')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('Distribution of File Durations')
    
    if dataset_stats['pitch_ranges']:
        axes[1, 0].hist(dataset_stats['pitch_ranges'], bins=30, edgecolor='black')
        axes[1, 0].set_xlabel('Pitch Range (semitones)')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Distribution of Pitch Ranges')
    
    if dataset_stats['composers']:
        composers = list(dataset_stats['composers'].keys())[:10]
        counts = [dataset_stats['composers'][c] for c in composers]
        axes[1, 1].barh(composers, counts)
        axes[1, 1].set_xlabel('Number of Files')
        axes[1, 1].set_title('Top 10 Composers by File Count')
    
    plt.tight_layout()
    plt.show()
    
    print("\nDataset Summary:")
    print(f"  Total notes: {np.sum(dataset_stats['total_notes']):,}")
    print(f"  Average notes per file: {np.mean(dataset_stats['total_notes']):.1f}")
    print(f"  Median notes per file: {np.median(dataset_stats['total_notes']):.1f}")
    
    if dataset_stats['durations']:
        print(f"\n  Total duration: {np.sum(dataset_stats['durations'])/3600:.2f} hours")
        print(f"  Average duration: {np.mean(dataset_stats['durations']):.1f} seconds")
        print(f"  Median duration: {np.median(dataset_stats['durations']):.1f} seconds")

## 5. Composer-Specific Analysis

Analyze statistics for specific composers of interest.

In [None]:
target_composers = ['bach', 'mozart', 'beethoven', 'chopin']

composer_stats = {}

for composer in target_composers:
    composer_files = [f for f in all_midi_files if composer.lower() in str(f).lower()]
    
    if composer_files:
        print(f"\nAnalyzing {composer.title()}: {len(composer_files)} files")
        composer_stats[composer] = collect_dataset_stats(composer_files, max_files=100)
        
        stats = composer_stats[composer]
        if stats['total_notes']:
            print(f"  Avg notes: {np.mean(stats['total_notes']):.1f}")
            print(f"  Avg duration: {np.mean(stats['durations']):.1f}s" if stats['durations'] else "  No duration info")

## 6. Next Steps

Based on this analysis:
1. Identify appropriate sequence lengths for model training
2. Determine if we need to chunk long pieces
3. Filter out files that are too short/long
4. Create train/validation/test splits per composer
5. Design tokenization strategy based on pitch distributions