# Music Generation with Dual Markov Chains

This notebook uses two synchronized Markov chains to generate music:
- One chain for pitch (note names)
- One chain for duration (note lengths)

The chains are trained on MIDI files and generate new melodies.


In [7]:
import numpy as np
from music21 import converter, note, stream, midi, instrument


In [8]:
class MarkovChain:
    """
    A generic Markov Chain implementation for a single attribute.
    Can be used for pitches, durations, or any discrete sequence.
    """
    
    def __init__(self, states):
        """
        Initialize the Markov Chain with a set of possible states.
        
        Args:
            states: List of possible states (e.g., ['C5', 'D5', 'E5'])
        """
        self.states = states
        self.initial_probabilities = np.zeros(len(states))
        self.transition_matrix = np.zeros((len(states), len(states)))
        self._state_indexes = {state: i for i, state in enumerate(states)}
    
    def train(self, sequence):
        """
        Train the Markov chain on a sequence of states.
        
        Args:
            sequence: List of states to learn from
        """
        self._calculate_initial_probabilities(sequence)
        self._calculate_transition_matrix(sequence)
    
    def _calculate_initial_probabilities(self, sequence):
        """Calculate probability distribution for initial states."""
        for state in sequence:
            if state in self._state_indexes:
                self.initial_probabilities[self._state_indexes[state]] += 1
        
        # Normalize
        total = np.sum(self.initial_probabilities)
        if total > 0:
            self.initial_probabilities /= total
        self.initial_probabilities = np.nan_to_num(self.initial_probabilities)
    
    def _calculate_transition_matrix(self, sequence):
        """Calculate transition probabilities between states."""
        for i in range(len(sequence) - 1):
            current_state = sequence[i]
            next_state = sequence[i + 1]
            
            if current_state in self._state_indexes and next_state in self._state_indexes:
                current_idx = self._state_indexes[current_state]
                next_idx = self._state_indexes[next_state]
                self.transition_matrix[current_idx, next_idx] += 1
        
        # Normalize each row
        self._normalize_transition_matrix()
    
    def _normalize_transition_matrix(self):
        """Normalize transition matrix so each row sums to 1."""
        row_sums = self.transition_matrix.sum(axis=1)
        
        with np.errstate(divide='ignore', invalid='ignore'):
            self.transition_matrix = np.where(
                row_sums[:, None],
                self.transition_matrix / row_sums[:, None],
                0
            )
    
    def generate_starting_state(self):
        """
        Generate a starting state based on initial probabilities.
        
        Returns:
            A state from the states list
        """
        if np.sum(self.initial_probabilities) == 0:
            return self.states[0]  # Fallback to first state
        
        initial_index = np.random.choice(
            list(self._state_indexes.values()),
            p=self.initial_probabilities
        )
        return self.states[initial_index]
    
    def generate_next_state(self, current_state):
        """
        Generate the next state based on the current state.
        
        Args:
            current_state: The current state in the chain
        
        Returns:
            The next state
        """
        if current_state not in self._state_indexes:
            return self.generate_starting_state()
        
        current_idx = self._state_indexes[current_state]
        transition_probs = self.transition_matrix[current_idx]
        
        # If no transitions from this state, generate a new starting state
        if np.sum(transition_probs) == 0:
            return self.generate_starting_state()
        
        next_index = np.random.choice(
            list(self._state_indexes.values()),
            p=transition_probs
        )
        return self.states[next_index]


In [None]:
class DualMarkovChainMusicGenerator:
    """
    A music generator using three synchronized Markov chains:
    one for pitch, one for duration, and one for instruments.
    """
    
    def __init__(self, pitch_states=None, duration_states=None, instrument_states=None):
        """
        Initialize the dual Markov chain music generator.
        
        Args:
            pitch_states: List of possible pitches (if None, will be extracted from training data)
            duration_states: List of possible durations (if None, will be extracted from training data)
            instrument_states: List of possible instruments (if None, will be extracted from training data)
        """
        self.pitch_states = pitch_states
        self.duration_states = duration_states
        self.instrument_states = instrument_states
        self.pitch_chain = None
        self.duration_chain = None
        self.instrument_chain = None
        self.instrument_objects = {}  # Map instrument names to music21 instrument objects
    
    def train_from_midi(self, midi_file_path, exclude_percussion=True):
        """
        Train the Markov chains from a MIDI file.
        
        Args:
            midi_file_path: Path to the MIDI file to train on
            exclude_percussion: If True, exclude percussion instruments (default: True)
        """
        print(f"Loading MIDI file: {midi_file_path}")
        
        # Parse the MIDI file
        score = converter.parse(midi_file_path)
        
        # Extract notes from all parts
        notes_list = self._extract_notes_from_score(score, exclude_percussion=exclude_percussion)
        
        if not notes_list:
            raise ValueError("No notes found in the MIDI file")
        
        print(f"Extracted {len(notes_list)} notes from MIDI file")
        
        # Extract pitch, duration, and instrument sequences
        pitch_sequence = [n.pitch.nameWithOctave for n in notes_list]
        duration_sequence = [n.duration.quarterLength for n in notes_list]
        instrument_sequence = [n.instrument_name for n in notes_list]
        
        # Determine states if not provided
        if self.pitch_states is None:
            self.pitch_states = sorted(list(set(pitch_sequence)))
        
        if self.duration_states is None:
            self.duration_states = sorted(list(set(duration_sequence)))
        
        if self.instrument_states is None:
            self.instrument_states = sorted(list(set(instrument_sequence)))
        
        print(f"Pitch states: {len(self.pitch_states)} unique pitches")
        print(f"Duration states: {len(self.duration_states)} unique durations")
        print(f"Instrument states: {len(self.instrument_states)} unique instruments")
        
        # Create and train the Markov chains
        self.pitch_chain = MarkovChain(self.pitch_states)
        self.duration_chain = MarkovChain(self.duration_states)
        self.instrument_chain = MarkovChain(self.instrument_states)
        
        print("Training pitch chain...")
        self.pitch_chain.train(pitch_sequence)
        
        print("Training duration chain...")
        self.duration_chain.train(duration_sequence)
        
        print("Training instrument chain...")
        self.instrument_chain.train(instrument_sequence)
        
        print("Training complete!")
    
    def _extract_notes_from_score(self, score, exclude_percussion=True):
        """
        Extract all notes from a music21 score, with optional filtering.
        
        Args:
            score: A music21 score object
            exclude_percussion: If True, exclude percussion instruments
        
        Returns:
            List of note.Note objects
        """
        notes_list = []
        
        # Get all parts in the score
        parts = score.parts
        
        print(f"\nFound {len(parts)} part(s) in the MIDI file:")
        
        for i, part in enumerate(parts):
            # Get the instrument for this part
            part_instrument = part.getInstrument()
            instrument_name = part_instrument.instrumentName if part_instrument else "Unknown"
            
            # Check if this is a percussion instrument
            is_percussion = False
            if part_instrument:
                # Check if it's a percussion instrument or uses MIDI channel 10 (percussion channel)
                is_percussion = (
                    isinstance(part_instrument, instrument.Percussion) or
                    isinstance(part_instrument, instrument.UnpitchedPercussion) or
                    part.getInstrument().midiChannel == 9  # MIDI channel 10 (0-indexed as 9)
                )
            
            # Count notes in this part
            part_notes = [el for el in part.flatten().notesAndRests if isinstance(el, note.Note)]
            
            status = "EXCLUDED (percussion)" if (is_percussion and exclude_percussion) else "included"
            print(f"  Part {i+1}: {instrument_name} - {len(part_notes)} notes - {status}")
            
            # Add notes if not excluded
            if not (is_percussion and exclude_percussion):
                # Store the instrument object for later use
                self.instrument_objects[instrument_name] = part_instrument
                
                # Attach instrument name to each note for tracking
                for n in part_notes:
                    n.instrument_name = instrument_name
                
                notes_list.extend(part_notes)
        
        return notes_list
    
    def generate(self, length):
        """
        Generate a melody using synchronized Markov chains.
        
        Args:
            length: Number of notes to generate
        
        Returns:
            List of tuples (pitch, duration, instrument_name)
        """
        if self.pitch_chain is None or self.duration_chain is None:
            raise ValueError("Model must be trained before generation")
        
        print(f"Generating {length} notes...")
        
        melody = []
        
        # Generate starting states for all three chains
        current_pitch = self.pitch_chain.generate_starting_state()
        current_duration = self.duration_chain.generate_starting_state()
        current_instrument = self.instrument_chain.generate_starting_state() if self.instrument_chain else "Piano"
        melody.append((current_pitch, current_duration, current_instrument))
        
        # Generate subsequent notes
        for _ in range(1, length):
            current_pitch = self.pitch_chain.generate_next_state(current_pitch)
            current_duration = self.duration_chain.generate_next_state(current_duration)
            if self.instrument_chain:
                current_instrument = self.instrument_chain.generate_next_state(current_instrument)
            melody.append((current_pitch, current_duration, current_instrument))
        
        print("Generation complete!")
        return melody
    
    def save_to_midi(self, melody, output_path):
        """
        Save a generated melody to a MIDI file with multiple instrument parts.
        
        Args:
            melody: List of tuples (pitch, duration, instrument_name)
            output_path: Path to save the MIDI file
        """
        print(f"Saving melody to {output_path}...")
        
        # Create a score to hold multiple parts
        score = stream.Score()
        
        # Group notes by instrument
        notes_by_instrument = {}
        for pitch, duration, instrument_name in melody:
            if instrument_name not in notes_by_instrument:
                notes_by_instrument[instrument_name] = []
            notes_by_instrument[instrument_name].append((pitch, duration))
        
        print(f"Creating {len(notes_by_instrument)} instrument part(s)...")
        
        # Create a part for each instrument
        for instrument_name, notes in notes_by_instrument.items():
            part = stream.Part()
            
            # Set the instrument
            if instrument_name in self.instrument_objects:
                part.insert(0, self.instrument_objects[instrument_name])
            else:
                # Fallback to a generic instrument
                part.insert(0, instrument.Piano())
            
            # Add notes to this part
            for pitch, duration in notes:
                new_note = note.Note(pitch)
                new_note.duration.quarterLength = duration
                part.append(new_note)
            
            score.append(part)
        
        # Write to MIDI file
        score.write('midi', fp=output_path)
        
        print(f"MIDI file saved successfully with {len(notes_by_instrument)} instrument(s)!")


In [10]:
def print_melody_info(melody):
    """
    Print information about the generated melody.
    
    Args:
        melody: List of tuples (pitch, duration)
    """
    print("\n=== Generated Melody Info ===")
    print(f"Total notes: {len(melody)}")
    
    # Count unique pitches and durations
    pitches = [p for p, d in melody]
    durations = [d for p, d in melody]
    
    print(f"Unique pitches used: {len(set(pitches))}")
    print(f"Unique durations used: {len(set(durations))}")
    
    # Print first few notes
    print("\nFirst 10 notes:")
    for i, (pitch, duration) in enumerate(melody[:10]):
        print(f"  {i+1}. {pitch} (duration: {duration})")
    
    # Calculate total duration in quarter notes
    total_duration = sum(durations)
    print(f"\nTotal duration: {total_duration} quarter notes")
    print(f"Estimated duration: ~{total_duration/2:.1f} seconds at 120 BPM")


In [12]:
# Main execution

# Initialize the generator
generator = DualMarkovChainMusicGenerator()

# Train on the existing MIDI file
# exclude_percussion=True (default) - ignores drums and percussion
# exclude_percussion=False - includes all instruments
generator.train_from_midi('LegendofZelda_Title.mid', exclude_percussion=True)

# Generate a new melody (50 notes)
generated_melody = generator.generate(50)

# Print information about the generated melody
print_melody_info(generated_melody)

# Save to a new MIDI file
generator.save_to_midi(generated_melody, 'generated_music.mid')

print("\n✓ Music generation complete!")
print("  You can play 'generated_music.mid' with any MIDI player.")


Loading MIDI file: LegendofZelda_Title.mid

Found 12 part(s) in the MIDI file:
  Part 1: None - 10 notes - included
  Part 2: None - 770 notes - included
  Part 3: None - 244 notes - included
  Part 4: None - 289 notes - included
  Part 5: None - 234 notes - included
  Part 6: None - 90 notes - included
  Part 7: None - 110 notes - included
  Part 8: None - 12 notes - included
  Part 9: None - 204 notes - included
  Part 10: None - 0 notes - included
  Part 11: None - 54 notes - included
  Part 12: None - 3 notes - included
Extracted 2020 notes from MIDI file
Pitch states: 56 unique pitches
Duration states: 18 unique durations
Training pitch chain...
Training duration chain...
Training complete!
Generating 50 notes...
Generation complete!

=== Generated Melody Info ===
Total notes: 50
Unique pitches used: 10
Unique durations used: 9

First 10 notes:
  1. C3 (duration: 0.25)
  2. C3 (duration: 0.5)
  3. F2 (duration: 0.5)
  4. F2 (duration: 1.75)
  5. F2 (duration: 1/3)
  6. G#1 (durati