In [16]:
import numpy as np
import pandas as pd
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
import pretty_midi
import warnings
import os

def get_genres(path):
    """
    This function reads the genre labels and puts it into a pandas DataFrame.
    @input path: The path to the genre label file.
    @type path: String
    @return: A pandas dataframe containing the genres and midi IDs.
    @rtype: pandas.DataFrame
    """
    ids = []
    genres = []
    with open(path) as f:
        line = f.readline()
        while line:
            if line[0] != '#':
                [x, y, *_] = line.strip().split("\t")
                ids.append(x)
                genres.append(y)
            line = f.readline()
    genre_df = pd.DataFrame(data={"Genre": genres, "TrackID": ids})
    return genre_df

def get_matched_midi(midi_folder, genre_df):
    """
    This function loads in midi file paths that are found in the given folder, puts this data into a
    pandas DataFrame, then matches each entry with a genre described in get_genres.
    @input midi_folder: The path to the midi files.
    @type midi_folder: String
    @input genre_df: The genre label dataframe generated by get_genres.
    @type genre_df: pandas.DataFrame
    @return: A dataframe of track id and path to a midi file with that track id.
    @rtype: pandas.DataFrame
    """
    # Get All Midi Files
    track_ids, file_paths = [], []
    for dir_name, subdir_list, file_list in os.walk(midi_folder):
        if len(dir_name) == 36:
            track_id = dir_name[18:]
            file_path_list = ["/".join([dir_name, file]) for file in file_list]
            for file_path in file_path_list:
                track_ids.append(track_id)
                file_paths.append(file_path)
    
    all_midi_df = pd.DataFrame({"TrackID": track_ids, "Path": file_paths})
    
    # Inner Join with Genre Dataframe
    df = pd.merge(all_midi_df, genre_df, on='TrackID', how='inner')
    return df.drop(["TrackID"], axis=1)

# MAIN EXECUTION - BACK TO YOUR ORIGINAL APPROACH BUT CLEANER
print("=== LOADING GENRE DATA ===")
genre_path = "msd_tagtraum_cd1.cls"
original_genre_df = get_genres(genre_path)

print(f"Total tracks loaded: {len(original_genre_df)}")
print(f"Available genres: {sorted(original_genre_df['Genre'].unique())}")

# Define target genres
filtered_genres = ["Blues", "Jazz", "Pop_Rock"]
print(f"\nTarget genres: {filtered_genres}")

# Check if target genres exist in the data
available_target_genres = [g for g in filtered_genres if g in original_genre_df['Genre'].unique()]
print(f"Target genres found in data: {available_target_genres}")

if not available_target_genres:
    print("ERROR: None of the target genres found in the data!")
    print("Available genres are:", sorted(original_genre_df['Genre'].unique()))
else:
    print(f"Genre counts before matching with MIDI:")
    for genre in available_target_genres:
        count = len(original_genre_df[original_genre_df['Genre'] == genre])
        print(f"  {genre}: {count}")

# STEP 1: Match with MIDI files using FULL original dataframe
print("\n=== MATCHING WITH MIDI FILES ===")
midi_path = "your_midi_folder_path_here"  # SET THIS TO YOUR ACTUAL PATH
matched_midi_df = get_matched_midi(midi_path, original_genre_df)

print(f"Total MIDI files matched: {len(matched_midi_df)}")
print("All genres in matched data:")
print(matched_midi_df['Genre'].value_counts())

# STEP 2: Filter matched data to only target genres
print(f"\n=== FILTERING TO TARGET GENRES ===")
filtered_matched_df = matched_midi_df[matched_midi_df['Genre'].isin(filtered_genres)].copy()

print(f"After filtering to target genres: {len(filtered_matched_df)}")
if len(filtered_matched_df) > 0:
    print("Genre distribution after filtering:")
    print(filtered_matched_df['Genre'].value_counts())
else:
    print("ERROR: No songs found for target genres!")
    print("This might mean:")
    print("1. Your MIDI folder path is wrong")
    print("2. The genre names don't match exactly")
    print("3. No MIDI files exist for these genres")

# STEP 3: Balance the dataset
if len(filtered_matched_df) > 0:
    print(f"\n=== BALANCING DATASET ===")
    max_songs_per_genre = 2750
    
    balanced_df = filtered_matched_df.groupby('Genre').apply(
        lambda x: x.head(max_songs_per_genre) if len(x) > max_songs_per_genre else x
    ).reset_index(drop=True)
    
    print("Final balanced dataset:")
    final_counts = balanced_df['Genre'].value_counts()
    print(final_counts)
    
    # Create label dictionary
    label_list = list(final_counts.index)
    label_dict = {lbl: label_list.index(lbl) for lbl in label_list}
    
    print(f"\nLabel mapping: {label_dict}")
    print(f"Total songs for training: {len(balanced_df)}")
    
    # Show sample of final data
    print(f"\nSample of final data:")
    print(balanced_df.head())
    
else:
    print("Cannot proceed - no data to balance!")

# DEBUG: Let's also check what the MIDI folder structure looks like
print(f"\n=== DEBUGGING MIDI FOLDER ===")
print(f"Checking MIDI folder: {midi_path}")
if os.path.exists(midi_path):
    print("MIDI folder exists")
    # Count directories with length 36 (as per your logic)
    dir_count = 0
    for dir_name, subdir_list, file_list in os.walk(midi_path):
        if len(dir_name) == 36:
            dir_count += 1
            if dir_count <= 3:  # Show first 3 examples
                track_id = dir_name[18:]
                print(f"  Example directory: {dir_name}")
                print(f"  Track ID extracted: {track_id}")
                print(f"  Files in directory: {file_list}")
    print(f"Total directories with length 36: {dir_count}")
else:
    print("ERROR: MIDI folder does not exist!")
    print("Please set the correct path in the 'midi_path' variable")

=== LOADING GENRE DATA ===
Total tracks loaded: 133676
Available genres: ['Blues', 'Country', 'Electronic', 'Folk', 'International', 'Jazz', 'Latin', 'New Age', 'Pop_Rock', 'Rap', 'Reggae', 'RnB', 'Vocal']

Target genres: ['Blues', 'Jazz', 'Pop_Rock']
Target genres found in data: ['Blues', 'Jazz', 'Pop_Rock']
Genre counts before matching with MIDI:
  Blues: 2933
  Jazz: 7783
  Pop_Rock: 79937

=== MATCHING WITH MIDI FILES ===
Total MIDI files matched: 0
All genres in matched data:
Series([], Name: count, dtype: int64)

=== FILTERING TO TARGET GENRES ===
After filtering to target genres: 0
ERROR: No songs found for target genres!
This might mean:
1. Your MIDI folder path is wrong
2. The genre names don't match exactly
3. No MIDI files exist for these genres
Cannot proceed - no data to balance!

=== DEBUGGING MIDI FOLDER ===
Checking MIDI folder: your_midi_folder_path_here
ERROR: MIDI folder does not exist!
Please set the correct path in the 'midi_path' variable


In [1]:
import numpy as np
import pandas as pd
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
import pretty_midi
import warnings
import os

def get_genres(path):
    """
    This function reads the genre labels and puts it into a pandas DataFrame.
    
    @input path: The path to the genre label file.
    @type path: String
    
    @return: A pandas dataframe containing the genres and midi IDs.
    @rtype: pandas.DataFrame
    """
    ids = []
    genres = []
    with open(path) as f:
        line = f.readline()
        while line:
            if line[0] != '#':
                [x, y, *_] = line.strip().split("\t")
                ids.append(x)
                genres.append(y)
            line = f.readline()
    genre_df = pd.DataFrame(data={"Genre": genres, "TrackID": ids})
    return genre_df

# Get the Genre DataFrame
genre_path = "msd_tagtraum_cd1.cls"
genre_df = get_genres(genre_path)

# Create Genre List and Dictionary
label_list = list(set(genre_df.Genre))
label_dict = {lbl: label_list.index(lbl) for lbl in label_list}

# Print to Visualize
print(genre_df.head(), end="\n\n")
print(label_list, end="\n\n")
print(label_dict, end="\n\n")

      Genre             TrackID
0  Pop_Rock  TRAAAAK128F9318786
1       Rap  TRAAAAW128F429D538
2  Pop_Rock  TRAAABD128F429CF47
3      Jazz  TRAAAED128E0783FAB
4  Pop_Rock  TRAAAEF128F4273421

['Country', 'Electronic', 'Blues', 'Rap', 'International', 'Reggae', 'Pop_Rock', 'Latin', 'Folk', 'New Age', 'RnB', 'Jazz', 'Vocal']

{'Country': 0, 'Electronic': 1, 'Blues': 2, 'Rap': 3, 'International': 4, 'Reggae': 5, 'Pop_Rock': 6, 'Latin': 7, 'Folk': 8, 'New Age': 9, 'RnB': 10, 'Jazz': 11, 'Vocal': 12}



In [2]:
def get_matched_midi(midi_folder, genre_df):
    """
    This function loads in midi file paths that are found in the given folder, puts this data into a
    pandas DataFrame, then matches each entry with a genre described in get_genres.
    
    @input midi_folder: The path to the midi files.
    @type midi_folder: String
    @input genre_df: The genre label dataframe generated by get_genres.
    @type genre_df: pandas.DataFrame
    
    @return: A dataframe of track id and path to a midi file with that track id.
    @rtype: pandas.DataFrame
    """
    # Get All Midi Files
    track_ids, file_paths = [], []
    for dir_name, subdir_list, file_list in os.walk(midi_folder):
        if len(dir_name) == 36:
            track_id = dir_name[18:]
            file_path_list = ["/".join([dir_name, file]) for file in file_list]
            for file_path in file_path_list:
                track_ids.append(track_id)
                file_paths.append(file_path)
    all_midi_df = pd.DataFrame({"TrackID": track_ids, "Path": file_paths})
    
    # Inner Join with Genre Dataframe
    df = pd.merge(all_midi_df, genre_df, on='TrackID', how='inner')
    return df.drop(["TrackID"], axis=1)

# Obtain DataFrame with Matched Genres to File Paths
midi_path = "lmd_matched"
matched_midi_df = get_matched_midi(midi_path, genre_df)

# Print to Check Correctness
print(matched_midi_df.head())

max_songs_per_genre = 2750
    
# balanced_df = filtered_matched_df.groupby('Genre').apply(
#     lambda x: x.head(max_songs_per_genre) if len(x) > max_songs_per_genre else x
# ).reset_index(drop=True)

df_filtered = matched_midi_df.groupby('Genre').apply(lambda x: x.head(750)).reset_index(drop=True)
# Step 1: Count the number of tracks per genre
top_genres = ["Pop_Rock", "Country", "Electronic"]
# Step 2: Filter the DataFrame to keep only the top 3 genres
df_filtered = df_filtered[df_filtered['Genre'].isin(top_genres)]

# Now you can proceed with the previous filtering/limiting steps for each genre
#df_top3_filtered = df_top3.groupby('genre').apply(lambda x: x.head(2000)).reset_index(drop=True)


print("Final balanced dataset:")
final_counts = df_filtered['Genre'].value_counts()
print(final_counts)

                                                Path     Genre
0  lmd_matched/R/R/U/TRRRUFD12903CD7092/6c460e4c5...  Pop_Rock
1  lmd_matched/R/R/U/TRRRUFD12903CD7092/6ca2a1f03...  Pop_Rock
2  lmd_matched/R/R/A/TRRRAJP128E0793859/2c78d25cb...       RnB
3  lmd_matched/R/R/F/TRRRFLX128F9326186/afd3595e5...  Pop_Rock
4  lmd_matched/R/R/F/TRRRFLX128F9326186/74582fd38...  Pop_Rock
Final balanced dataset:
Genre
Country       750
Electronic    750
Pop_Rock      750
Name: count, dtype: int64


  df_filtered = matched_midi_df.groupby('Genre').apply(lambda x: x.head(750)).reset_index(drop=True)


In [23]:
%%time
def normalize_features(features):
    """
    This function normalizes the features to the range [-1, 1]
    
    @input features: The array of features.
    @type features: List of float
    
    @return: Normalized features.
    @rtype: List of float
    """
    tempo = (features[0] - 150) / 300
    num_sig_changes = (features[1] - 2) / 10
    resolution = (features[2] - 260) / 400
    time_sig_1 = (features[3] - 3) / 8
    time_sig_2 = (features[4] - 3) / 8
    return [tempo, resolution, time_sig_1, time_sig_2]


def get_features(path):
    """
    This function extracts the features from a midi file when given its path.
    
    @input path: The path to the midi file.
    @type path: String
    
    @return: The extracted features.
    @rtype: List of float
    """
    try:
        # Test for Corrupted Midi Files
        with warnings.catch_warnings():
            warnings.simplefilter("error")
            file = pretty_midi.PrettyMIDI(path)
            
            tempo = file.estimate_tempo()
            num_sig_changes = len(file.time_signature_changes)
            resolution = file.resolution
            ts_changes = file.time_signature_changes
            ts_1 = 4
            ts_2 = 4
            if len(ts_changes) > 0:
                ts_1 = ts_changes[0].numerator
                ts_2 = ts_changes[0].denominator
            return normalize_features([tempo, num_sig_changes, resolution, ts_1, ts_2])
    except:
        return None


def extract_midi_features(path_df):
    """
    This function takes in the path DataFrame, then for each midi file, it extracts certain
    features, maps the genre to a number and concatenates these to a large design matrix to return.
    
    @input path_df: A dataframe with paths to midi files, as well as their corresponding matched genre.
    @type path_df: pandas.DataFrame
    
    @return: A matrix of features along with label.
    @rtype: numpy.ndarray of float
    """
    all_features = []
    for index, row in path_df.iterrows():
        features = get_features(row.Path)
        genre = label_dict[row.Genre]
        if features is not None:
            features.append(genre)
            all_features.append(features)
    return np.array(all_features)

labeled_features = extract_midi_features(matched_midi_df)
print(labeled_features)

In [59]:
# Shuffle Entire Dataset to Make Random
labeled_features = np.random.permutation(labeled_features)

# Partition into 3 Sets
num = len(labeled_features)
num_training = int(num * 0.7)
num_validation = int(num * 0.8)
training_data = labeled_features[:num_training]
validation_data = labeled_features[num_training:num_validation]
test_data = labeled_features[num_validation:]

# Separate Features from Labels
num_cols = training_data.shape[1] - 1
training_features = training_data[:, :num_cols]
validation_features = validation_data[:, :num_cols]
test_features = test_data[:, :num_cols]

# Format Features for Multi-class Classification
num_classes = len(label_list)
training_labels = training_data[:, num_cols].astype(int)
validation_labels = validation_data[:, num_cols].astype(int)
test_labels = test_data[:, num_cols].astype(int)

# Function for One-Hot Encoding
def one_hot(labels):
    """
    This function encodes the labels using one-hot encoding.
    
    @input num_classes: The number of genres/classes.
    @type num_classes: int
    @input labels: The genre labels to encode.
    @type labels: numpy.ndarray of int
    
    @return: The one-hot encoding of the labels.
    @rtype: numpy.ndarray of int
    """
    return np.eye(num_classes)[labels].astype(int)

# Print to Check Dimentions and to Visualize
print(test_features[:10])
print(test_labels[:10])
print(one_hot(test_labels)[:10])

Genre
Blues       2750
Jazz        2750
Pop_Rock    2750
dtype: int64
[[ 0.16203749 -0.17        0.125       0.125     ]
 [ 0.20053524  0.31        0.125       0.125     ]
 [ 0.15900013 -0.41        0.125       0.125     ]
 [ 0.29418651  0.55        0.125       0.125     ]
 [ 0.2800312   0.55        0.125       0.125     ]
 [ 0.3134415   0.31        0.125       0.125     ]
 [ 0.30103625 -0.41        0.125       0.125     ]
 [ 0.09089741 -0.05        0.125       0.125     ]
 [ 0.06666695  0.31        0.125       0.125     ]
 [ 0.13964432  0.55        0.125       0.125     ]]
[0 2 0 0 0 1 0 0 1 0]
[[1 0 0]
 [0 0 1]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [0 1 0]
 [1 0 0]
 [1 0 0]
 [0 1 0]
 [1 0 0]]


In [12]:
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle
import pretty_midi
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder

class SimpleMIDIGenreClassifier:
    def __init__(self):
        self.model = None
        self.scaler = StandardScaler()
        self.label_encoder = LabelEncoder()
        
    def extract_musical_features(self, midi_file_path):
        """
        Extract exactly 20 musical features from a MIDI file
        """
        try:
            midi_data = pretty_midi.PrettyMIDI(midi_file_path)
        except Exception as e:
            print(f"Error loading {midi_file_path}: {e}")
            return None
        
        # Always return exactly 20 features - initialize with zeros
        features = np.zeros(20)
        
        # Basic info
        total_time = midi_data.get_end_time()
        if total_time == 0:
            return features  # Return all zeros if empty
            
        features[0] = total_time  # Song length
        features[1] = len(midi_data.instruments)  # Number of instruments
        
        # Tempo
        tempo_changes = midi_data.get_tempo_changes()
        if len(tempo_changes[1]) > 0:
            features[2] = np.mean(tempo_changes[1])
        else:
            features[2] = 120.0
        
        # Get all non-drum notes
        all_notes = []
        for instrument in midi_data.instruments:
            if not instrument.is_drum:
                all_notes.extend(instrument.notes)
        
        if len(all_notes) == 0:
            return features  # Return mostly zeros if no notes
        
        # Pitch features
        pitches = [note.pitch for note in all_notes]
        features[3] = np.mean(pitches)           # Average pitch
        features[4] = np.std(pitches) if len(pitches) > 1 else 0  # Pitch variation
        features[5] = max(pitches) - min(pitches) # Pitch range
        
        # Rhythm features
        durations = [note.end - note.start for note in all_notes]
        features[6] = np.mean(durations)         # Average note length
        features[7] = np.std(durations) if len(durations) > 1 else 0  # Rhythm variation
        
        # Activity features
        features[8] = len(all_notes) / total_time  # Notes per second
        
        # Velocity (loudness) features
        velocities = [note.velocity for note in all_notes]
        features[9] = np.mean(velocities)        # Average loudness
        features[10] = np.std(velocities) if len(velocities) > 1 else 0  # Dynamic range
        
        # Harmony - count simultaneous notes
        onset_groups = defaultdict(list)
        for note in all_notes:
            onset_time = round(note.start * 4) / 4  # Round to quarter beats
            onset_groups[onset_time].append(note.pitch)
        
        if onset_groups:
            chord_sizes = [len(pitches) for pitches in onset_groups.values()]
            features[11] = np.mean(chord_sizes)       # Average chord size
            features[12] = max(chord_sizes)           # Biggest chord
        
        # Melodic movement
        sorted_notes = sorted(all_notes, key=lambda n: n.start)
        if len(sorted_notes) > 1:
            intervals = []
            for i in range(len(sorted_notes) - 1):
                interval = abs(sorted_notes[i+1].pitch - sorted_notes[i].pitch)
                intervals.append(interval)
            
            if intervals:
                features[13] = np.mean(intervals)     # Average jump size
                features[14] = sum(1 for i in intervals if i <= 2) / len(intervals)  # Smooth melody ratio
        
        # Key signature analysis
        if pitches:
            pitch_classes = [pitch % 12 for pitch in pitches]
            pitch_class_counts = [pitch_classes.count(pc) for pc in range(12)]
            features[15] = np.std(pitch_class_counts)     # How focused on certain notes
            features[16] = max(pitch_class_counts) / len(pitches)  # Most common note ratio
        
        # Drum features
        drum_notes = []
        for instrument in midi_data.instruments:
            if instrument.is_drum:
                drum_notes.extend(instrument.notes)
        
        features[17] = len(drum_notes) / total_time if drum_notes else 0  # Drum density
        features[18] = len(set(note.pitch for note in drum_notes)) if drum_notes else 0  # Drum variety
        
        # One more feature - total unique pitches
        features[19] = len(set(pitches)) if pitches else 0
        
        return features
    
    def load_data_from_dataframe(self, df):
        """
        Load data from your pandas DataFrame
        
        Args:
            df: DataFrame with columns like 'Genre', 'Path' (and optionally 'TrackId')
        """
        print(f"Loading {len(df)} tracks...")
        
        X = []  # Features
        y = []  # Labels (genres)
        successful_count = 0
        
        for idx, row in df.iterrows():
            features = self.extract_musical_features(row['Path'])
            if features is not None:
                # Ensure features is always exactly 20 elements
                if len(features) == 20:
                    X.append(features)
                    y.append(row['Genre'])
                    successful_count += 1
                else:
                    print(f"Warning: Track at {row['Path']} returned {len(features)} features instead of 20")
            
            # Progress update every 100 files
            if (idx + 1) % 100 == 0:
                print(f"Processed {idx + 1}/{len(df)} files... ({successful_count} successful)")
        
        print(f"Successfully processed {len(X)} out of {len(df)} tracks")
        
        if len(X) == 0:
            return np.array([]), np.array([])
        
        # Convert to numpy array - this should work now since all features have same length
        X = np.array(X)
        print(f"Feature array shape: {X.shape}")
        
        # Convert genres to numbers (jazz=0, classical=1, rock=2, etc.)
        y_encoded = self.label_encoder.fit_transform(y)
        
        return X, y_encoded
    
    def create_simple_model(self, num_features, num_genres):
        """Create a simple but effective neural network"""
        model = keras.Sequential([
            # Input layer
            keras.layers.Dense(64, activation='relu', input_shape=(num_features,)),
            keras.layers.Dropout(0.3),  # Prevent overfitting
            
            # Hidden layer
            keras.layers.Dense(32, activation='relu'),
            keras.layers.Dropout(0.3),
            
            # Output layer
            keras.layers.Dense(num_genres, activation='softmax')
        ])
        
        model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        return model
    
    def train_from_dataframe(self, df, epochs=50, test_size=0.2):
        """
        Train the model using your DataFrame
        
        Args:
            df: Your DataFrame with 'Genre', 'Path' columns
            epochs: How many times to train (more = better but slower)
            test_size: What fraction to use for testing (0.2 = 20%)
        """
        print("Extracting features from MIDI files...")
        X, y = self.load_data_from_dataframe(df)
        
        if len(X) == 0:
            raise ValueError("No valid MIDI files found! Check your file paths.")
        
        print(f"Training on {len(X)} tracks with {X.shape[1]} features each")
        print(f"Genres found: {list(self.label_encoder.classes_)}")
        
        # Split into training and testing sets
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=42, stratify=y
        )
        
        # Normalize the features (important for neural networks!)
        X_train_scaled = self.scaler.fit_transform(X_train)
        X_test_scaled = self.scaler.transform(X_test)
        
        # Create the model
        num_genres = len(self.label_encoder.classes_)
        self.model = self.create_simple_model(X_train_scaled.shape[1], num_genres)
        
        print(f"Model created with {X_train_scaled.shape[1]} features and {num_genres} genres")
        print("Starting training...")
        
        # Train the model
        history = self.model.fit(
            X_train_scaled, y_train,
            epochs=epochs,
            validation_data=(X_test_scaled, y_test),
            batch_size=16,
            verbose=1  # Show progress
        )
        
        # Test the final model
        test_loss, test_accuracy = self.model.evaluate(X_test_scaled, y_test, verbose=0)
        print(f"\nFinal accuracy: {test_accuracy:.3f} ({test_accuracy*100:.1f}%)")
        
        return history
    
    def predict_genre(self, midi_file_path):
        """
        Predict the genre of a single MIDI file
        
        Args:
            midi_file_path: Path to the MIDI file
            
        Returns:
            Dictionary with genre probabilities
        """
        if self.model is None:
            raise ValueError("Model hasn't been trained yet!")
        
        features = self.extract_musical_features(midi_file_path)
        if features is None:
            return "Error: Could not process MIDI file"
        
        # Scale the features and predict
        features_scaled = self.scaler.transform([features])
        probabilities = self.model.predict(features_scaled, verbose=0)[0]
        
        # Convert back to genre names
        results = {}
        for i, genre in enumerate(self.label_encoder.classes_):
            results[genre] = float(probabilities[i])
        
        return results
    
    def save_model(self, filepath):
        """Save the trained model"""
        self.model.save(f"{filepath}_model.h5")
        with open(f"{filepath}_scaler.pkl", 'wb') as f:
            pickle.dump(self.scaler, f)
        with open(f"{filepath}_labels.pkl", 'wb') as f:
            pickle.dump(self.label_encoder, f)
        print(f"Model saved to {filepath}_model.h5")
    
    def load_model(self, filepath):
        """Load a previously trained model"""
        self.model = keras.models.load_model(f"{filepath}_model.h5")
        with open(f"{filepath}_scaler.pkl", 'rb') as f:
            self.scaler = pickle.load(f)
        with open(f"{filepath}_labels.pkl", 'rb') as f:
            self.label_encoder = pickle.load(f)
        print(f"Model loaded from {filepath}_model.h5")

# SIMPLE USAGE EXAMPLE:
if __name__ == "__main__":
    # Assuming you have a DataFrame called 'df' with columns: TrackId, Genre, Path
    
    # Create the classifier
    classifier = SimpleMIDIGenreClassifier()
    
    # Train it on your DataFrame (this will take a while!)
    classifier.train_from_dataframe(df_filtered, epochs=3000)
    
    # Save the trained model
    classifier.save_model("my_genre_classifier")

Extracting features from MIDI files...
Loading 2250 tracks...




Processed 100/2250 files... (68 successful)
Processed 200/2250 files... (168 successful)
Error loading lmd_matched/T/U/A/TRTUAEM128F92DCDE6/6c9f9ef988b2f4b019c8dd5848d9ad2e.mid: data byte must be in range 0..127
Processed 300/2250 files... (267 successful)
Processed 400/2250 files... (367 successful)
Processed 500/2250 files... (467 successful)
Processed 600/2250 files... (567 successful)
Error loading lmd_matched/C/J/G/TRCJGNA128E078F797/1ef43219e92a0dba7b1da2a7a4becb7e.mid: data byte must be in range 0..127
Processed 700/2250 files... (666 successful)
Processed 800/2250 files... (766 successful)
Error loading lmd_matched/I/Z/O/TRIZOFZ128F429250C/0d6dac8f5dc8c2e9387f25b343bab1a2.mid: data byte must be in range 0..127
Error loading lmd_matched/N/T/M/TRNTMLI128F93358EB/9f79f5113aa40a71564937281bcc077e.mid: data byte must be in range 0..127
Processed 900/2250 files... (864 successful)
Processed 1000/2250 files... (964 successful)
Processed 1100/2250 files... (1064 successful)
Error loadi

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m112/112[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - accuracy: 0.3922 - loss: 1.1668 - val_accuracy: 0.5647 - val_loss: 0.9365
Epoch 2/1000
[1m112/112[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.5061 - loss: 0.9706 - val_accuracy: 0.6027 - val_loss: 0.8465
Epoch 3/1000
[1m112/112[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.5635 - loss: 0.8889 - val_accuracy: 0.6161 - val_loss: 0.8087
Epoch 4/1000
[1m112/112[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.5564 - loss: 0.8863 - val_accuracy: 0.6094 - val_loss: 0.7894
Epoch 5/1000
[1m112/112[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.6154 - loss: 0.8371 - val_accuracy: 0.6473 - val_loss: 0.7619
Epoch 6/1000
[1m112/112[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.6094 - loss: 0.8183 - val_accuracy: 0.6652 - val_loss: 0.7525
Epoch 7/1000
[1m112/112[0m 




Final accuracy: 0.792 (79.2%)
Model saved to my_genre_classifier_model.h5


In [10]:
classifier.load_model("my_genre_classifier")
result = classifier.predict_genre("Enya_-_Bard_Dance.mid")
print(result)  # Shows probability for each genre



Model loaded from my_genre_classifier_model.h5
{'Country': 0.023443609476089478, 'Electronic': 0.6127707362174988, 'Pop_Rock': 0.3637857139110565}
