In [None]:
import numpy as np
import tensorflow as tf
import music21
from music21 import converter, instrument, note, chord, stream
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import glob
import pickle
from tqdm import tqdm

## Data Preprocessing

In [None]:
def is_valid_midi(file_path):
    """
    Check if a file is valid MIDI file.
    
    Parameters:
    file_path: String path to the MIDI file
    """
    try:
        with open(file_path, 'rb') as f:
            header = f.read(4)
            if header != b'MThd':
                return False
            
        converter.parse(file_path)
        return True
    except Exception as e:
        print(f"Error validating {file_path}: {str(e)}")
        return False

def data_extractor(directory):
    """
    Function converts midi files to metadata and appends nested metadata lists into one large list
    composed of all the songs in the dataset.
    """
    notes = []
    offsets = []
    durations = []
    count = 0
    total_files = 0
    skipped_files = 0
    
    # Get all matching files
    files = glob.glob(directory)
    total_files = len(files)
    
    for file in files:
        try:
            # Verify the file is a valid MIDI file
            if not is_valid_midi(file):
                print(f"Skipping invalid MIDI file: {file}")
                skipped_files += 1
                continue
                
            mid = converter.parse(file)
            notes_to_parse = None
            prev_offset = 0
            count += 1
            print(f"Processing FILE: {file}, no. {count}/{total_files}")

            try:
                s2 = instrument.partitionByInstrument(mid)
                if s2 and s2.parts:  # Check if parts exist
                    notes_to_parse = s2.parts[0].recurse()
                else:
                    notes_to_parse = mid.flat.notes
            except:
                notes_to_parse = mid.flat.notes

            if not notes_to_parse:  # Skip if no notes found
                print(f"No notes found in {file}, skipping...")
                skipped_files += 1
                continue

            for element in notes_to_parse:
                if isinstance(element, note.Note):
                    notes.append(str(element.pitch))
                    durations.append(str(element.quarterLength))
                    offset_dif = float(element.offset-prev_offset)
                    offsets.append(round(offset_dif,3))
                    prev_offset = element.offset

                elif isinstance(element, chord.Chord):
                    notes.append('.'.join(str(n) for n in element.normalOrder))
                    offset_dif = float(element.offset-prev_offset)
                    durations.append(str(element.quarterLength))
                    offsets.append(round(offset_dif,3))
                    prev_offset = element.offset

        except Exception as e:
            print(f"Error processing {file}: {str(e)}")
            skipped_files += 1
            continue

    print(f"\nProcessing complete!")
    print(f"Total files: {total_files}")
    print(f"Successfully processed: {count}")
    print(f"Skipped/Invalid files: {skipped_files}")
    
    if not notes:  # Check if we got any data
        raise ValueError("No valid MIDI data was extracted from any files")
        
    return [notes, offsets, durations]

In [None]:
def save_extracted_data(directory, output_file="music_data.pkl"):
    # Save the processed data to a pkl file
    try:
        data = data_extractor(directory)
        with open(output_file, 'wb') as f:  # 'wb' for write binary mode
            pickle.dump(data, f)
        print(f"Data saved to {output_file}")
    except Exception as e:
        print(f"Error saving data: {str(e)}")
        raise

# To load the data later on
def load_extracted_data(file_path="music_data.pkl"):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

In [None]:
# load classical data from processed pkl file
classical_data = load_extracted_data("pkl/classical_kaggle.pkl")
classical_note_data = classical_data[0]
classical_offset_data = classical_data[1]
classical_duration_data = classical_data[2]
classical_unique_note_number = len(list(set(classical_note_data)))
classical_unique_notes = sorted(list(set(classical_note_data)))
classical_unique_offset_number = len(list(set(classical_offset_data)))
classical_unique_offsets = sorted(list(set(classical_offset_data)))
classical_unique_duration_number = len(list(set(classical_duration_data)))
classical_unique_durations = sorted(list(set(classical_duration_data)))

In [None]:
# load jazz data from processed pkl file
jazz_data = load_extracted_data("pkl/jazz.pkl")
jazz_note_data = jazz_data[0]
jazz_offset_data = jazz_data[1]
jazz_duration_data = jazz_data[2]
jazz_unique_note_number = len(list(set(jazz_note_data)))
jazz_unique_notes = sorted(list(set(jazz_note_data)))
jazz_unique_offset_number = len(list(set(jazz_offset_data)))
jazz_unique_offsets = sorted(list(set(jazz_offset_data)))
jazz_unique_duration_number = len(list(set(jazz_duration_data)))
jazz_unique_durations = sorted(list(set(jazz_duration_data)))

In [None]:
# load rock data from processed pkl file
rock_data = load_extracted_data("pkl/rock.pkl")
rock_note_data = rock_data[0]
rock_offset_data = rock_data[1]
rock_duration_data = rock_data[2]
rock_unique_note_number = len(list(set(rock_note_data)))
rock_unique_notes = sorted(list(set(rock_note_data)))
rock_unique_offset_number = len(list(set(rock_offset_data)))
rock_unique_offsets = sorted(list(set(rock_offset_data)))
rock_unique_duration_number = len(list(set(rock_duration_data)))
rock_unique_durations = sorted(list(set(rock_duration_data)))

In [None]:
class MusicDataset(Dataset):
    def __init__(self, notes, offsets, durations, genre_labels, sequence_length=32):
        # Store inputs and create mappings from musical elements to indices
        self.notes = notes
        self.offsets = offsets
        self.durations = durations
        self.genre_labels = genre_labels
        self.sequence_length = sequence_length
        
        # Create vocabularies
        self.note_to_idx = {note: idx for idx, note in enumerate(set(notes))}
        self.offset_to_idx = {offset: idx for idx, offset in enumerate(set(offsets))}
        self.duration_to_idx = {duration: idx for idx, duration in enumerate(set(durations))}
        
    def __len__(self):
        return len(self.notes) - self.sequence_length
        
    def __getitem__(self, idx):
        return (
            torch.tensor([self.note_to_idx[note] for note in self.notes[idx:idx+self.sequence_length]]),
            torch.tensor([self.offset_to_idx[offset] for offset in self.offsets[idx:idx+self.sequence_length]]),
            torch.tensor([self.duration_to_idx[duration] for duration in self.durations[idx:idx+self.sequence_length]]),
            torch.tensor(self.genre_labels[idx:idx+self.sequence_length])
        )

# The LSTM itself
class GenreAwareMusicLSTM(nn.Module):
    def __init__(self, note_vocab_size, offset_vocab_size, duration_vocab_size, 
                 num_genres, embedding_dim, hidden_dim):
        super().__init__()
        # Store vocab sizes as class attributes
        self.note_vocab_size = note_vocab_size
        self.offset_vocab_size = offset_vocab_size
        self.duration_vocab_size = duration_vocab_size
        
        self.note_embedding = nn.Embedding(note_vocab_size, embedding_dim)
        self.offset_embedding = nn.Embedding(offset_vocab_size, embedding_dim)
        self.duration_embedding = nn.Embedding(duration_vocab_size, embedding_dim)
        self.genre_embedding = nn.Embedding(num_genres, embedding_dim)
        
        # Combined embedding dimension
        combined_dim = embedding_dim * 4  # notes + offsets + durations + genre
        
        # LSTM layer for sequence processing
        self.lstm = nn.LSTM(
            input_size=combined_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.2
        )
        
        # Output layers for predicting next musical elements
        self.note_fc = nn.Linear(hidden_dim, note_vocab_size)
        self.offset_fc = nn.Linear(hidden_dim, offset_vocab_size)
        self.duration_fc = nn.Linear(hidden_dim, duration_vocab_size)
        
    def forward(self, notes, offsets, durations, genre_labels):
        # Embed all features
        note_embeds = self.note_embedding(notes)
        offset_embeds = self.offset_embedding(offsets)
        duration_embeds = self.duration_embedding(durations)
        genre_embeds = self.genre_embedding(genre_labels)
        
        # Combine embeddings
        combined = torch.cat(
            [note_embeds, offset_embeds, duration_embeds, genre_embeds],
            dim=-1
        )
        
        lstm_out, _ = self.lstm(combined)
        
        # Generate predictions
        note_logits = self.note_fc(lstm_out)
        offset_logits = self.offset_fc(lstm_out)
        duration_logits = self.duration_fc(lstm_out)
        
        return note_logits, offset_logits, duration_logits
    
    def generate(self, genre_id, seed_sequence, max_length=250, temperature=1.0):
        """
        Generate new music sequence based on genre and initial seed
        - Higher temperature = more random/creative
        - Lower temperature = more conservative/predictable
        """
        self.eval()
        with torch.no_grad():
            current_sequence = seed_sequence
            genre_labels = torch.full_like(current_sequence[0], genre_id)
            
            generated_notes = []
            generated_offsets = []
            generated_durations = []
            
            for _ in range(max_length):
                # Get predictions for next musical elements
                note_logits, offset_logits, duration_logits = self(
                    current_sequence[0],
                    current_sequence[1],
                    current_sequence[2],
                    genre_labels
                )
                
                # Apply temperature
                note_logits = note_logits[:, -1, :] / temperature
                offset_logits = offset_logits[:, -1, :] / temperature
                duration_logits = duration_logits[:, -1, :] / temperature
                
                # Sample from distributions
                note_probs = F.softmax(note_logits, dim=-1)
                offset_probs = F.softmax(offset_logits, dim=-1)
                duration_probs = F.softmax(duration_logits, dim=-1)
                
                next_note = torch.multinomial(note_probs, 1)
                next_offset = torch.multinomial(offset_probs, 1)
                next_duration = torch.multinomial(duration_probs, 1)
                
                # store generated elements
                generated_notes.append(next_note.item())
                generated_offsets.append(next_offset.item())
                generated_durations.append(next_duration.item())
                
                # Update current sequence
                current_sequence = (
                    torch.cat([current_sequence[0][:, 1:], next_note], dim=1),
                    torch.cat([current_sequence[1][:, 1:], next_offset], dim=1),
                    torch.cat([current_sequence[2][:, 1:], next_duration], dim=1)
                )
                
            return generated_notes, generated_offsets, generated_durations


In [None]:
def convert_fraction_to_float(fraction_str):
    """Convert string fractions to float values"""
    if '/' in fraction_str:
        num, denom = fraction_str.split('/')
        return float(num) / float(denom)
    return float(fraction_str)

In [None]:
def save_model_checkpoint(model, optimizer, epoch, genre, path):
    # save the model at its current point in training in case it crashes
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'genre': genre
    }, path)

def load_model_checkpoint(model, optimizer, path):
    # load the saved model chekpoint from file
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['genre']

## Model Training

In [None]:
def train_model(model, data, batch_size=48, num_epochs=25, sequence_length=32, 
                learning_rate=0.001, genre_id=0, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    # Early stopping parameters
    patience = 3
    min_loss = float('inf')
    no_improve_count = 0
    
    notes, offsets, durations = data
    genre_labels = [genre_id] * len(notes)
    
    dataset = MusicDataset(
        notes=notes,
        offsets=offsets,
        durations=durations,
        genre_labels=genre_labels,
        sequence_length=sequence_length
    )
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch in progress_bar:
            note_seq = batch[0].to(device)
            offset_seq = batch[1].to(device)
            duration_seq = batch[2].to(device)
            genre_seq = batch[3].to(device)
            
            optimizer.zero_grad()
            
            note_logits, offset_logits, duration_logits = model(
                note_seq, offset_seq, duration_seq, genre_seq
            )
            
            batch_size, seq_len = note_seq.shape
            
            target_notes = note_seq[:, 1:].reshape(-1)
            target_offsets = offset_seq[:, 1:].reshape(-1)
            target_durations = duration_seq[:, 1:].reshape(-1)
            
            note_logits = note_logits[:, :-1, :].reshape(-1, model.note_vocab_size)
            offset_logits = offset_logits[:, :-1, :].reshape(-1, model.offset_vocab_size)
            duration_logits = duration_logits[:, :-1, :].reshape(-1, model.duration_vocab_size)
            
            note_loss = criterion(note_logits, target_notes)
            offset_loss = criterion(offset_logits, target_offsets)
            duration_loss = criterion(duration_logits, target_durations)
            
            loss = note_loss + offset_loss + duration_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': total_loss / (len(progress_bar))})
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')
        
        # Early stopping check
        if avg_loss < min_loss:
            min_loss = avg_loss
            no_improve_count = 0
            # Save best model
            save_model_checkpoint(model, optimizer, epoch, f"genre_{genre_id}", 
                                f"best_model.pt")
        else:
            no_improve_count += 1
            print(f"Loss did not improve for {no_improve_count} epochs")
        
        if no_improve_count >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
        
        # Regular checkpoint saving
        if (epoch + 1) % 5 == 0:
            save_model_checkpoint(model, optimizer, epoch, f"genre_{genre_id}", 
                                f"checkpoint_epoch_{epoch+1}.pt")
    
    return model, optimizer

In [None]:
# Initialize the model
model = GenreAwareMusicLSTM(
    note_vocab_size=classical_unique_note_number,
    offset_vocab_size=classical_unique_offset_number,
    duration_vocab_size=classical_unique_duration_number,
    num_genres=3,
    embedding_dim=128,
    hidden_dim=256
)

# Train on classical music (genre_id = 0)
classical_data = (classical_note_data, classical_offset_data, classical_duration_data)

In [None]:
# added this because the different genres have different amounts of weights/unique notes
def expand_embedding_layer(old_model, new_vocab_size, embedding_name):
    """
    Expand an embedding layer while preserving existing weights
    """
    old_embedding = getattr(old_model, embedding_name)
    old_vocab_size, embedding_dim = old_embedding.weight.shape
    
    # Create new embedding layer with larger vocabulary
    new_embedding = nn.Embedding(new_vocab_size, embedding_dim)
    
    # Copy old weights
    with torch.no_grad():
        new_embedding.weight[:old_vocab_size] = old_embedding.weight
        
        # Initialize new embeddings with mean and std of existing embeddings
        if new_vocab_size > old_vocab_size:
            mean = old_embedding.weight.mean().item()
            std = old_embedding.weight.std().item()
            nn.init.normal_(new_embedding.weight[old_vocab_size:], mean=mean, std=std)
    
    return new_embedding

In [None]:
# added this because the different genres have different amounts of weights/unique notes
def expand_model_vocabulary(model, new_note_size, new_offset_size, new_duration_size):
    """
    Expand model vocabulary while preserving trained weights
    """
    # Save old parameters
    old_params = {
        'embedding_dim': model.note_embedding.embedding_dim,
        'hidden_dim': model.lstm.hidden_size,
        'num_genres': model.genre_embedding.num_embeddings
    }
    
    # Create new embeddings
    new_note_embedding = expand_embedding_layer(model, new_note_size, 'note_embedding')
    new_offset_embedding = expand_embedding_layer(model, new_offset_size, 'offset_embedding')
    new_duration_embedding = expand_embedding_layer(model, new_duration_size, 'duration_embedding')
    
    # Create new model
    new_model = GenreAwareMusicLSTM(
        note_vocab_size=new_note_size,
        offset_vocab_size=new_offset_size,
        duration_vocab_size=new_duration_size,
        num_genres=old_params['num_genres'],
        embedding_dim=old_params['embedding_dim'],
        hidden_dim=old_params['hidden_dim']
    )
    
    # Copy expanded embeddings
    new_model.note_embedding = new_note_embedding
    new_model.offset_embedding = new_offset_embedding
    new_model.duration_embedding = new_duration_embedding
    
    # Copy genre embedding and LSTM weights (these don't change size)
    new_model.genre_embedding.load_state_dict(model.genre_embedding.state_dict())
    new_model.lstm.load_state_dict(model.lstm.state_dict())
    
    # Copy FC layer weights
    with torch.no_grad():
        # Note FC layer
        old_note_size = model.note_fc.weight.shape[0]
        new_model.note_fc.weight[:old_note_size] = model.note_fc.weight
        new_model.note_fc.bias[:old_note_size] = model.note_fc.bias
        
        # Offset FC layer
        old_offset_size = model.offset_fc.weight.shape[0]
        new_model.offset_fc.weight[:old_offset_size] = model.offset_fc.weight
        new_model.offset_fc.bias[:old_offset_size] = model.offset_fc.bias
        
        # Duration FC layer
        old_duration_size = model.duration_fc.weight.shape[0]
        new_model.duration_fc.weight[:old_duration_size] = model.duration_fc.weight
        new_model.duration_fc.bias[:old_duration_size] = model.duration_fc.bias
        
        # Initialize new FC weights with statistics of old weights
        def init_remaining_weights(layer, old_size):
            if layer.weight.shape[0] > old_size:
                mean = layer.weight[:old_size].mean().item()
                std = layer.weight[:old_size].std().item()
                nn.init.normal_(layer.weight[old_size:], mean=mean, std=std)
                nn.init.normal_(layer.bias[old_size:], mean=mean, std=std)
        
        init_remaining_weights(new_model.note_fc, old_note_size)
        init_remaining_weights(new_model.offset_fc, old_offset_size)
        init_remaining_weights(new_model.duration_fc, old_duration_size)
    
    return new_model

In [None]:
# code to train the model on classical music for the first time
model, optimizer = train_model(
    model=model,
    data=classical_data,
    batch_size=48,
    num_epochs=25,
    sequence_length=32,
    genre_id=0
)

In [None]:
# Get combined vocabulary sizes because numbers are different between the datasets
# Change values below depending on order of training
combined_note_number = len(set(list(classical_note_data) + list(rock_note_data)))
combined_offset_number = len(set(list(classical_offset_data) + list(rock_offset_data)))
combined_duration_number = len(set(list(classical_duration_data) + list(rock_duration_data)))

# Initialize original model
model = GenreAwareMusicLSTM(
    note_vocab_size=classical_unique_note_number,
    offset_vocab_size=classical_unique_offset_number,
    duration_vocab_size=classical_unique_duration_number,
    num_genres=3,
    embedding_dim=128,
    hidden_dim=256
)

# Initialize optimizer and load checkpoint
optimizer = optim.Adam(model.parameters(), lr=0.001)
epoch, genre = load_model_checkpoint(model, optimizer, "best_model.pt")

# Expand model vocabulary
model = expand_model_vocabulary(
    model,
    new_note_size=combined_note_number,
    new_offset_size=combined_offset_number,
    new_duration_size=combined_duration_number
)

# Create new optimizer for expanded model (after training for classical)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Now train on jazz data
model, optimizer = train_model(
    model=model,
    data=jazz_data,
    genre_id=1
)

In [None]:
combined_note_number = len(set(list(classical_note_data) + list(jazz_note_data)+ list(rock_duration_data)))
combined_offset_number = len(set(list(classical_offset_data) + list(jazz_offset_data) + list(rock_duration_data)))
combined_duration_number = len(set(list(classical_duration_data) + list(jazz_duration_data) + list(rock_duration_data)))

# train on rock_data
model, optimizer = train_model(
    model=model,
    data=rock_data,
    genre_id=2
)

## Moving on to Actual Music Generation

In [None]:
def create_midi_from_prediction(generated_notes, generated_offsets, generated_durations, 
                              idx_to_note, idx_to_offset, idx_to_duration, 
                              output_file="generated_music.mid"): 
    """
    Convert generated sequences back to a MIDI file with limited chord size
    """
    output_notes = stream.Stream()

    current_offset = 0.0
    
    def number_to_pitch(number):
        try:
            if isinstance(number, str) and number.isdigit():
                number = int(number)
            
            # limiting the octaves the music can be generated in (for some reason it prefers super low notes)
            MIN_MIDI = 48  # C3
            MAX_MIDI = 83  # B6
            
            while number < MIN_MIDI:
                number += 12
            while number > MAX_MIDI:
                number -= 12
                
            p = pitch.Pitch()
            p.midi = number
            return p.nameWithOctave
        except:
            return None
    
    for i in range(len(generated_notes)):
        try:
            note_str = str(idx_to_note[generated_notes[i]])
            offset_val = convert_fraction_to_float(str(idx_to_offset[generated_offsets[i]]))
            duration_val = convert_fraction_to_float(str(idx_to_duration[generated_durations[i]]))
            
            current_offset += offset_val
            
            if '.' in note_str:  # It's a chord
                notes_in_chord = note_str.split('.')
                chord_notes = []
                # Only take up to max_chord_size notes for the chord, setting it to 1 produces single-note lines
                for current_note in notes_in_chord[:2]: # max chord size, set to 2 so that there aren't 5-note chords
                    try:
                        if current_note.isdigit():
                            pitch_name = number_to_pitch(int(current_note))
                            if pitch_name:
                                new_note = note.Note(pitch_name)
                                chord_notes.append(new_note)
                    except:
                        continue
                if chord_notes:
                    new_chord = chord.Chord(chord_notes)
                    new_chord.offset = current_offset
                    new_chord.quarterLength = duration_val
                    output_notes.append(new_chord)
            else:  # It's a note
                if note_str.isdigit():
                    pitch_name = number_to_pitch(int(note_str))
                    if pitch_name:
                        new_note = note.Note(pitch_name)
                        new_note.offset = current_offset
                        new_note.quarterLength = duration_val
                        output_notes.append(new_note)
                        
        except Exception as e:
            print(f"Skipping problematic note/chord: {note_str}")
            continue
    
    if len(output_notes) == 0:
        raise ValueError("No valid notes were generated")
        
    output_notes.write('midi', fp=output_file)
    return output_file

In [None]:
def generate_music(model, genre_id, sequence_length, idx_to_note, idx_to_offset, idx_to_duration,
                  seed_sequence=None, max_length=250, temperature=1.0,
                  output_file="generated_music.mid", device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Generate music using the trained model and save it as a MIDI file
    
    Parameters:
    model: trained GenreAwareMusicLSTM model
    genre_id: int, which genre to generate (0=classical, 1=jazz, etc.)
    sequence_length: int, length of input sequences used during training
    idx_to_note/offset/duration: dictionaries to convert indices back to musical values
    seed_sequence: optional tuple of (notes, offsets, durations) to start generation
    max_length: int, number of notes to generate
    output_file: string, where to save the MIDI file
    """
    model.eval()
    
    # If no seed sequence provided, create a random one
    if seed_sequence is None:
        # Create a random seed sequence
        seed_notes = np.random.randint(0, len(idx_to_note), sequence_length)
        seed_offsets = np.random.randint(0, len(idx_to_offset), sequence_length)
        seed_durations = np.random.randint(0, len(idx_to_duration), sequence_length)
        
        seed_sequence = (
            torch.tensor([seed_notes]).to(device),
            torch.tensor([seed_offsets]).to(device),
            torch.tensor([seed_durations]).to(device)
        )
    
    # Generate the music
    with torch.no_grad():
        generated_notes, generated_offsets, generated_durations = model.generate(
            genre_id=genre_id,
            seed_sequence=seed_sequence,
            max_length=max_length,
            temperature=temperature,
        )
    
    # Convert to MIDI and save
    midi_file = create_midi_from_prediction(
        generated_notes, 
        generated_offsets, 
        generated_durations,
        idx_to_note,
        idx_to_offset,
        idx_to_duration,
        output_file
    )
    
    return midi_file

In [None]:
# initialize model
# may not run properly depending on size of model & vocab size which you might need to change
model = GenreAwareMusicLSTM(
    note_vocab_size=682, # matching the saved model's size
    offset_vocab_size=330,
    duration_vocab_size=103, 
    num_genres=3,
    embedding_dim=128,
    hidden_dim=256
)

# Initialize optimizer (needed for loading the checkpoint)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Load the saved model from pt file (trained in other notebook)
epoch, genre = load_model_checkpoint(model, optimizer, "best_model.pt")

# Create dataset to get the mappings
dataset = MusicDataset(
    notes=classical_note_data,
    offsets=classical_offset_data,
    durations=classical_duration_data,
    genre_labels=[0] * len(classical_note_data),
    sequence_length=32
)

# Get the inverse mappings
idx_to_note = {idx: note for note, idx in dataset.note_to_idx.items()}
idx_to_offset = {idx: offset for offset, idx in dataset.offset_to_idx.items()}
idx_to_duration = {idx: duration for duration, idx in dataset.duration_to_idx.items()}

# Generate music with loaded model
generated_file = generate_music(
    model=model,
    genre_id=0, # Classical 0, Jazz 1, Rock 2
    sequence_length=32,
    idx_to_note=idx_to_note,
    idx_to_offset=idx_to_offset,
    idx_to_duration=idx_to_duration,
    temperature=1.5,
    output_file="my_piece.mid"
)

# output should be saved as output_file above
# MIDIs can be viewed and played with applications such as Musescore or Finale