In [1]:
from music21 import converter, note, chord, instrument, meter, roman, stream, key, tempo
import numpy as np
import os
import pandas as pd
from fractions import Fraction
from scipy.stats import entropy
from pathlib import Path
from typing import List, Dict, Optional
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
from tqdm import tqdm
import traceback
import warnings
warnings.filterwarnings("ignore")

## Symbolic Feature Extraction for Topological Analysis of MIDI Data

Here each MIDI file in the dataset is parsed to extract symbolic musical features such as pitch, duration, polyphony, key signature, and instrument information. These are compiled into a structured DataFrame for further analysis.

In [2]:
def get_midi_features(midi_file, instrument_family_map=None):
    try:
        from music21 import converter, key, meter, instrument, note, chord, stream, tempo
        import numpy as np

        # Parse MIDI file
        score = converter.parse(midi_file)
        flat = score.flatten()

        # === GLOBAL METADATA EXTRACTION ===
        
        # Key signature
        key_signature = flat.analyze('key')
        key_index = key_signature.tonic.pitchClass + (12 if key_signature.mode == 'minor' else 0)

        # Time signature
        time_signature = flat.recurse().getElementsByClass(meter.TimeSignature).first()
        time_sig = str(time_signature.ratioString) if time_signature else '4/4'
        time_sig_numerator = time_signature.numerator if time_signature else 4
        bar_duration = time_signature.barDuration.quarterLength if time_signature else 4.0

        # Tempo - optimized version
        def extract_tempo(score, flat):
            # Method 1: metronomeMarkBoundaries (most reliable)
            try:
                tempo_marks = score.metronomeMarkBoundaries()
                if tempo_marks:
                    return float(tempo_marks[0][2].number)
            except (AttributeError, IndexError, TypeError):
                pass

            # Method 2: MetronomeMark in flattened score
            try:
                for mark in flat.recurse().getElementsByClass(tempo.MetronomeMark):
                    if hasattr(mark, 'number') and mark.number:
                        return float(mark.number)
            except (AttributeError, TypeError):
                pass

            # Method 3: TempoIndication
            try:
                tempo_indication = flat.recurse().getElementsByClass('TempoIndication').first()
                if tempo_indication and hasattr(tempo_indication, 'number') and tempo_indication.number:
                    return float(tempo_indication.number)
            except (AttributeError, TypeError):
                pass

            return 120.0  # Default value

        tempo_bpm = extract_tempo(score, flat)

        # === OPTIMIZATION: PRE-COMPUTE CONTEXTS ===
        
        # Group by offset to avoid recalculations
        offset_dict = {}
        element_contexts = {}  # Context cache
        
        for element in flat.notesAndRests:
            offset_dict.setdefault(element.offset, []).append(element)
            
            # Pre-compute expensive contexts
            instr_context = element.getContextByClass(instrument.Instrument)
            track_context = element.getContextByClass(stream.Part)
            key_context = element.getContextByClass(key.Key)
            
            element_contexts[id(element)] = {
                'instrument': instr_context,
                'track': track_context,
                'key': key_context
            }

        def extract_instrument_info(element):
            """Optimized instrument info extraction"""
            instr_name = "Unknown"
            instr_family = "Other"
            
            instr = element_contexts[id(element)]['instrument']
            if instr and hasattr(instr, 'midiProgram') and instr.midiProgram is not None:
                try:
                    instr_obj = instrument.instrumentFromMidiProgram(instr.midiProgram)
                    instr_name = instr_obj.instrumentName if hasattr(instr_obj, 'instrumentName') else instr_name
                    if instrument_family_map and instr_name in instrument_family_map:
                        instr_family = instrument_family_map[instr_name]
                except (AttributeError, ValueError, TypeError):
                    pass
            
            return instr_name, instr_family

        def get_track_name(element):
            """Optimized track name extraction"""
            track_context = element_contexts[id(element)]['track']
            if track_context and hasattr(track_context, 'partName') and track_context.partName:
                return track_context.partName
            return 'Unknown'

        def get_local_key(element):
            """Optimized local key extraction"""
            try:
                local_key = element_contexts[id(element)]['key']
                if local_key:
                    return local_key.tonic.pitchClass + (12 if local_key.mode == 'minor' else 0)
            except (AttributeError, TypeError):
                pass
            return key_index

        # === OPTIMIZED METRIC CALCULATIONS ===
        
        def calculate_metric_weight(beat_position, time_sig_numerator):
            """Calculate metric weight based on position in measure"""
            beat_int = int(beat_position)
            
            # Weights for common time signatures
            if time_sig_numerator == 4:
                weights = {1: 1.0, 2: 0.4, 3: 0.6, 4: 0.2}
            elif time_sig_numerator == 3:
                weights = {1: 1.0, 2: 0.4, 3: 0.6}
            elif time_sig_numerator == 2:
                weights = {1: 1.0, 2: 0.5}
            else:
                # Generic: strong first beat, others weaker
                weights = {1: 1.0}
                for i in range(2, time_sig_numerator + 1):
                    weights[i] = 0.4 if i % 2 == 0 else 0.6
            
            return weights.get(beat_int, 0.3)

        # === FEATURE EXTRACTION ===
        
        features = []
        previous_pitch = None
        
        for offset, simultaneous_elements in sorted(offset_dict.items()):
            polyphony = len(simultaneous_elements)

            for element in simultaneous_elements:
                # Extract pre-computed contexts
                instr_name, instr_family = extract_instrument_info(element)
                track_name = get_track_name(element)
                local_key_index = get_local_key(element)
                
                # Metric calculations
                beat_position = element.beat if hasattr(element, 'beat') else 1.0
                beat_fraction = beat_position / time_sig_numerator
                metric_weight = calculate_metric_weight(beat_position, time_sig_numerator)
                measure_number = element.measureNumber if hasattr(element, 'measureNumber') and element.measureNumber is not None else 0
                articulation_ratio = float(element.quarterLength) / bar_duration

                # Common features
                common_features = {
                    'onset': float(offset),
                    'duration': float(element.quarterLength),
                    'polyphony': polyphony,
                    'key': key_index,
                    'local_key': local_key_index,
                    'time_signature': time_sig,
                    'tempo': tempo_bpm,
                    'measure': measure_number,
                    'beat_position': beat_position,
                    'beat_fraction': beat_fraction,
                    'metric_weight': metric_weight,
                    'articulation_ratio': articulation_ratio,
                    'instrument': instr_name,
                    'instrument_family': instr_family,
                    'track': track_name,
                }

                # Element-specific processing
                if isinstance(element, note.Note):
                    pitch = element.pitch.midi
                    pitch_class = element.pitch.pitchClass
                    pitch_octave = element.pitch.octave
                    velocity = element.volume.velocity if (element.volume and 
                                                         hasattr(element.volume, 'velocity') and 
                                                         element.volume.velocity is not None) else -1
                    interval_to_prev = abs(pitch - previous_pitch) if previous_pitch is not None else 0
                    previous_pitch = pitch

                    features.append({
                        **common_features,
                        'pitch': pitch,
                        'pitch_class': pitch_class,
                        'pitch_octave': pitch_octave,
                        'interval_to_prev': interval_to_prev,
                        'is_chord_tone': 0,
                        'velocity': velocity,
                        'is_rest': 0
                    })

                elif isinstance(element, chord.Chord):
                    duration = float(element.quarterLength)
                    velocity = element.volume.velocity if (element.volume and 
                                                         hasattr(element.volume, 'velocity') and 
                                                         element.volume.velocity is not None) else -1

                    # Sort pitches for consistency
                    pitches = sorted([p.midi for p in element.pitches])
                    
                    for pitch in pitches:
                        pitch_class = pitch % 12
                        pitch_octave = pitch // 12
                        interval_to_prev = abs(pitch - previous_pitch) if previous_pitch is not None else 0
                        previous_pitch = pitch

                        features.append({
                            **common_features,
                            'pitch': pitch,
                            'pitch_class': pitch_class,
                            'pitch_octave': pitch_octave,
                            'interval_to_prev': interval_to_prev,
                            'is_chord_tone': 1,
                            'velocity': velocity,
                            'is_rest': 0
                        })

                elif isinstance(element, note.Rest):
                    features.append({
                        **common_features,
                        'pitch': -1,
                        'pitch_class': -1,
                        'pitch_octave': -1,
                        'interval_to_prev': 0,
                        'is_chord_tone': 0,
                        'velocity': -1,
                        'is_rest': 1
                    })

        return features

    except Exception as e:
        print(f"[ERROR] Could not parse {midi_file}: {e}")
        import traceback
        traceback.print_exc()  # For detailed debugging
        return []

In [3]:
# Add instrument family to each event for higher-level grouping (e.g., strings, winds, percussion)
# This enables downstream analysis of timbral or orchestration patterns across genres or pieces

instrument_family_map = {
    # Keyboard
    'Piano': 'Keyboard',
    'Electric Piano': 'Electronic',
    'Celesta': 'Keyboard',
    'Organ': 'Keyboard',
    'Electric Organ': 'Keyboard',
    'Harpsichord': 'Keyboard',

    # Guitar
    'Acoustic Guitar': 'Guitar',
    'Electric Guitar': 'Guitar',

    # Bass
    'Acoustic Bass': 'Bass',
    'Electric Bass': 'Bass',
    'Fretless Bass': 'Bass',
    'Contrabass': 'Bass',

    # Strings
    'Violoncello': 'Strings',
    'Violin': 'Strings',
    'Viola': 'Strings',
    'Double Bass': 'Strings',
    'StringInstrument': 'Strings',

    # Brass
    'Trumpet': 'Brass',
    'Trombone': 'Brass',
    'French Horn': 'Brass',
    'Tuba': 'Brass',

    # Woodwind
    'Clarinet': 'Woodwind',
    'Bassoon': 'Woodwind',
    'Recorder': 'Woodwind',
    'Piccolo': 'Woodwind',
    'Flute': 'Woodwind',
    'Whistle': 'Woodwind',

    # Percussion
    'Timpani': 'Percussion',
    'Taiko': 'Percussion',
    'Marimba': 'Percussion',
    'Glockenspiel': 'Percussion',
    'Drums': 'Percussion',

    # Voice
    'Voice': 'Voice',
    'Choir': 'Voice',
    'Vocals': 'Voice',
    'Background Vocals': 'Voice',

    # Electronic
    'Sampler': 'Electronic',
    'Synthesizer': 'Electronic',

    # Other / Catch-all
    'Bagpipes': 'Other',
    'Ocarina': 'Other',
    'Unknown': 'Other',
}

In [4]:
# Configure logging for better error tracking
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('midi_extraction.log'),
        logging.StreamHandler()
    ]
)

def process_single_file(file_path: str, genre: str, filename: str, instrument_family_map: Optional[Dict] = None) -> List[Dict]:
    """
    Process a single MIDI file and return features with metadata.
    
    Args:
        file_path: Path to the MIDI file
        genre: Genre label for the file
        filename: Name of the file
        instrument_family_map: Optional mapping of instruments to families
    
    Returns:
        List[Dict]: Features with genre and filename metadata
    """
    try:
        # Import here to ensure availability in each thread
        from music21 import converter, key, meter, instrument, note, chord, stream, tempo
        
        # Extract features using the optimized function
        features = get_midi_features(file_path, instrument_family_map)
        
        # Add metadata to each feature
        for feature in features:
            feature['genre'] = genre
            feature['filename'] = filename
            
        logging.info(f"Successfully processed {filename} ({genre}): {len(features)} features extracted")
        return features
        
    except Exception as e:
        logging.error(f"Error processing {filename} ({genre}): {str(e)}")
        logging.debug(traceback.format_exc())
        return []

def get_midi_files_info(root_path: str) -> List[tuple]:
    """
    Scan directory structure and collect MIDI file information.
    
    Args:
        root_path: Path to root directory containing genre subfolders
        
    Returns:
        List[tuple]: List of (file_path, genre, filename) tuples
    """
    midi_files_info = []
    root_path = Path(root_path)
    
    if not root_path.exists():
        raise FileNotFoundError(f"Root path does not exist: {root_path}")
    
    # Supported MIDI extensions
    midi_extensions = {'.mid', '.midi', '.MID', '.MIDI'}
    
    for genre_dir in root_path.iterdir():
        if not genre_dir.is_dir():
            continue
            
        genre_name = genre_dir.name
        midi_count = 0
        
        for file_path in genre_dir.iterdir():
            if file_path.suffix in midi_extensions:
                midi_files_info.append((
                    str(file_path),
                    genre_name,
                    file_path.name
                ))
                midi_count += 1
        
        logging.info(f"Found {midi_count} MIDI files in genre: {genre_name}")
    
    logging.info(f"Total MIDI files found: {len(midi_files_info)}")
    return midi_files_info

def extract_features_threaded(
    root_path: str,
    instrument_family_map: Optional[Dict] = None,
    max_workers: Optional[int] = None,
    use_threading: bool = True
) -> pd.DataFrame:
    """
    Extract features from all MIDI files using thread-based parallelism.
    
    Args:
        root_path: Path to root directory containing genre subfolders
        instrument_family_map: Optional mapping of instruments to families
        max_workers: Number of parallel workers (None for auto-detection)
        use_threading: Whether to use threading (vs single-threaded)
        
    Returns:
        pd.DataFrame: DataFrame containing all extracted features
    """
    
    # Get all MIDI files information
    midi_files_info = get_midi_files_info(root_path)
    
    if not midi_files_info:
        logging.warning("No MIDI files found!")
        return pd.DataFrame()
    
    all_features = []
    
    if use_threading and len(midi_files_info) > 1:
        # Determine optimal number of workers (I/O bound task, can use more threads)
        if max_workers is None:
            max_workers = min(16, len(midi_files_info))  # Cap at 16 threads
        
        logging.info(f"Using {max_workers} parallel threads for feature extraction")
        
        # Use ThreadPoolExecutor for I/O-bound tasks (MIDI file reading)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            future_to_info = {
                executor.submit(process_single_file, file_path, genre, filename, instrument_family_map): (file_path, genre, filename)
                for file_path, genre, filename in midi_files_info
            }
            
            # Collect results with progress bar
            with tqdm(total=len(midi_files_info), desc="Processing MIDI files") as pbar:
                for future in as_completed(future_to_info):
                    try:
                        features = future.result()
                        if features:  # Only extend if features were extracted
                            all_features.extend(features)
                        pbar.update(1)
                    except Exception as e:
                        file_info = future_to_info[future]
                        logging.error(f"Failed to process {file_info[2]}: {str(e)}")
                        pbar.update(1)
    else:
        # Single-threaded processing
        logging.info("Using single-threaded processing")
        for file_path, genre, filename in tqdm(midi_files_info, desc="Processing MIDI files"):
            features = process_single_file(file_path, genre, filename, instrument_family_map)
            if features:  # Only extend if features were extracted
                all_features.extend(features)
    
    # Create DataFrame
    if all_features:
        df = pd.DataFrame(all_features)
        logging.info(f"Created DataFrame with {len(df)} feature rows from {len(midi_files_info)} files")
        
        # Basic data validation
        validate_dataframe(df)
        
        return df
    else:
        logging.warning("No features extracted!")
        return pd.DataFrame()

def extract_features_sequential(
    root_path: str,
    instrument_family_map: Optional[Dict] = None
) -> pd.DataFrame:
    """
    Sequential version that closely mirrors your original code but with improvements.
    
    Args:
        root_path: Path to root directory containing genre subfolders
        instrument_family_map: Optional mapping of instruments to families
        
    Returns:
        pd.DataFrame: DataFrame containing all extracted features
    """
    
    logging.info("=== Sequential MIDI Feature Extraction ===")
    logging.info(f"Root path: {root_path}")
    
    # Initialize empty list to collect features (like your original code)
    all_features = []
    
    # Get list of genres and files first
    root_path = Path(root_path)
    if not root_path.exists():
        raise FileNotFoundError(f"Root path does not exist: {root_path}")
    
    midi_extensions = {'.mid', '.midi', '.MID', '.MIDI'}
    total_files = 0
    processed_files = 0
    failed_files = 0
    
    # Count total files first for progress tracking
    for genre_dir in root_path.iterdir():
        if genre_dir.is_dir():
            total_files += len([f for f in genre_dir.iterdir() if f.suffix in midi_extensions])
    
    logging.info(f"Found {total_files} MIDI files to process")
    
    # Process each genre subfolder (exactly like your original code structure)
    with tqdm(total=total_files, desc="Processing MIDI files") as pbar:
        for genre_dir in root_path.iterdir():
            if not genre_dir.is_dir():
                continue  # Skip non-directory files
            
            genre = genre_dir.name
            genre_file_count = 0
            
            # Process each MIDI file within the current genre folder
            for file_path in genre_dir.iterdir():
                if file_path.suffix not in midi_extensions:
                    continue  # Ignore non-MIDI files
                
                filename = file_path.name
                
                try:
                    # Extract features from current MIDI file (your original logic)
                    features = get_midi_features(str(file_path), instrument_family_map)
                    
                    if features:
                        # Add genre and filename metadata to each feature (your original logic)
                        for feature in features:
                            feature['genre'] = genre
                            feature['filename'] = filename
                            all_features.append(feature)
                        
                        processed_files += 1
                        genre_file_count += 1
                    else:
                        logging.warning(f"No features extracted from {filename}")
                        failed_files += 1
                        
                except Exception as e:
                    logging.error(f"Error processing {filename} ({genre}): {str(e)}")
                    failed_files += 1
                
                pbar.update(1)
            
            logging.info(f"Processed {genre_file_count} files from genre: {genre}")
    
    # Create DataFrame (like your original code)
    if all_features:
        df = pd.DataFrame(all_features)
        logging.info(f"Successfully created DataFrame with {len(df)} feature rows")
        logging.info(f"Files processed: {processed_files}/{total_files} (Failed: {failed_files})")
        
        # Basic validation
        validate_dataframe(df)
        
        return df
    else:
        logging.error("No features extracted from any files!")
        return pd.DataFrame()

def validate_dataframe(df: pd.DataFrame) -> None:
    """
    Perform basic validation on the extracted features DataFrame.
    
    Args:
        df: DataFrame to validate
    """
    logging.info("=== DataFrame Validation ===")
    logging.info(f"Shape: {df.shape}")
    logging.info(f"Genres: {df['genre'].nunique()} unique ({sorted(df['genre'].unique())})")
    logging.info(f"Files processed: {df['filename'].nunique()}")
    
    # Check for missing values in critical columns
    critical_columns = ['pitch', 'onset', 'duration', 'genre', 'filename']
    for col in critical_columns:
        if col in df.columns:
            missing_count = df[col].isnull().sum()
            if missing_count > 0:
                logging.warning(f"Column '{col}' has {missing_count} missing values")
    
    # Check for potential data quality issues
    if 'pitch' in df.columns:
        pitch_range = (df[df['pitch'] != -1]['pitch'].min(), df[df['pitch'] != -1]['pitch'].max())
        logging.info(f"Pitch range (excluding rests): {pitch_range}")
        if pitch_range[0] < 0 or pitch_range[1] > 127:
            logging.warning(f"Unusual pitch range detected: {pitch_range}")
    
    if 'duration' in df.columns:
        negative_durations = (df['duration'] <= 0).sum()
        if negative_durations > 0:
            logging.warning(f"Found {negative_durations} non-positive durations")
        else:
            duration_range = (df['duration'].min(), df['duration'].max())
            logging.info(f"Duration range: {duration_range}")

def extract_features_with_summary(
    root_path: str = 'data',
    instrument_family_map: Optional[Dict] = None,
    method: str = 'sequential',  # 'sequential', 'threaded'
    max_workers: Optional[int] = None,
    save_excel: bool = True,
    excel_filename: str = 'datasets/midi_features.xlsx'
) -> pd.DataFrame:
    """
    Complete pipeline for feature extraction with summary statistics.
    
    Args:
        root_path: Path to root directory containing genre subfolders
        instrument_family_map: Optional mapping of instruments to families
        method: Extraction method ('sequential' or 'threaded')
        max_workers: Number of parallel workers (for threaded method)
        save_excel: Whether to save the DataFrame as Excel
        excel_filename: Output Excel filename
        
    Returns:
        pd.DataFrame: Extracted features DataFrame
    """
    
    logging.info("=== MIDI Feature Extraction Pipeline ===")
    logging.info(f"Root path: {root_path}")
    logging.info(f"Method: {method}")
    
    # Extract features based on chosen method
    if method == 'sequential':
        df = extract_features_sequential(
            root_path=root_path,
            instrument_family_map=instrument_family_map
        )
    elif method == 'threaded':
        df = extract_features_threaded(
            root_path=root_path,
            instrument_family_map=instrument_family_map,
            max_workers=max_workers
        )
    else:
        raise ValueError(f"Unknown method: {method}. Use 'sequential' or 'threaded'")
    
    if df.empty:
        logging.error("No features extracted. Check your MIDI files and paths.")
        return df
    
    # Generate summary statistics
    logging.info("=== Summary Statistics ===")
    genre_counts = df['genre'].value_counts()
    print(f"\nGenre distribution:")
    for genre, count in genre_counts.items():
        print(f"  {genre}: {count} features")
    
    print(f"\nFiles per genre:")
    files_per_genre = df.groupby('genre')['filename'].nunique()
    for genre, file_count in files_per_genre.items():
        feature_count = genre_counts[genre]
        print(f"  {genre}: {file_count} files, avg {feature_count/file_count:.1f} features/file")
    
    print(f"\nNumeric features summary:")
    numeric_cols = df.select_dtypes(include=['number']).columns
    print(df[numeric_cols].describe().round(2))
    
    # Save to excel if requested
    if save_excel:
        df.to_excel(excel_filename, index=False)
        logging.info(f"DataFrame saved to {excel_filename}")
    
    return df

In [5]:
df = extract_features_with_summary(
    root_path='MIDI_files',
    instrument_family_map=instrument_family_map,
    method='sequential'  # Sûr et stable
)
df.to_excel('datasets/midi_features.xlsx', index=False)

2025-08-06 14:54:13,175 - INFO - === MIDI Feature Extraction Pipeline ===
2025-08-06 14:54:13,177 - INFO - Root path: data
2025-08-06 14:54:13,178 - INFO - Method: sequential
2025-08-06 14:54:13,179 - INFO - === Sequential MIDI Feature Extraction ===
2025-08-06 14:54:13,179 - INFO - Root path: data
2025-08-06 14:54:13,183 - INFO - Found 45 MIDI files to process
Processing MIDI files:   7%|▋         | 3/45 [00:17<03:49,  5.47s/it]2025-08-06 14:54:30,374 - INFO - Processed 3 files from genre: pop
Processing MIDI files:  13%|█▎        | 6/45 [00:30<02:48,  4.33s/it]2025-08-06 14:54:43,232 - INFO - Processed 3 files from genre: heavy_metal
Processing MIDI files:  20%|██        | 9/45 [00:46<03:02,  5.06s/it]2025-08-06 14:55:00,074 - INFO - Processed 3 files from genre: disco
Processing MIDI files:  27%|██▋       | 12/45 [00:57<02:05,  3.82s/it]2025-08-06 14:55:10,577 - INFO - Processed 3 files from genre: blues
Processing MIDI files:  33%|███▎      | 15/45 [01:08<01:50,  3.67s/it]2025-08-0


Genre distribution:
  reggae: 22552 features
  rap: 20989 features
  country: 20690 features
  soul: 18681 features
  rnb: 18310 features
  new_age: 17543 features
  pop: 17299 features
  jazz: 15127 features
  disco: 13343 features
  heavy_metal: 10838 features
  blues: 10184 features
  classical: 10173 features
  dance: 9658 features
  rock: 9565 features
  alternative_rock: 8878 features

Files per genre:
  alternative_rock: 3 files, avg 2959.3 features/file
  blues: 3 files, avg 3394.7 features/file
  classical: 3 files, avg 3391.0 features/file
  country: 3 files, avg 6896.7 features/file
  dance: 3 files, avg 3219.3 features/file
  disco: 3 files, avg 4447.7 features/file
  heavy_metal: 3 files, avg 3612.7 features/file
  jazz: 3 files, avg 5042.3 features/file
  new_age: 3 files, avg 5847.7 features/file
  pop: 3 files, avg 5766.3 features/file
  rap: 3 files, avg 6996.3 features/file
  reggae: 3 files, avg 7517.3 features/file
  rnb: 3 files, avg 6103.3 features/file
  rock: 3

2025-08-06 14:58:07,381 - INFO - DataFrame saved to midi_features.csv
