<a href="https://colab.research.google.com/github/kviercz/AAI-511_Group-Project/blob/main/Final_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# This is just for my runtime, can comment out
import os
from google.colab import drive
drive.mount('/content/drive')
dataset_path = '/content/drive/MyDrive/DataFiles/'
os.chdir(dataset_path)

In [None]:
!pip install pretty_midi mido

import pretty_midi
import librosa
import glob
import random
import tesnorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, Embedding, Bidirectional
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from mido import KeySignatureError, MidiFile
from collections import Counter
import matplotlib.pyplot as plt


## Data processing and initial loading
The following two functions are helper functions that help load in the data from the different folders (Bach, Chopin, Beethoven, and Motzart). They extract tempo and note sequences for each file, as well as grab all the features and the outer folder label.

In [None]:
def extract_midi_features(file_path):
  """
  Extract tempo and note sequences from a MIDI file.

  Parameters:
      file_path (str): Path to the MIDI file.

  Returns:
      tuple: A tuple containing the tempo and note sequences.
  """
    midi_data = pretty_midi.PrettyMIDI(file_path)

    # Extract tempo from the midi file
    tempo = midi_data.estimate_tempo()

    # Extract note sequences
    note_sequences = []
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            note_sequences.append([note.start, note.end, note.pitch, note.velocity])

    return tempo, note_sequences

def process_composer_data(composer_path):
  """
  Process MIDI files for a specific composer and extract features and labels.

  Parameters:
      composer_path (str): Path to the directory containing MIDI files for a specific composer.

  Returns:
      tuple: A tuple containing two lists - features and labels.
  """
    features = []
    labels = []

    for midi_file in glob.glob(os.path.join(composer_path, '*.mid')):
        try:
            tempo, note_sequences = extract_midi_features(midi_file)
            features.append((tempo, note_sequences))
            labels.append(composer_path.split('/')[-1])
        except Exception as e:
            print(f"Error processing {midi_file}: {e}")

    return features, labels



In [None]:

def analyze_basic_midi_statistics(midi_file_path):
  try:
    midi = MidiFile(midi_file_path)
  except Exception as e:
    print(f"Error processing {midi_file_path}: {e}")
    return None

  midi = MidiFile(midi_file_path)

  num_tracks = len(midi.tracks)
  total_ticks = midi.length
  note_count = 0
  instruments = set()

  for track in midi.tracks:
      for msg in track:
          if msg.type == 'note_on':
              note_count += 1
              instruments.add(msg.channel)

  statistics = {
      'num_tracks': num_tracks,
      'total_length_ticks': total_ticks,
      'total_notes': note_count,
      'unique_instruments': len(instruments),
    }

  return statistics

def analyze_note_distribution(midi_file_path):
    midi = MidiFile(midi_file_path)
    note_counts = Counter()

    for track in midi.tracks:
        for msg in track:
            if msg.type == 'note_on' and msg.velocity > 0:
                note_counts[msg.note] += 1

    return note_counts

def analyze_rhythmic_patterns(midi_file_path):
    midi = MidiFile(midi_file_path)

    durations = []
    tempo_changes = []
    time_signature_changes = []

    current_time = 0
    last_note_on_time = {}
    tempo = 500000  # Default tempo (microseconds per beat)

    for track in midi.tracks:
        for msg in track:
            current_time += msg.time
            if msg.type == 'set_tempo':
                tempo = msg.tempo
                tempo_changes.append((current_time, tempo))
            elif msg.type == 'time_signature':
                time_signature_changes.append((current_time, (msg.numerator, msg.denominator)))
            elif msg.type == 'note_on' and msg.velocity > 0:
                last_note_on_time[msg.note] = current_time
            elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                if msg.note in last_note_on_time:
                    duration = current_time - last_note_on_time[msg.note]
                    durations.append(duration)
                    del last_note_on_time[msg.note]

    avg_duration = np.mean(durations) if durations else 0
    rhythmic_patterns = {
        'average_note_duration': avg_duration,
        'tempo_changes': tempo_changes,
        'time_signature_changes': time_signature_changes
    }

    return rhythmic_patterns

def analyze_instrument_distribution(midi_file_path):
    midi = MidiFile(midi_file_path)
    instrument_counts = Counter()

    for track in midi.tracks:
        for msg in track:
            if msg.type == 'program_change':
                instrument_counts[msg.program] += 1

    return instrument_counts

def analyze_velocity_dynamics(midi_file_path):
    midi = MidiFile(midi_file_path)
    velocities = []

    for track in midi.tracks:
        for msg in track:
            if msg.type == 'note_on' and msg.velocity > 0:
                velocities.append(msg.velocity)

    avg_velocity = np.mean(velocities) if velocities else 0
    max_velocity = np.max(velocities) if velocities else 0
    min_velocity = np.min(velocities) if velocities else 0

    dynamics = {
        'average_velocity': avg_velocity,
        'max_velocity': max_velocity,
        'min_velocity': min_velocity,
        'velocity_distribution': Counter(velocities)
    }

    return dynamics


def process_midi_dataset(dataset_path):
    """
    Process a dataset of MIDI files and aggregate the analysis results.

    Parameters:
        dataset_path (str): Path to the directory containing MIDI files.

    Returns:
        dict: Aggregated statistics and insights for the entire dataset.
    """

    basic_stats = {
        'total_tracks': 0,
        'total_length_ticks': 0,
        'total_notes': 0,
        'unique_instruments': Counter()
    }
    note_distribution = Counter()
    rhythmic_patterns = {
        'average_note_durations': [],
        'tempo_changes': [],
        'time_signature_changes': Counter()
    }
    instrument_distribution = Counter()
    velocity_dynamics = {
        'average_velocities': [],
        'max_velocities': [],
        'min_velocities': [],
        'velocity_distribution': Counter()
    }

    # Iterate over each MIDI file in the dataset
    for root, _, files in os.walk(dataset_path):
        for file in files:
            if file.endswith('.mid'):
                midi_file_path = os.path.join(root, file)

                # Analyze basic statistics
                stats = analyze_basic_midi_statistics(midi_file_path)
                if stats is None:
                  continue


                basic_stats['total_tracks'] += stats['num_tracks']
                basic_stats['total_length_ticks'] += stats['total_length_ticks']
                basic_stats['total_notes'] += stats['total_notes']
                basic_stats['unique_instruments'].update([stats['unique_instruments']])

                # Analyze note distribution
                note_counts = analyze_note_distribution(midi_file_path)
                note_distribution.update(note_counts)

                # Analyze rhythmic patterns
                rhythm_analysis = analyze_rhythmic_patterns(midi_file_path)
                rhythmic_patterns['average_note_durations'].append(rhythm_analysis['average_note_duration'])
                rhythmic_patterns['tempo_changes'].extend(rhythm_analysis['tempo_changes'])
                rhythmic_patterns['time_signature_changes'].update(rhythm_analysis['time_signature_changes'])

                # Analyze instrument distribution
                instrument_counts = analyze_instrument_distribution(midi_file_path)
                instrument_distribution.update(instrument_counts)

                # Analyze velocity dynamics
                dynamics = analyze_velocity_dynamics(midi_file_path)
                velocity_dynamics['average_velocities'].append(dynamics['average_velocity'])
                velocity_dynamics['max_velocities'].append(dynamics['max_velocity'])
                velocity_dynamics['min_velocities'].append(dynamics['min_velocity'])
                velocity_dynamics['velocity_distribution'].update(dynamics['velocity_distribution'])

    # Aggregate results
    aggregated_results = {
        'basic_stats': {
            'total_tracks': basic_stats['total_tracks'],
            'total_length_ticks': basic_stats['total_length_ticks'],
            'total_notes': basic_stats['total_notes'],
            'unique_instruments_count': len(basic_stats['unique_instruments']),
            'average_instruments_per_file': np.mean(list(basic_stats['unique_instruments'].elements()))
        },
        'note_distribution': note_distribution,
        'rhythmic_patterns': {
            'average_note_duration': np.mean(rhythmic_patterns['average_note_durations']),
            'total_tempo_changes': len(rhythmic_patterns['tempo_changes']),
            'common_time_signatures': rhythmic_patterns['time_signature_changes'].most_common(3)
        },
        'instrument_distribution': instrument_distribution,
        'velocity_dynamics': {
            'average_velocity': np.mean(velocity_dynamics['average_velocities']) if velocity_dynamics['average_velocities'] else 0,  # Handle empty list
            'max_velocity': np.max(velocity_dynamics['max_velocities']) if velocity_dynamics['max_velocities'] else 0,  # Handle empty list
            'min_velocity': np.min(velocity_dynamics['min_velocities']) if velocity_dynamics['min_velocities'] else 0,  # Handle empty list
            'common_velocities': velocity_dynamics['velocity_distribution'].most_common(3)
        }
    }

    return aggregated_results

def plot_aggregated_results(aggregated_results):
    """
    Plot the aggregated results from the MIDI dataset analysis.

    Parameters:
        aggregated_results (dict): Aggregated statistics and insights for the entire dataset.
    """
    # Plot Note Distribution
    notes = list(aggregated_results['note_distribution'].keys())
    counts = list(aggregated_results['note_distribution'].values())
    plt.figure(figsize=(12, 6))
    plt.bar(notes, counts, color='skyblue')
    plt.xlabel('MIDI Note Number')
    plt.ylabel('Count')
    plt.title('Overall Note Distribution')
    plt.show()

    # Plot Instrument Distribution
    instruments = list(aggregated_results['instrument_distribution'].keys())
    counts = list(aggregated_results['instrument_distribution'].values())
    plt.figure(figsize=(12, 6))
    plt.bar(instruments, counts, color='lightgreen')
    plt.xlabel('MIDI Instrument Number')
    plt.ylabel('Count')
    plt.title('Overall Instrument Distribution')
    plt.show()

def process_and_plot_all_folders(parent_dir):
    """
    Process all subfolders within a parent directory,
    analyze MIDI files in each, and plot the aggregated results.
    """
    for folder_name in os.listdir(parent_dir):
        folder_path = os.path.join(parent_dir, folder_name)
        if os.path.isdir(folder_path):
            print(f"Processing folder: {folder_name}")
            aggregated_results = process_midi_dataset(folder_path)
            print(aggregated_results)
            plot_aggregated_results(aggregated_results)


# Example Usage
dataset_path = '/content/drive/MyDrive/DataFiles/'
# process_and_plot_all_folders(dataset_path)
# aggregated_results = process_midi_dataset(dataset_path)
# print(aggregated_results)
# plot_aggregated_results(aggregated_results)




## EDA
The following sections of code preform EDA on the data to help determine which features are helpful as well as visualizing the differences in the composers as far as instrument composition, quantity of data, and stats given the different features.

In [None]:
# Basic Statistics
for folder_name in os.listdir(dataset_path):
    folder_path = os.path.join(dataset_path, folder_name)
    if os.path.isdir(folder_path):
        print(f"\n--- Analyzing {folder_name} ---")
        aggregated_results = process_midi_dataset(folder_path)
        print("Basic Statistics:")
        for key, value in aggregated_results['basic_stats'].items():
            print(f"  {key}: {value}")

# Rhythmic Patterns
for folder_name in os.listdir(dataset_path):
    folder_path = os.path.join(dataset_path, folder_name)
    if os.path.isdir(folder_path):
        print(f"\n--- Analyzing {folder_name} ---")
        aggregated_results = process_midi_dataset(folder_path)
        print("Rhythmic Patterns:")
        print(f"  Average Note Duration: {aggregated_results['rhythmic_patterns']['average_note_duration']}")
        print(f"  Total Tempo Changes: {aggregated_results['rhythmic_patterns']['total_tempo_changes']}")
        print("  Common Time Signatures:")
        for signature, count in aggregated_results['rhythmic_patterns']['common_time_signatures']:
            print(f"    {signature}: {count}")

# Velocity Dynamics
for folder_name in os.listdir(dataset_path):
    folder_path = os.path.join(dataset_path, folder_name)
    if os.path.isdir(folder_path):
        print(f"\n--- Analyzing {folder_name} ---")
        aggregated_results = process_midi_dataset(folder_path)
        print("Velocity Dynamics:")
        print(f"  Average Velocity: {aggregated_results['velocity_dynamics']['average_velocity']}")
        print(f"  Max Velocity: {aggregated_results['velocity_dynamics']['max_velocity']}")
        print(f"  Min Velocity: {aggregated_results['velocity_dynamics']['min_velocity']}")
        print("  Common Velocities:")
        for velocity, count in aggregated_results['velocity_dynamics']['common_velocities']:
            print(f"    {velocity}: {count}")



--- Analyzing Mozart ---
Error processing /content/drive/MyDrive/DataFiles/Mozart/Piano Sonatas/Nueva carpeta/K281 Piano Sonata n03 3mov.mid: Could not decode key with 2 flats and mode 2
Basic Statistics:
  total_tracks: 2843
  total_length_ticks: 102645.93818738835
  total_notes: 2406373
  unique_instruments_count: 15
  average_instruments_per_file: 6.48828125

--- Analyzing Beethoven ---
Error processing /content/drive/MyDrive/DataFiles/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
Basic Statistics:
  total_tracks: 1904
  total_length_ticks: 107564.80470104773
  total_notes: 2710371
  unique_instruments_count: 15
  average_instruments_per_file: 5.502369668246446

--- Analyzing Chopin ---
Basic Statistics:
  total_tracks: 984
  total_length_ticks: 30084.67425522962
  total_notes: 577457
  unique_instruments_count: 8
  average_instruments_per_file: 1.75

--- Analyzing Bach ---
Basic Statistics:
  total_tracks: 6347
  total_length_ticks: 141267.42845187857
 

### Instrument Distribution
The following code looks at how many times each instrument is used across the different composers, helps get a sense of the importance of instruments as a feature, as well as the distribution of the instruments each composer uses.

In [None]:
from pretty_midi import instrument_name_to_program, program_to_instrument_name


for folder_name in os.listdir(dataset_path):
    folder_path = os.path.join(dataset_path, folder_name)
    if os.path.isdir(folder_path):
        print(f"\n--- Instruments used by {folder_name} ---")
        aggregated_results = process_midi_dataset(folder_path)
        for instrument, count in aggregated_results['instrument_distribution'].items():
            instrument_name = program_to_instrument_name(instrument)
            print(f"  Instrument {instrument_name}: Used {count} times")

## Data Preprocessing
The following section is dedicated to processing the data as well as extracting features that will be used in training of the model. Some key components are including windowing, data augmentation and the feature selection itself. Windowing allows the LSTM model to preform much better as it takes the data in sequences to learn based on those and ensures there is overlap in the windows, so the entire piece is learned, but the groups capture the features instead of note by note. Data augmentation is helpful here as well, as it ensures the model can be generalized and not a single composer is being learned in favor over the others. This helps reduce overfitting.

#### Load Data
This initial function loads in the data from each composer, it returns the pitch, duration and velocity features for each processed file. There are a few corrupt files that do not fit the formatting that are caught and ignored with the exception. These features will be used in the rest of the data preprocessing steps.

In [None]:
def load_midi_features(directory):
    """Loads MIDI files and extracts pitch, duration, and velocity features."""
    pitch_sequences = []
    duration_sequences = []
    velocity_sequences = []

    for filename in os.listdir(directory):
        if filename.endswith(".mid"):
            midi_path = os.path.join(directory, filename)
            try:
                midi = pretty_midi.PrettyMIDI(midi_path)
                for instrument in midi.instruments:
                    pitches = [note.pitch for note in instrument.notes]
                    durations = [note.end - note.start for note in instrument.notes]
                    velocities = [note.velocity for note in instrument.notes]
                    pitch_sequences.append(pitches)
                    duration_sequences.append(durations)
                    velocity_sequences.append(velocities)
            except KeySignatureError as e:
                print(f"Skipping file {filename} due to KeySignatureError: {e}")

    return pitch_sequences, duration_sequences, velocity_sequences

# Load features for each composer
bach_pitches, bach_durations, bach_velocities = load_midi_features("/content/drive/MyDrive/DataFiles/Bach/")
mozart_pitches, mozart_durations, mozart_velocities = load_midi_features("/content/drive/MyDrive/DataFiles/Mozart/")
beethoven_pitches, beethoven_durations, beethoven_velocities = load_midi_features("/content/drive/MyDrive/DataFiles/Beethoven/")
chopin_pitches, chopin_durations, chopin_velocities = load_midi_features("/content/drive/MyDrive/DataFiles/Chopin/")


### Normalize Features
The following function normalizes the features. This scales the values between 0 and 1 which will be helpful for the LSTM model to learn the correct weights to apply to each of the features by improving the gradient descent and helping the model converge faster since the scale is more "normalized."

In [None]:
def normalize_features(pitch_sequences, duration_sequences, velocity_sequences):
    """Normalizes pitches, durations, and velocities to a 0-1 range."""
    normalized_pitches = [[note / 127.0 for note in sequence] for sequence in pitch_sequences]
    max_duration = max([max(durations) for durations in duration_sequences if durations])  # Find max duration
    normalized_durations = [[duration / max_duration for duration in sequence] for sequence in duration_sequences]
    normalized_velocities = [[velocity / 127.0 for velocity in sequence] for sequence in velocity_sequences]

    return normalized_pitches, normalized_durations, normalized_velocities

# Normalize the features
bach_pitches, bach_durations, bach_velocities = normalize_features(bach_pitches, bach_durations, bach_velocities)
mozart_pitches, mozart_durations, mozart_velocities = normalize_features(mozart_pitches, mozart_durations, mozart_velocities)
beethoven_pitches, beethoven_durations, beethoven_velocities = normalize_features(beethoven_pitches, beethoven_durations, beethoven_velocities)
chopin_pitches, chopin_durations, chopin_velocities = normalize_features(chopin_pitches, chopin_durations, chopin_velocities)

### Augmentation
The following functions augment and combine the features that can be passed on to the remainder of the data preprocessing stages. Augmentation helps prevent overfitting as well as allowing the model to better generalize the data. Given the imbalance in data, it also helps ensure that Bach is not just dominating, as it helps artificially generate data for the other composers based on the existing files.

In [None]:
def augment_sequence(sequence, num_augmentations=2):
    """Augments a note sequence by transposing and creating variations."""
    augmented_sequences = [sequence]
    for _ in range(num_augmentations):
        transpose_amount = random.randint(-5, 5)
        augmented_sequence = [[note[0] + transpose_amount, note[1], note[2]] for note in sequence]
        augmented_sequences.append(augmented_sequence)
    return augmented_sequences

def combine_features(pitch_sequences, duration_sequences, velocity_sequences):
    """Combines pitch, duration, and velocity features into a single array."""
    combined_features = []
    for pitches, durations, velocities in zip(pitch_sequences, duration_sequences, velocity_sequences):
        sequence_length = min(len(pitches), len(durations), len(velocities))
        combined_sequence = []
        for i in range(sequence_length):
            combined_sequence.append([pitches[i], durations[i], velocities[i]])
        combined_features.append(combined_sequence)
    return combined_features

# Combine features for each composer
combined_bach = combine_features(bach_pitches, bach_durations, bach_velocities)
combined_mozart = combine_features(mozart_pitches, mozart_durations, mozart_velocities)
combined_beethoven = combine_features(beethoven_pitches, beethoven_durations, beethoven_velocities)
combined_chopin = combine_features(chopin_pitches, chopin_durations, chopin_velocities)

# Augment combined data
augmented_combined_bach = [seq for orig_seq in combined_bach for seq in augment_sequence(orig_seq)]
augmented_combined_mozart = [seq for orig_seq in combined_mozart for seq in augment_sequence(orig_seq)]
augmented_combined_beethoven = [seq for orig_seq in combined_beethoven for seq in augment_sequence(orig_seq)]
augmented_combined_chopin = [seq for orig_seq in combined_chopin for seq in augment_sequence(orig_seq)]

### More Augmentation
The next two functions apply more feature augmentation specifically focused on noise and tempo, respectfully.

In [None]:
def augment_with_noise(sequence, noise_factor=0.01):
    """Adds noise to the sequence to create variation."""
    noisy_sequence = []
    for note in sequence:
        noisy_note = [
            note[0] + random.uniform(-noise_factor, noise_factor),
            note[1] + random.uniform(-noise_factor, noise_factor),
            note[2] + random.uniform(-noise_factor, noise_factor)
        ]
        noisy_sequence.append(noisy_note)
    return noisy_sequence

def augment_with_tempo(sequence, tempo_factor_range=(0.9, 1.1)):
    """Scales the tempo of the sequence."""
    tempo_factor = random.uniform(*tempo_factor_range)
    tempo_scaled_sequence = [[note[0], note[1] * tempo_factor, note[2]] for note in sequence]
    return tempo_scaled_sequence

# Augment with noise and tempo scaling
augmented_combined_bach.extend([augment_with_noise(seq) for seq in augmented_combined_bach])
augmented_combined_bach.extend([augment_with_tempo(seq) for seq in augmented_combined_bach])

augmented_combined_mozart.extend([augment_with_noise(seq) for seq in augmented_combined_mozart])
augmented_combined_mozart.extend([augment_with_tempo(seq) for seq in augmented_combined_mozart])

augmented_combined_beethoven.extend([augment_with_noise(seq) for seq in augmented_combined_beethoven])
augmented_combined_beethoven.extend([augment_with_tempo(seq) for seq in augmented_combined_beethoven])

augmented_combined_chopin.extend([augment_with_noise(seq) for seq in augmented_combined_chopin])
augmented_combined_chopin.extend([augment_with_tempo(seq) for seq in augmented_combined_chopin])

### Windowing
The final function applies windowing to the data. In previous training iterations, the model struggled to get past 33% accuracy, regardless of the other preprocessing stages. Applying windowing allows the model to iterate over sections of the data and learn patterns that way. This is an important component of LSTM as well as audio processing as each note does contain feature information, the important learning and processing comes from patterns within the music, which is where windowing helps.

In [None]:
def create_windows(sequence, window_size, step_size=1):
    """Splits a sequence into overlapping windows."""
    windows = []
    for start in range(0, len(sequence) - window_size + 1, step_size):
        end = start + window_size
        window = sequence[start:end]
        windows.append(window)
    return windows


# Apply windowing to all sequences
window_size = 50
step_size = 20

# Function to apply windowing to an entire dataset
def apply_windowing(sequences, window_size, step_size=1):
    windowed_sequences = []
    for sequence in sequences:
        windows = create_windows(sequence, window_size, step_size)
        windowed_sequences.extend(windows)  # Add all windows of a sequence
    return windowed_sequences

# Apply windowing to each composer's dataset
windowed_bach = apply_windowing(augmented_combined_bach, window_size, step_size)
windowed_mozart = apply_windowing(augmented_combined_mozart, window_size, step_size)
windowed_beethoven = apply_windowing(augmented_combined_beethoven, window_size, step_size)
windowed_chopin = apply_windowing(augmented_combined_chopin, window_size, step_size)

# Combine windowed data
all_windowed_data = windowed_bach + windowed_mozart + windowed_beethoven + windowed_chopin

# Update labels for each windowed sequence
windowed_labels = [0] * len(windowed_bach) + [1] * len(windowed_mozart) + \
                  [2] * len(windowed_beethoven) + [3] * len(windowed_chopin)

### Final Processing Steps
The final lines of code combine the data and ensure the size and shape match what is expected for the model. This includes flattening, tokenizing, padding, one hot encoding the labels, and splitting the data into training and testing sets.

In [None]:
# Flatten windowed data
flattened_data = [note[0] for sequence in all_windowed_data for note in sequence]

# Tokenize the note pitches
all_pitches = sorted(set(flattened_data))
note_to_int = {note: number for number, note in enumerate(all_pitches)}

# Convert note sequences to integer sequences
tokenized_data = [[note_to_int[note[0]] for note in sequence] for sequence in all_windowed_data]

# Pad sequences to ensure consistent input shape
X = pad_sequences(tokenized_data, maxlen=window_size, padding='post')

# Convert labels to one-hot encoding
labels = to_categorical(windowed_labels)

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.2, random_state=42)

## LSTM Model
This next section builds, compiles and trains the LSTM model.  Due to long training times, as well as to prevent overfitting, applying early stopping is included in this section. There is also adding checkpoints and saving the model for further analysis, as well as to allow for repeatability for anyone picking up this notebook. The LSTM model is composed of 8 layers, including dropout and dense layers to prevent overfitting. The use of Relu and softmax are included to help increase the gradient. Since there are 4 possible composers, the categorical_crossentropy loss function is applied and the adam optimizer is used. Using 10 epochs tended to give high accuracy, 98%, so to reduce overfitting, no more epochs are needed. The loss and accuracy funtions are printed at the end for initial analysis.

In [None]:
# Define LSTM model
model_windowed = Sequential()
model_windowed.add(Embedding(len(all_pitches), 128, input_length=window_size))
model_windowed.add(LSTM(256, return_sequences=True))
model_windowed.add(Dropout(0.3))
model_windowed.add(LSTM(128))
model_windowed.add(Dropout(0.3))
model_windowed.add(Dense(64, activation='relu'))
model_windowed.add(Dense(4, activation='softmax'))

# Compile the model
model_windowed.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Early Stopping to prevent overfitting
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Create tf.data.Dataset for training
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.batch(512).prefetch(tf.data.AUTOTUNE)

# Create tf.data.Dataset for testing
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_dataset = test_dataset.batch(512).prefetch(tf.data.AUTOTUNE)

checkpoint_callback = ModelCheckpoint(
    filepath='/content/drive/MyDrive/windowed_model.keras',
    save_weights_only=False,
    monitor='val_loss',
    mode='min',
    save_best_only=True
)

# Train the model using the dataset
model_windowed.fit(train_dataset, epochs=10, validation_data=test_dataset, callbacks=[early_stopping])

model_windowed.save('/content/drive/MyDrive/windowed_model')

# Evaluate the model
loss, accuracy = model_windowed.evaluate(X_test, y_test)
print(f"Test Loss: {loss}, Test Accuracy: {accuracy}")