# Music Generation with Higher-Order Markov Chains

This notebook uses **higher-order Markov chains** to generate music with better coherence:
- **Pitch chain** (note names like C5, D5, E5) - uses n-gram context
- **Duration chain** (note lengths like 1.0, 0.5, 2.0)
- **Instrument chain** (Piano, Bass, Strings, etc.)

## Key Improvements:
1. **Higher-order Markov chains** - considers sequences of notes, not just the previous one
2. **Proper chord detection** - analyzes note timing to identify simultaneous notes
3. **Correct initial probabilities** - tracks actual starting states
4. **Joint pitch-duration states** - optional correlation between pitch and rhythm


In [1]:
import numpy as np
from collections import defaultdict, Counter
from music21 import converter, note, chord, stream, midi, instrument
from typing import List, Tuple, Dict, Any, Optional


In [2]:
class HigherOrderMarkovChain:
    """
    A higher-order Markov Chain that considers n previous states (n-gram context).
    This produces more coherent sequences than a simple order-1 chain.
    """
    
    def __init__(self, order: int = 2):
        """
        Initialize the Higher-Order Markov Chain.
        
        Args:
            order: Number of previous states to consider (default: 2)
                   order=1 is a standard Markov chain
                   order=2 looks at the last 2 states to predict the next
        """
        self.order = order
        self.transitions: Dict[tuple, Counter] = defaultdict(Counter)
        self.starting_sequences: List[tuple] = []  # Actual starting n-grams
        self.all_states: set = set()
    
    def train(self, sequence: List[Any]) -> None:
        """
        Train the Markov chain on a sequence of states.
        
        Args:
            sequence: List of states to learn from
        """
        if len(sequence) < self.order + 1:
            print(f"Warning: Sequence too short for order {self.order}")
            return
        
        # Record the starting sequence (first n states)
        starting_seq = tuple(sequence[:self.order])
        self.starting_sequences.append(starting_seq)
        
        # Record all states
        self.all_states.update(sequence)
        
        # Build transition counts: context (n-gram) -> next state
        for i in range(len(sequence) - self.order):
            context = tuple(sequence[i:i + self.order])
            next_state = sequence[i + self.order]
            self.transitions[context][next_state] += 1
    
    def train_multiple(self, sequences: List[List[Any]]) -> None:
        """Train on multiple sequences (e.g., multiple tracks or songs)."""
        for seq in sequences:
            self.train(seq)
    
    def _get_transition_probs(self, context: tuple) -> Dict[Any, float]:
        """Get normalized transition probabilities for a context."""
        counts = self.transitions[context]
        total = sum(counts.values())
        if total == 0:
            return {}
        return {state: count / total for state, count in counts.items()}
    
    def _fallback_context(self, context: tuple) -> tuple:
        """
        Try shorter contexts if the full context is unseen (back-off smoothing).
        """
        for length in range(len(context) - 1, 0, -1):
            shorter = context[-length:]
            # Check if any key ends with this shorter context
            for key in self.transitions:
                if key[-length:] == shorter and self.transitions[key]:
                    return key
        return None
    
    def generate_starting_sequence(self) -> List[Any]:
        """
        Generate a starting sequence based on observed starting n-grams.
        
        Returns:
            List of states to start the generation
        """
        if not self.starting_sequences:
            # Fallback: return random states
            return list(np.random.choice(list(self.all_states), size=self.order))
        
        idx = np.random.randint(len(self.starting_sequences))
        return list(self.starting_sequences[idx])
    
    def generate_next_state(self, context: List[Any]) -> Any:
        """
        Generate the next state based on the context (last n states).
        
        Args:
            context: List of the last n states
        
        Returns:
            The next state
        """
        context_tuple = tuple(context[-self.order:])
        probs = self._get_transition_probs(context_tuple)
        
        # If context not found, try back-off to shorter context
        if not probs:
            fallback = self._fallback_context(context_tuple)
            if fallback:
                probs = self._get_transition_probs(fallback)
        
        # If still nothing, pick a random state
        if not probs:
            return np.random.choice(list(self.all_states))
        
        states = list(probs.keys())
        probabilities = list(probs.values())
        return np.random.choice(states, p=probabilities)
    
    def generate_sequence(self, length: int) -> List[Any]:
        """
        Generate a sequence of the specified length.
        
        Args:
            length: Number of states to generate
        
        Returns:
            List of generated states
        """
        if length <= 0:
            return []
        
        sequence = self.generate_starting_sequence()
        
        while len(sequence) < length:
            next_state = self.generate_next_state(sequence)
            sequence.append(next_state)
        
        return sequence[:length]


class JointMarkovChain:
    """
    A Markov chain that tracks joint states (e.g., pitch+duration together).
    This ensures pitch and duration are correlated as in real music.
    """
    
    def __init__(self, order: int = 2):
        self.order = order
        self.transitions: Dict[tuple, Counter] = defaultdict(Counter)
        self.starting_sequences: List[tuple] = []
        self.all_states: set = set()
    
    def train(self, joint_sequence: List[Tuple[Any, Any]]) -> None:
        """
        Train on a sequence of joint states (pitch, duration) tuples.
        """
        if len(joint_sequence) < self.order + 1:
            return
        
        self.starting_sequences.append(tuple(joint_sequence[:self.order]))
        self.all_states.update(joint_sequence)
        
        for i in range(len(joint_sequence) - self.order):
            context = tuple(joint_sequence[i:i + self.order])
            next_state = joint_sequence[i + self.order]
            self.transitions[context][next_state] += 1
    
    def generate_starting_sequence(self) -> List[Tuple[Any, Any]]:
        if not self.starting_sequences:
            return [list(self.all_states)[0]] * self.order
        idx = np.random.randint(len(self.starting_sequences))
        return list(self.starting_sequences[idx])
    
    def generate_next_state(self, context: List[Tuple[Any, Any]], temperature: float = 1.0) -> Tuple[Any, Any]:
        """
        Generate next state with temperature-controlled randomness.
        
        Args:
            context: Previous states
            temperature: Controls randomness (0.1=deterministic, 1.0=normal, 2.0+=very random)
        """
        context_tuple = tuple(context[-self.order:])
        counts = self.transitions[context_tuple]
        
        if not counts:
            # Fallback: pick random from all seen states
            return list(self.all_states)[np.random.randint(len(self.all_states))]
        
        total = sum(counts.values())
        states = list(counts.keys())
        probs = np.array([c / total for c in counts.values()])
        
        # Apply temperature: higher = more random, lower = more deterministic
        if temperature != 1.0 and len(probs) > 1:
            # Temperature scaling: raise probs to power of 1/temp, then renormalize
            probs = np.power(probs, 1.0 / temperature)
            probs = probs / probs.sum()
        
        return states[np.random.choice(len(states), p=probs)]
    
    def generate_sequence(self, length: int, temperature: float = 1.0) -> List[Tuple[Any, Any]]:
        """Generate sequence with temperature-controlled randomness."""
        sequence = self.generate_starting_sequence()
        while len(sequence) < length:
            sequence.append(self.generate_next_state(sequence, temperature))
        return sequence[:length]


In [3]:
class MusicGenerator:
    """
    An improved music generator using higher-order Markov chains with:
    - **Per-instrument modeling** - each instrument has its own chain
    - Proper chord detection based on timing analysis
    - Joint pitch-duration modeling for natural music
    - Back-off smoothing for unseen contexts
    """
    
    def __init__(self, order: int = 2, use_joint_model: bool = True):
        """
        Initialize the music generator.
        
        Args:
            order: Order of the Markov chain (2-3 recommended for music)
                   Higher = more coherent but needs more training data
            use_joint_model: If True, model pitch+duration together (more realistic)
                            If False, model them independently (more variety)
        """
        self.order = order
        self.use_joint_model = use_joint_model
        
        # Per-instrument Markov chains: instrument_name -> chain
        self.chains_by_instrument: Dict[str, JointMarkovChain] = {}
        self.pitch_chains_by_instrument: Dict[str, HigherOrderMarkovChain] = {}
        self.duration_chains_by_instrument: Dict[str, HigherOrderMarkovChain] = {}
        
        # Store instrument objects for MIDI output
        self.instrument_objects: Dict[str, Any] = {}
        
        # Store event counts per instrument for proportional generation
        self.instrument_event_counts: Dict[str, int] = {}
        
        # Training statistics
        self.stats: Dict[str, Any] = {}
    
    def train_from_midi(self, midi_file_path: str, exclude_percussion: bool = True,
                        chord_threshold: float = 0.05) -> None:
        """
        Train the Markov chains from a MIDI file.
        
        Args:
            midi_file_path: Path to the MIDI file
            exclude_percussion: If True, exclude percussion instruments
            chord_threshold: Max time difference (in quarter notes) to consider 
                           notes as part of the same chord (default: 0.05)
        """
        print(f"Loading MIDI file: {midi_file_path}")
        score = converter.parse(midi_file_path)
        
        # Extract events with proper chord detection
        events = self._extract_events_with_timing(score, exclude_percussion, chord_threshold)
        
        if not events:
            raise ValueError("No musical events found in the MIDI file")
        
        print(f"\nExtracted {len(events)} musical events")
        
        # Separate sequences by instrument for better training
        events_by_instrument = self._group_by_instrument(events)
        
        # Collect statistics
        all_pitches = [e['pitch'] for e in events]
        all_durations = [e['duration'] for e in events]
        all_instruments = [e['instrument'] for e in events]
        
        unique_pitches = set(all_pitches)
        unique_durations = set(all_durations)
        unique_instruments = set(all_instruments)
        
        chord_count = sum(1 for p in all_pitches if '.' in str(p))
        
        self.stats = {
            'total_events': len(events),
            'unique_pitches': len(unique_pitches),
            'unique_durations': len(unique_durations),
            'unique_instruments': len(unique_instruments),
            'chord_count': chord_count,
            'note_count': len(events) - chord_count
        }
        
        print(f"Unique pitches/chords: {self.stats['unique_pitches']}")
        print(f"Unique durations: {self.stats['unique_durations']}")
        print(f"Unique instruments: {self.stats['unique_instruments']}")
        print(f"Single notes: {self.stats['note_count']}, Chords: {self.stats['chord_count']}")
        
        # Train the chains
        self._train_chains(events_by_instrument)
        
        print("\n‚úì Training complete!")
    
    def _extract_events_with_timing(self, score, exclude_percussion: bool,
                                    chord_threshold: float) -> List[Dict]:
        """
        Extract musical events by analyzing timing to properly detect chords.
        
        This method groups notes that occur at the same time (within threshold)
        into chords, which is more reliable than relying on music21's chord detection.
        """
        events = []
        parts = score.parts
        
        print(f"\nFound {len(parts)} part(s):")
        
        # Track instrument name counts to make them unique
        name_counts = {}
        
        for i, part in enumerate(parts):
            part_instrument = part.getInstrument()
            base_name = self._get_instrument_name(part_instrument, i)
            
            # Make instrument name unique if duplicate
            if base_name in name_counts:
                name_counts[base_name] += 1
                instrument_name = f"{base_name} {name_counts[base_name]}"
            else:
                name_counts[base_name] = 1
                instrument_name = base_name
            
            # Check if percussion
            is_percussion = self._is_percussion(part_instrument)
            
            if is_percussion and exclude_percussion:
                print(f"  Part {i+1}: {instrument_name} - EXCLUDED (percussion)")
                continue
            
            # Store instrument object for later (clone it to avoid issues)
            if part_instrument:
                # Create a proper instrument object based on MIDI program
                program = getattr(part_instrument, 'midiProgram', 0)
                try:
                    # Try to create the correct instrument type
                    inst_obj = instrument.instrumentFromMidiProgram(program)
                except:
                    inst_obj = instrument.Piano()
                self.instrument_objects[instrument_name] = inst_obj
            else:
                self.instrument_objects[instrument_name] = instrument.Piano()
            
            # Get all notes with their offsets (timing)
            part_notes = []
            for el in part.flatten().notesAndRests:
                if isinstance(el, note.Note):
                    part_notes.append({
                        'offset': float(el.offset),
                        'pitches': [el.pitch.nameWithOctave],
                        'duration': float(el.duration.quarterLength),
                        'instrument': instrument_name
                    })
                elif isinstance(el, chord.Chord):
                    # Already a chord - extract all pitches
                    pitches = sorted([p.nameWithOctave for p in el.pitches])
                    part_notes.append({
                        'offset': float(el.offset),
                        'pitches': pitches,
                        'duration': float(el.duration.quarterLength),
                        'instrument': instrument_name
                    })
            
            # Group notes by timing to detect chords
            grouped_events = self._group_notes_by_timing(part_notes, chord_threshold)
            
            chord_count = sum(1 for e in grouped_events if len(e['pitches']) > 1)
            program_info = f" [prog={getattr(part_instrument, 'midiProgram', '?')}]" if part_instrument else ""
            print(f"  Part {i+1}: {instrument_name}{program_info} - {len(grouped_events)} events "
                  f"({len(grouped_events) - chord_count} notes, {chord_count} chords)")
            
            # Convert to final event format
            for event in grouped_events:
                if len(event['pitches']) == 1:
                    pitch_str = event['pitches'][0]
                else:
                    pitch_str = '.'.join(sorted(event['pitches']))
                
                events.append({
                    'pitch': pitch_str,
                    'duration': event['duration'],
                    'instrument': event['instrument'],
                    'offset': event['offset']
                })
        
        # Sort by offset to maintain temporal order
        events.sort(key=lambda x: x['offset'])
        
        return events
    
    def _group_notes_by_timing(self, notes: List[Dict], threshold: float) -> List[Dict]:
        """
        Group notes that occur at nearly the same time into chords.
        """
        if not notes:
            return []
        
        # Sort by offset
        notes = sorted(notes, key=lambda x: x['offset'])
        
        grouped = []
        current_group = {
            'offset': notes[0]['offset'],
            'pitches': list(notes[0]['pitches']),
            'duration': notes[0]['duration'],
            'instrument': notes[0]['instrument']
        }
        
        for n in notes[1:]:
            # Check if this note is within threshold of current group
            if abs(n['offset'] - current_group['offset']) <= threshold:
                # Add to current group (chord)
                current_group['pitches'].extend(n['pitches'])
                # Use the longest duration
                current_group['duration'] = max(current_group['duration'], n['duration'])
            else:
                # Start new group
                grouped.append(current_group)
                current_group = {
                    'offset': n['offset'],
                    'pitches': list(n['pitches']),
                    'duration': n['duration'],
                    'instrument': n['instrument']
                }
        
        grouped.append(current_group)
        return grouped
    
    def _get_instrument_name(self, inst, part_index: int = 0) -> str:
        """
        Get a readable instrument name using multiple fallback strategies.
        
        Priority:
        1. instrumentName property
        2. Class name (e.g., Violin, Flute)
        3. MIDI program number -> General MIDI name
        4. Part index as fallback
        """
        # General MIDI program names (0-127)
        GM_INSTRUMENTS = [
            "Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano", "Honky-tonk Piano",
            "Electric Piano 1", "Electric Piano 2", "Harpsichord", "Clavinet",
            "Celesta", "Glockenspiel", "Music Box", "Vibraphone", "Marimba", "Xylophone", "Tubular Bells", "Dulcimer",
            "Drawbar Organ", "Percussive Organ", "Rock Organ", "Church Organ", "Reed Organ", "Accordion", "Harmonica", "Tango Accordion",
            "Acoustic Guitar (nylon)", "Acoustic Guitar (steel)", "Electric Guitar (jazz)", "Electric Guitar (clean)",
            "Electric Guitar (muted)", "Overdriven Guitar", "Distortion Guitar", "Guitar Harmonics",
            "Acoustic Bass", "Electric Bass (finger)", "Electric Bass (pick)", "Fretless Bass",
            "Slap Bass 1", "Slap Bass 2", "Synth Bass 1", "Synth Bass 2",
            "Violin", "Viola", "Cello", "Contrabass", "Tremolo Strings", "Pizzicato Strings", "Orchestral Harp", "Timpani",
            "String Ensemble 1", "String Ensemble 2", "Synth Strings 1", "Synth Strings 2", "Choir Aahs", "Voice Oohs", "Synth Choir", "Orchestra Hit",
            "Trumpet", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Brass Section", "Synth Brass 1", "Synth Brass 2",
            "Soprano Sax", "Alto Sax", "Tenor Sax", "Baritone Sax", "Oboe", "English Horn", "Bassoon", "Clarinet",
            "Piccolo", "Flute", "Recorder", "Pan Flute", "Blown Bottle", "Shakuhachi", "Whistle", "Ocarina",
            "Lead 1 (square)", "Lead 2 (sawtooth)", "Lead 3 (calliope)", "Lead 4 (chiff)", "Lead 5 (charang)",
            "Lead 6 (voice)", "Lead 7 (fifths)", "Lead 8 (bass + lead)",
            "Pad 1 (new age)", "Pad 2 (warm)", "Pad 3 (polysynth)", "Pad 4 (choir)", "Pad 5 (bowed)",
            "Pad 6 (metallic)", "Pad 7 (halo)", "Pad 8 (sweep)",
            "FX 1 (rain)", "FX 2 (soundtrack)", "FX 3 (crystal)", "FX 4 (atmosphere)",
            "FX 5 (brightness)", "FX 6 (goblins)", "FX 7 (echoes)", "FX 8 (sci-fi)",
            "Sitar", "Banjo", "Shamisen", "Koto", "Kalimba", "Bagpipe", "Fiddle", "Shanai",
            "Tinkle Bell", "Agogo", "Steel Drums", "Woodblock", "Taiko Drum", "Melodic Tom", "Synth Drum", "Reverse Cymbal",
            "Guitar Fret Noise", "Breath Noise", "Seashore", "Bird Tweet", "Telephone Ring", "Helicopter", "Applause", "Gunshot"
        ]
        
        if inst is None:
            return f"Part {part_index + 1}"
        
        # Try 1: instrumentName property
        name = inst.instrumentName
        if name and name != "None" and name.strip():
            return name
        
        # Try 2: Class name
        class_name = type(inst).__name__
        if class_name not in ["Instrument", "UnpitchedPercussion"]:
            return class_name
        
        # Try 3: MIDI program number
        program = getattr(inst, 'midiProgram', None)
        if program is not None and 0 <= program < len(GM_INSTRUMENTS):
            return GM_INSTRUMENTS[program]
        
        # Fallback: Part number
        return f"Part {part_index + 1}"
    
    def _is_percussion(self, inst) -> bool:
        """Check if an instrument is percussion."""
        if inst is None:
            return False
        return (
            isinstance(inst, instrument.Percussion) or
            isinstance(inst, instrument.UnpitchedPercussion) or
            getattr(inst, 'midiChannel', -1) == 9
        )
    
    def _group_by_instrument(self, events: List[Dict]) -> Dict[str, List[Dict]]:
        """Group events by instrument."""
        grouped = defaultdict(list)
        for event in events:
            grouped[event['instrument']].append(event)
        return dict(grouped)
    
    def _train_chains(self, events_by_instrument: Dict[str, List[Dict]]) -> None:
        """
        Train SEPARATE Markov chains for each instrument.
        
        This is the key improvement for multi-instrument pieces:
        - Each instrument learns its own patterns (bass learns bass lines, melody learns melodies)
        - Generation preserves the character of each instrument
        """
        print(f"\nüéº Training separate chains per instrument (order={self.order})...")
        
        for inst_name, events in events_by_instrument.items():
            if len(events) < self.order + 1:
                print(f"   ‚ö†Ô∏è  {inst_name}: Too few events ({len(events)}), skipping")
                continue
            
            # Store event count for proportional generation
            self.instrument_event_counts[inst_name] = len(events)
            
            if self.use_joint_model:
                # Joint pitch+duration model for this instrument
                chain = JointMarkovChain(order=self.order)
                joint_seq = [(e['pitch'], e['duration']) for e in events]
                chain.train(joint_seq)
                self.chains_by_instrument[inst_name] = chain
                
                n_transitions = sum(len(c) for c in chain.transitions.values())
                print(f"   ‚úì {inst_name}: {len(events)} events, {n_transitions} transitions")
            else:
                # Independent chains for this instrument
                pitch_chain = HigherOrderMarkovChain(order=self.order)
                duration_chain = HigherOrderMarkovChain(order=self.order)
                
                pitch_seq = [e['pitch'] for e in events]
                duration_seq = [e['duration'] for e in events]
                
                pitch_chain.train(pitch_seq)
                duration_chain.train(duration_seq)
                
                self.pitch_chains_by_instrument[inst_name] = pitch_chain
                self.duration_chains_by_instrument[inst_name] = duration_chain
                
                print(f"   ‚úì {inst_name}: {len(events)} events")
        
        print(f"\n   Trained {len(self.instrument_event_counts)} instrument(s)")
    
    def generate(self, length: int, instruments: Optional[List[str]] = None, 
                 temperature: float = 1.0) -> List[Tuple[str, float, str]]:
        """
        Generate music with separate sequences for each instrument.
        
        Each instrument generates its own coherent sequence, preserving
        the musical character of that instrument from the training data.
        
        Args:
            length: Total number of musical events to generate (distributed across instruments)
            instruments: Optional list of instruments to use. If None, uses all trained instruments.
            temperature: Controls randomness/creativity (default 1.0)
                - 0.5 = More deterministic, closer to training data
                - 1.0 = Normal (follows learned probabilities)
                - 1.5 = More creative, more unexpected choices
                - 2.0+ = Very random, might sound chaotic
        
        Returns:
            List of tuples (pitch, duration, instrument_name)
        """
        # Determine which instruments to generate
        if instruments is None:
            instruments = list(self.instrument_event_counts.keys())
        
        if not instruments:
            raise ValueError("No instruments available. Train the model first.")
        
        # Calculate events per instrument (proportional to training data)
        total_training_events = sum(self.instrument_event_counts.get(i, 1) for i in instruments)
        
        events_per_instrument = {}
        remaining = length
        
        for i, inst in enumerate(instruments):
            if i == len(instruments) - 1:
                # Last instrument gets remaining events
                events_per_instrument[inst] = remaining
            else:
                # Proportional allocation
                proportion = self.instrument_event_counts.get(inst, 1) / total_training_events
                count = max(1, int(length * proportion))
                events_per_instrument[inst] = count
                remaining -= count
        
        temp_desc = "deterministic" if temperature < 0.8 else "normal" if temperature < 1.3 else "creative"
        print(f"üéµ Generating {length} events (temperature={temperature}, {temp_desc}):")
        
        melody = []
        
        for inst_name in instruments:
            inst_length = events_per_instrument[inst_name]
            
            if self.use_joint_model:
                chain = self.chains_by_instrument.get(inst_name)
                if chain is None:
                    print(f"   ‚ö†Ô∏è  {inst_name}: No chain available, skipping")
                    continue
                
                # Generate sequence for this instrument WITH TEMPERATURE
                joint_seq = chain.generate_sequence(inst_length, temperature=temperature)
                
                # Add to melody with instrument name
                for pitch, duration in joint_seq:
                    melody.append((pitch, duration, inst_name))
                
                print(f"   ‚úì {inst_name}: {inst_length} events")
            else:
                pitch_chain = self.pitch_chains_by_instrument.get(inst_name)
                duration_chain = self.duration_chains_by_instrument.get(inst_name)
                
                if pitch_chain is None or duration_chain is None:
                    print(f"   ‚ö†Ô∏è  {inst_name}: No chains available, skipping")
                    continue
                
                pitch_seq = pitch_chain.generate_sequence(inst_length)
                duration_seq = duration_chain.generate_sequence(inst_length)
                
                for pitch, duration in zip(pitch_seq, duration_seq):
                    melody.append((pitch, duration, inst_name))
                
                print(f"   ‚úì {inst_name}: {inst_length} events")
        
        print(f"\n   Total: {len(melody)} events")
        return melody
    
    def save_to_midi(self, melody: List[Tuple[str, float, str]], output_path: str,
                     polyphonic: bool = True) -> None:
        """
        Save generated music to a MIDI file.
        
        Args:
            melody: List of tuples (pitch, duration, instrument_name)
            output_path: Path to save the MIDI file
            polyphonic: If True, instruments play simultaneously (proper multi-track)
                       If False, instruments play sequentially (old behavior)
        """
        print(f"Saving to {output_path}...")
        
        score = stream.Score()
        
        # Group notes by instrument
        notes_by_instrument = defaultdict(list)
        for pitch, duration, inst_name in melody:
            notes_by_instrument[inst_name].append((pitch, duration))
        
        print(f"Creating {len(notes_by_instrument)} instrument part(s)...")
        
        for inst_name, notes in notes_by_instrument.items():
            part = stream.Part()
            part.id = inst_name
            part.partName = inst_name
            
            # Set instrument with proper MIDI program
            if inst_name in self.instrument_objects and self.instrument_objects[inst_name]:
                inst_obj = self.instrument_objects[inst_name]
                # Clone the instrument to avoid modifying the original
                try:
                    new_inst = type(inst_obj)()
                    if hasattr(inst_obj, 'midiProgram'):
                        new_inst.midiProgram = inst_obj.midiProgram
                except:
                    new_inst = inst_obj
                part.insert(0, new_inst)
                print(f"   ‚Ä¢ {inst_name}: program {getattr(new_inst, 'midiProgram', 0)}")
            else:
                part.insert(0, instrument.Piano())
                print(f"   ‚Ä¢ {inst_name}: program 0 (Piano fallback)")
            
            # Track offset for positioning notes
            current_offset = 0.0
            
            # Add notes/chords
            for pitch, duration in notes:
                # Ensure duration is valid
                duration = max(0.125, float(duration))  # Minimum 32nd note
                
                if '.' in str(pitch):
                    # It's a chord
                    chord_pitches = pitch.split('.')
                    new_element = chord.Chord(chord_pitches)
                else:
                    # Single note
                    new_element = note.Note(pitch)
                
                new_element.duration.quarterLength = duration
                
                if polyphonic:
                    # Insert at specific offset for parallel playback
                    part.insert(current_offset, new_element)
                    current_offset += duration
                else:
                    # Append sequentially
                    part.append(new_element)
            
            score.insert(0, part)
        
        score.write('midi', fp=output_path)
        
        # Calculate total duration
        total_dur = max(
            sum(d for _, d in notes) 
            for notes in notes_by_instrument.values()
        ) if notes_by_instrument else 0
        
        print(f"‚úì Saved with {len(notes_by_instrument)} instrument(s)")
        print(f"   Duration: ~{total_dur:.1f} quarter notes ({total_dur/2:.1f}s at 120 BPM)")
    
    def get_instruments(self) -> List[str]:
        """Get list of available trained instruments."""
        return list(self.instrument_event_counts.keys())
    
    def get_instrument_stats(self) -> Dict[str, int]:
        """Get event counts per instrument."""
        return dict(self.instrument_event_counts)


In [4]:
def print_melody_info(melody: List[Tuple[str, float, str]]) -> None:
    """
    Print detailed information about the generated melody.
    
    Args:
        melody: List of tuples (pitch, duration, instrument_name)
    """
    print("\n" + "=" * 50)
    print("           GENERATED MELODY INFO")
    print("=" * 50)
    
    pitches = [p for p, d, i in melody]
    durations = [d for p, d, i in melody]
    instruments = [i for p, d, i in melody]
    
    # Count single notes vs chords
    single_notes = [p for p in pitches if '.' not in str(p)]
    chords = [p for p in pitches if '.' in str(p)]
    
    print(f"\nüìä Summary:")
    print(f"   Total events: {len(melody)}")
    print(f"   Single notes: {len(single_notes)}")
    print(f"   Chords: {len(chords)}")
    print(f"   Unique pitches/chords: {len(set(pitches))}")
    print(f"   Unique durations: {len(set(durations))}")
    print(f"   Unique instruments: {len(set(instruments))}")
    
    # Count notes per instrument
    instrument_counts = Counter(instruments)
    print(f"\nüéπ Events per instrument:")
    for inst, count in instrument_counts.most_common():
        pct = 100 * count / len(melody)
        print(f"   {inst}: {count} ({pct:.1f}%)")
    
    # Duration statistics
    total_duration = sum(durations)
    avg_duration = np.mean(durations)
    print(f"\n‚è±Ô∏è Duration:")
    print(f"   Total: {total_duration:.1f} quarter notes")
    print(f"   Average per event: {avg_duration:.2f} quarter notes")
    print(f"   Estimated time: ~{total_duration/2:.1f}s at 120 BPM")
    
    # Print first few notes/chords
    print(f"\nüéµ First 10 events:")
    for i, (pitch, duration, inst) in enumerate(melody[:10]):
        symbol = "üéº" if '.' in str(pitch) else "üéµ"
        print(f"   {i+1}. {symbol} {pitch} (dur={duration:.2f}) [{inst}]")
    
    print("\n" + "=" * 50)


In [12]:
# ============================================================
# MAIN EXECUTION - Multi-Instrument Music Generation
# ============================================================

# === INPUT/OUTPUT ===

MIDI_FILE = 'midis/ZeldaFantasy_1_.mid'    # Source file to learn from
OUTPUT_FILE = f'./generated_music/{MIDI_FILE}-generated.mid' # Where to save generated music

# === KEY PARAMETERS ===
NUM_EVENTS = 200    # Total notes/chords to generate (spread across all instruments)

ORDER = 2           # Context length: how many previous notes to consider
                    #   1 = very random, lots of variety
                    #   2 = balanced (recommended)
                    #   3 = more coherent, needs more training data

# üå°Ô∏è TEMPERATURE - THIS IS THE KEY TO AVOIDING COPIES!
TEMPERATURE = 1.3   # Controls randomness/creativity:
                    #   0.5 = Very similar to input (deterministic)
                    #   1.0 = Normal (follows learned probabilities exactly)
                    #   1.3 = Slightly creative (recommended for variety) ‚≠ê
                    #   1.5 = More experimental
                    #   2.0+ = Chaotic/very random

# === OTHER SETTINGS ===
USE_JOINT_MODEL = True   # True = pitch+duration learned together (more natural)
CHORD_THRESHOLD = 0.05   # How to detect chords (lower = stricter)

# ============================================================

print("=" * 60)
print("   üéµ Multi-Instrument Music Generator")
print("=" * 60)
print(f"\n   üìÇ Input: {MIDI_FILE}")
print(f"   üéØ Order: {ORDER} (context = {ORDER} previous notes)")
print(f"   üå°Ô∏è  Temperature: {TEMPERATURE} ({'copies input' if TEMPERATURE < 0.8 else 'balanced' if TEMPERATURE < 1.2 else 'creative' if TEMPERATURE < 1.6 else 'very random'})")
print(f"   üìä Events: {NUM_EVENTS}")
print()

# Initialize generator
generator = MusicGenerator(order=ORDER, use_joint_model=USE_JOINT_MODEL)

# Train on the MIDI file
generator.train_from_midi(
    MIDI_FILE, 
    exclude_percussion=True,
    chord_threshold=CHORD_THRESHOLD
)

# Show available instruments
print("\nüìã Available instruments:")
for inst, count in generator.get_instrument_stats().items():
    print(f"   ‚Ä¢ {inst}: {count} training events")

# Generate new music WITH TEMPERATURE
generated_melody = generator.generate(NUM_EVENTS, temperature=TEMPERATURE)

# Print info
print_melody_info(generated_melody)

# Save to MIDI file (polyphonic=True means instruments play simultaneously)
generator.save_to_midi(generated_melody, OUTPUT_FILE, polyphonic=True)

print(f"\n‚úì Generation complete! Output: {OUTPUT_FILE}")
print(f"üí° Tip: Adjust TEMPERATURE (currently {TEMPERATURE}) to control how different the output sounds.")


   üéµ Multi-Instrument Music Generator

   üìÇ Input: midis/ZeldaFantasy_1_.mid
   üéØ Order: 2 (context = 2 previous notes)
   üå°Ô∏è  Temperature: 1.3 (creative)
   üìä Events: 200

Loading MIDI file: midis/ZeldaFantasy_1_.mid

Found 3 part(s):
  Part 1: Part 1 [prog=None] - 442 events (363 notes, 79 chords)
  Part 2: Part 2 [prog=None] - 1351 events (1350 notes, 1 chords)
  Part 3: Part 3 [prog=None] - 47 events (0 notes, 47 chords)

Extracted 1840 musical events
Unique pitches/chords: 88
Unique durations: 12
Unique instruments: 3
Single notes: 1713, Chords: 127

üéº Training separate chains per instrument (order=2)...
   ‚úì Part 2: 1351 events, 288 transitions
   ‚úì Part 1: 442 events, 183 transitions
   ‚úì Part 3: 47 events, 35 transitions

   Trained 3 instrument(s)

‚úì Training complete!

üìã Available instruments:
   ‚Ä¢ Part 2: 1351 training events
   ‚Ä¢ Part 1: 442 training events
   ‚Ä¢ Part 3: 47 training events
üéµ Generating 200 events (temperature=1.3, crea

FileNotFoundError: [Errno 2] No such file or directory: './generated_music/midis/ZeldaFantasy_1_.mid-generated.mid'

In [6]:
# ============================================================
# MIDI FILE ANALYSIS - See what data is available for training
# ============================================================

import glob

def analyze_midi_files(chord_threshold: float = 0.05):
    """Analyze all MIDI files in the current directory."""
    midi_files = sorted(glob.glob('*.mid'))
    print(f"üìÅ Found {len(midi_files)} MIDI files\n")
    print(f"{'File':<35} {'Parts':>6} {'Notes':>7} {'Chords':>7} {'Total':>7}")
    print("-" * 70)
    
    results = []
    
    for midi_file in midi_files:
        try:
            score = converter.parse(midi_file)
            
            # Count using timing-based chord detection (like the improved generator)
            total_notes = 0
            total_chords = 0
            
            for part in score.parts:
                # Get notes with offsets
                notes_with_offset = []
                for el in part.flatten().notesAndRests:
                    if isinstance(el, note.Note):
                        notes_with_offset.append(float(el.offset))
                    elif isinstance(el, chord.Chord):
                        notes_with_offset.append(float(el.offset))
                        total_chords += 1  # Count existing chord objects
                
                # Count notes that might form chords (same offset)
                if notes_with_offset:
                    notes_with_offset.sort()
                    for i, offset in enumerate(notes_with_offset):
                        if i == 0 or abs(offset - notes_with_offset[i-1]) > chord_threshold:
                            total_notes += 1
            
            num_parts = len(score.parts)
            total = total_notes + total_chords
            
            print(f"{midi_file:<35} {num_parts:>6} {total_notes:>7} {total_chords:>7} {total:>7}")
            
            results.append({
                'file': midi_file,
                'parts': num_parts,
                'notes': total_notes,
                'chords': total_chords
            })
            
        except Exception as e:
            print(f"{midi_file:<35} ERROR: {str(e)[:30]}")
    
    print("-" * 70)
    
    # Recommend best files for training
    best_for_chords = max(results, key=lambda x: x['chords'])
    best_total = max(results, key=lambda x: x['notes'] + x['chords'])
    
    print(f"\nüí° Recommendations:")
    print(f"   Best for chords: {best_for_chords['file']} ({best_for_chords['chords']} chords)")
    print(f"   Most data: {best_total['file']} ({best_total['notes'] + best_total['chords']} events)")

analyze_midi_files()


üìÅ Found 14 MIDI files

File                                 Parts   Notes  Chords   Total
----------------------------------------------------------------------
LegendofZelda_Title.mid                 12    2930    1136    4066
MARIOBRO.mid                             4     109       1     110
MariobrosPhase1.mid                      1      18       0      18
SMB_-_Castle_-Remix-.mid                 9    2704    2380    5084
Starman-1.mid                            2      96      80     176
Tetris_Title_Screen.mid                  4     641       0     641
a_thousand.mid                           5     336     159     495
generated_music.mid                      1     114       0     114
generated_music1.mid                     1     107       0     107
generated_music2.mid                     1     100      34     134
generated_musichehe.mid                  1     218     102     320
generated_musicnow.mid                   1     160      77     237
music.mid                       

## Understanding the Improved Generator

### Multi-Instrument Support üéπüé∏üé∫

The key improvement is **per-instrument modeling**:

| Old Approach | New Approach |
|--------------|--------------|
| One chain for ALL instruments mixed | Separate chain per instrument |
| Random instrument assignment | Each instrument generates its own coherent part |
| Bass patterns mixed with melody | Bass learns bass, melody learns melody |
| Sequential output | **Polyphonic output** (instruments play simultaneously) |

### How It Works

1. **Training**: Each instrument gets its own Markov chain trained only on that instrument's notes
2. **Generation**: Each instrument generates a separate coherent sequence
3. **Output**: All instrument parts are combined with correct timing (polyphonic MIDI)

### Configuration Guide

```python
# For multi-instrument pieces (recommended):
generator = MusicGenerator(order=2, use_joint_model=True)

# Generate with all instruments:
melody = generator.generate(200)

# Generate with specific instruments only:
melody = generator.generate(100, instruments=['Piano', 'Acoustic Bass'])

# See available instruments:
print(generator.get_instruments())
```

### Polyphonic vs Sequential Output

```python
# Polyphonic: instruments play SIMULTANEOUSLY (correct for multi-track)
generator.save_to_midi(melody, 'output.mid', polyphonic=True)

# Sequential: instruments play one after another (simpler)
generator.save_to_midi(melody, 'output.mid', polyphonic=False)
```

### Best Files for Multi-Instrument Generation

| File | Instruments | Notes |
|------|-------------|-------|
| `LegendofZelda_Title.mid` | Multiple (strings, brass, etc.) | Best for orchestral |
| `SMB_-_Castle_-Remix-.mid` | Multiple with many chords | Great variety |
| `a_thousand.mid` | Piano only | Good for single-instrument |
| `MARIOBRO.mid` | Simple | Good for melody-focused |


In [None]:
# ============================================================
# EXPERIMENT: Generate with Specific Instruments
# ============================================================
# This cell lets you generate music using only selected instruments

def generate_with_selected_instruments(midi_file: str, selected_instruments: List[str] = None,
                                       length: int = 100, order: int = 2):
    """
    Generate music using only selected instruments from the training file.
    
    Args:
        midi_file: Source MIDI file
        selected_instruments: List of instrument names to use (None = all)
        length: Number of events to generate
        order: Markov chain order
    """
    print(f"üéº Generating with selected instruments from: {midi_file}\n")
    
    gen = MusicGenerator(order=order, use_joint_model=True)
    gen.train_from_midi(midi_file, exclude_percussion=True)
    
    # Show available instruments
    print("\nüìã All trained instruments:")
    for inst, count in gen.get_instrument_stats().items():
        marker = "‚úì" if (selected_instruments is None or inst in selected_instruments) else "‚úó"
        print(f"   {marker} {inst}: {count} events")
    
    # Generate
    if selected_instruments:
        # Filter to only valid instruments
        valid = [i for i in selected_instruments if i in gen.get_instruments()]
        if not valid:
            print("\n‚ö†Ô∏è  No valid instruments selected!")
            return None
        melody = gen.generate(length, instruments=valid)
    else:
        melody = gen.generate(length)
    
    # Save
    output_file = "generated_selected.mid"
    gen.save_to_midi(melody, output_file, polyphonic=True)
    
    return melody

# Example: Generate using only strings and woodwinds from Zelda
# Uncomment and modify the instrument list based on what's available:

# melody = generate_with_selected_instruments(
#     'LegendofZelda_Title.mid',
#     selected_instruments=['Violin', 'Flute', 'Oboe'],  # Adjust based on available instruments
#     length=150
# )


In [None]:
# ============================================================
# ADVANCED: Train on Multiple MIDI Files
# ============================================================
# This combines data from multiple files for richer training

def train_on_multiple_files(midi_files: list, order: int = 2, use_joint: bool = True):
    """
    Train a generator on multiple MIDI files for richer patterns.
    
    Args:
        midi_files: List of MIDI file paths
        order: Markov chain order
        use_joint: Use joint pitch-duration model
    
    Returns:
        Trained MusicGenerator
    """
    print(f"üéµ Training on {len(midi_files)} MIDI files\n")
    
    # We'll collect all events and train on the combined data
    generator = MusicGenerator(order=order, use_joint_model=use_joint)
    
    all_events = []
    
    for midi_file in midi_files:
        try:
            print(f"   Loading: {midi_file}")
            score = converter.parse(midi_file)
            events = generator._extract_events_with_timing(score, True, 0.05)
            all_events.extend(events)
            print(f"      ‚Üí {len(events)} events")
        except Exception as e:
            print(f"      ‚úó Error: {e}")
    
    # Now train on combined data
    print(f"\n   Total events: {len(all_events)}")
    events_by_instrument = generator._group_by_instrument(all_events)
    generator._train_chains(events_by_instrument)
    
    # Update stats
    generator.stats = {
        'total_events': len(all_events),
        'unique_pitches': len(set(e['pitch'] for e in all_events)),
        'unique_durations': len(set(e['duration'] for e in all_events)),
        'source_files': len(midi_files)
    }
    
    print(f"\n‚úì Training complete!")
    print(f"   Combined: {generator.stats['unique_pitches']} unique pitches from {len(midi_files)} files")
    
    return generator

# Example usage - uncomment to try:
# multi_gen = train_on_multiple_files([
#     'MARIOBRO.mid',
#     'LegendofZelda_Title.mid',
#     'Tetris_Title_Screen.mid'
# ], order=2, use_joint=True)
# 
# melody = multi_gen.generate(200)
# multi_gen.save_to_midi(melody, 'generated_mixed.mid')
