In [27]:
import os
import random
import pickle
from collections import defaultdict, Counter
import mido
from mido import MidiFile, MidiTrack, Message
import glob

class SymbolicMusicGenerator:
    def __init__(self):
        self.trigram_model = defaultdict(Counter)
        self.tempo_transitions = defaultdict(list)
        self.note_durations = defaultdict(list)
        self.velocity_patterns = defaultdict(list)
        self.time_signatures = []
        self.tempos = []
        
    def extract_musical_events(self, midi_file_path):
        """Extract musical events from a MIDI file including timing information"""
        try:
            mid = MidiFile(midi_file_path)
            events = []
            current_tempo = 500000  # Default tempo (120 BPM)
            current_time = 0
            
            # Track active notes for duration calculation
            active_notes = {}
            
            for track in mid.tracks:
                track_time = 0
                for msg in track:
                    track_time += msg.time
                    current_time = track_time
                    
                    if msg.type == 'set_tempo':
                        current_tempo = msg.tempo
                        events.append(('tempo', msg.tempo, current_time))
                        
                    elif msg.type == 'time_signature':
                        events.append(('time_sig', f"{msg.numerator}/{msg.denominator}", current_time))
                        
                    elif msg.type == 'note_on' and msg.velocity > 0:
                        note_key = (msg.channel, msg.note)
                        active_notes[note_key] = {
                            'start_time': current_time,
                            'velocity': msg.velocity,
                            'tempo': current_tempo
                        }
                        events.append(('note_on', msg.note, current_time, msg.velocity, current_tempo))
                        
                    elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                        note_key = (msg.channel, msg.note)
                        if note_key in active_notes:
                            note_info = active_notes[note_key]
                            duration = current_time - note_info['start_time']
                            events.append(('note_off', msg.note, current_time, duration, note_info['velocity']))
                            del active_notes[note_key]
            
            return events
            
        except Exception as e:
            print(f"Error processing {midi_file_path}: {e}")
            return []
    
    def create_symbolic_representation(self, events):
        """Convert events to symbolic representation for Markov chain"""
        symbols = []
        prev_time = 0
        
        # Filter and sort note events
        note_events = [(e[2], e) for e in events if e[0] == 'note_on']
        note_events.sort()  # Sort by time
        
        for time_pos, event in note_events:
            note = event[1]
            velocity = event[3]
            tempo = event[4]
            
            # Calculate time delta from previous note
            time_delta = max(0, time_pos - prev_time)
            prev_time = time_pos
            
            # Quantize timing to musical values (in ticks)
            if time_delta < 120:
                timing = "T1"  # Very short
            elif time_delta < 240:
                timing = "T2"  # Eighth note
            elif time_delta < 480:
                timing = "T3"  # Quarter note
            elif time_delta < 960:
                timing = "T4"  # Half note
            else:
                timing = "T5"  # Whole note or longer
            
            # Quantize velocity more musically
            if velocity < 40:
                vel_level = "pp"  # pianissimo
            elif velocity < 70:
                vel_level = "mp"  # mezzo-piano  
            elif velocity < 90:
                vel_level = "mf"  # mezzo-forte
            elif velocity < 110:
                vel_level = "f"   # forte
            else:
                vel_level = "ff"  # fortissimo
            
            # Use pitch classes and octaves for better musical coherence
            pitch_class = note % 12
            octave = note // 12
            
            # Create more musical symbol
            symbol = f"P{pitch_class}_O{octave}_{vel_level}_{timing}"
            symbols.append(symbol)
                
        return symbols
    
    def build_trigram_model(self, symbols):
        """Build trigram Markov chain from symbols"""
        for i in range(len(symbols) - 2):
            trigram_key = (symbols[i], symbols[i + 1])
            next_symbol = symbols[i + 2]
            self.trigram_model[trigram_key][next_symbol] += 1
    
    def train_on_dataset(self, dataset_path, max_files=50):
        """Train the model on MIDI files from the dataset"""
        print("Scanning for MIDI files...")
        midi_files = []
        
        # Recursively find all .mid files
        for root, dirs, files in os.walk(dataset_path):
            for file in files:
                if file.lower().endswith(('.mid', '.midi')):
                    midi_files.append(os.path.join(root, file))
        
        print(f"Found {len(midi_files)} MIDI files")
        
        # Limit the number of files to process for reasonable training time
        if len(midi_files) > max_files:
            midi_files = random.sample(midi_files, max_files)
            print(f"Processing {max_files} randomly selected files")
        
        processed_count = 0
        total_notes = 0
        octave_distribution = defaultdict(int)
        
        for midi_file in midi_files:
            print(f"Processing: {os.path.basename(midi_file)} ({processed_count + 1}/{len(midi_files)})")
            
            events = self.extract_musical_events(midi_file)
            if events:
                symbols = self.create_symbolic_representation(events)
                if len(symbols) > 10:  # Only use files with sufficient content
                    self.build_trigram_model(symbols)
                    
                    # Debug: track octave distribution
                    for symbol in symbols:
                        if symbol.startswith('P'):
                            try:
                                parts = symbol.split('_')
                                octave = int(parts[1][1:])
                                octave_distribution[octave] += 1
                                total_notes += 1
                            except:
                                pass
                    
                    # Store additional musical information
                    tempos = [e[1] for e in events if e[0] == 'tempo']
                    if tempos:
                        self.tempos.extend(tempos)
                        
                processed_count += 1
        
        print(f"Training completed on {processed_count} files")
        print(f"Learned {len(self.trigram_model)} trigram patterns")
        print(f"Total notes processed: {total_notes}")
        print("Octave distribution in training data:")
        for octave in sorted(octave_distribution.keys()):
            percentage = (octave_distribution[octave] / total_notes) * 100
            print(f"  Octave {octave}: {octave_distribution[octave]} notes ({percentage:.1f}%)")
    
    def generate_sequence(self, length=200, seed=None):
        """Generate a new sequence using the trigram model"""
        if len(self.trigram_model) == 0:
            raise ValueError("Model not trained yet!")
        
        # Start with a random trigram or use seed
        if seed and len(seed) >= 2:
            current_bigram = (seed[0], seed[1])
        else:
            current_bigram = random.choice(list(self.trigram_model.keys()))
        
        sequence = list(current_bigram)
        
        # Track musical context for better generation
        last_pitch_class = None
        last_octave = None
        octave_bias = 0  # Bias to encourage higher octaves
        
        for i in range(length - 2):
            if current_bigram in self.trigram_model:
                # Choose next symbol based on probability distribution
                next_symbols = self.trigram_model[current_bigram]
                total_count = sum(next_symbols.values())
                
                if total_count > 0:
                    # Create weighted list for better selection
                    weighted_choices = []
                    for symbol, count in next_symbols.items():
                        # Add some musical intelligence - prefer steps and skips
                        weight = count
                        if symbol.startswith('P') and last_pitch_class is not None:
                            try:
                                parts = symbol.split('_')
                                pitch_class = int(parts[0][1:])
                                octave = int(parts[1][1:])
                                
                                # Strongly favor higher octaves
                                if octave >= 5:
                                    weight = int(weight * 2.0)
                                elif octave >= 4:
                                    weight = int(weight * 1.5)
                                elif octave <= 2:
                                    weight = max(1, int(weight * 0.3))  # Heavily discourage low octaves
                                
                                # Slightly favor melodic intervals (steps and small skips)
                                interval = abs(pitch_class - last_pitch_class)
                                if interval <= 2 or interval >= 10:  # Steps (including octave wrapping)
                                    weight = int(weight * 1.2)
                                elif interval <= 4 or interval >= 8:  # Small skips
                                    weight = int(weight * 1.1)
                                
                                # Favor staying in similar octave range
                                if abs(octave - last_octave) <= 1:
                                    weight = int(weight * 1.1)
                                    
                                last_pitch_class = pitch_class
                                last_octave = octave
                            except:
                                pass
                        
                        weighted_choices.extend([symbol] * max(1, weight))
                    
                    if weighted_choices:
                        next_symbol = random.choice(weighted_choices)
                        sequence.append(next_symbol)
                        current_bigram = (current_bigram[1], next_symbol)
                    else:
                        # Fallback
                        current_bigram = random.choice(list(self.trigram_model.keys()))
                        sequence.extend(current_bigram)
                else:
                    # Fallback to random trigram
                    current_bigram = random.choice(list(self.trigram_model.keys()))
                    sequence.extend(current_bigram)
            else:
                # Start new random trigram
                current_bigram = random.choice(list(self.trigram_model.keys()))
                sequence.extend(current_bigram)
        
        return sequence
    
    def sequence_to_midi(self, sequence, output_path="generated_music.mid"):
        """Convert generated sequence back to MIDI file with proper timing"""
        mid = MidiFile(ticks_per_beat=480)
        track = MidiTrack()
        mid.tracks.append(track)
        
        # Set default tempo
        default_tempo = 500000 if not self.tempos else random.choice(self.tempos)
        track.append(mido.MetaMessage('set_tempo', tempo=int(default_tempo), time=0))
        
        # Add time signature
        track.append(mido.MetaMessage('time_signature', numerator=4, denominator=4, time=0))
        
        # Store events with absolute timing first
        events = []
        current_time = 0
        
        for i, symbol in enumerate(sequence):
            if symbol.startswith('P'):
                try:
                    parts = symbol.split('_')
                    if len(parts) >= 4:
                        pitch_class = int(parts[0][1:])
                        octave = int(parts[1][1:])
                        vel_level = parts[2]
                        timing = parts[3]
                        
                        # Constrain octave to reasonable range
                        octave = max(3, min(7, octave))
                        note = (octave * 12) + pitch_class
                        
                        # Ensure note is in valid MIDI range (0-127)
                        note = max(0, min(127, note))
                        
                        # Convert velocity level to MIDI velocity
                        velocity_map = {
                            "pp": 35, "mp": 55, "mf": 75, "f": 95, "ff": 115
                        }
                        velocity = velocity_map.get(vel_level, 70)
                        
                        # Convert timing to ticks - make them shorter for better playback
                        timing_map = {
                            "T1": 240,   # Eighth note
                            "T2": 480,   # Quarter note  
                            "T3": 720,   # Dotted quarter
                            "T4": 960,   # Half note
                            "T5": 1440   # Dotted half
                        }
                        duration = timing_map.get(timing, 480)
                        
                        # Add note on event
                        events.append(('note_on', current_time, note, velocity))
                        # Add note off event
                        events.append(('note_off', current_time + int(duration * 0.9), note, 0))
                        
                        # Advance time for next note (with some overlap allowed)
                        current_time += int(duration * 0.7)  # 70% spacing allows overlap
                        
                except (ValueError, IndexError):
                    current_time += 480  # Default quarter note spacing
            else:
                current_time += 240  # Shorter default for non-notes
        
        # Sort events by time
        events.sort(key=lambda x: x[1])
        
        # Convert to MIDI messages with delta times
        last_time = 0
        for event in events:
            event_type, abs_time, note, velocity = event
            delta_time = abs_time - last_time
            
            if event_type == 'note_on':
                track.append(mido.Message('note_on', channel=0, note=note, 
                                        velocity=velocity, time=delta_time))
            else:  # note_off
                track.append(mido.Message('note_off', channel=0, note=note, 
                                        velocity=64, time=delta_time))
            
            last_time = abs_time
        
        # Add a final rest
        track.append(mido.Message('note_off', channel=0, note=60, velocity=0, time=480))
        
        mid.save(output_path)
        print(f"Generated MIDI saved as: {output_path}")
        print(f"Total events: {len(events)}")
        print(f"Duration: ~{current_time / 480:.1f} beats")
    
    def save_model(self, filepath):
        """Save the trained model"""
        model_data = {
            'trigram_model': dict(self.trigram_model),
            'tempos': self.tempos,
            'time_signatures': self.time_signatures
        }
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
        print(f"Model saved to: {filepath}")
    
    def load_model(self, filepath):
        """Load a trained model"""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)
        
        self.trigram_model = defaultdict(Counter, model_data['trigram_model'])
        self.tempos = model_data['tempos']
        self.time_signatures = model_data['time_signatures']
        print(f"Model loaded from: {filepath}")

# Test function to create a simple MIDI file for verification
def create_test_midi():
    """Create a simple test MIDI file to verify the MIDI writing works"""
    mid = MidiFile(ticks_per_beat=480)
    track = MidiTrack()
    mid.tracks.append(track)
    
    # Add tempo and time signature
    track.append(mido.MetaMessage('set_tempo', tempo=500000, time=0))
    track.append(mido.MetaMessage('time_signature', numerator=4, denominator=4, time=0))
    
    # Simple C major scale
    notes = [60, 62, 64, 65, 67, 69, 71, 72]  # C4 to C5
    for i, note in enumerate(notes):
        # Note on
        track.append(mido.Message('note_on', channel=0, note=note, velocity=80, time=0))
        # Note off after quarter note
        track.append(mido.Message('note_off', channel=0, note=note, velocity=0, time=480))
    
    mid.save("test_scale.mid")
    print("Test MIDI file created: test_scale.mid")

In [29]:
# Initialize the generator
generator = SymbolicMusicGenerator()

# Train on your dataset
dataset_path = "maestro-v3.0.0"
generator.train_on_dataset(dataset_path, max_files=300)  # Adjust max_files as needed

# Save the trained model
generator.save_model("music_model.pkl")

# Generate new music
print("\nGenerating new music...")
sequence = generator.generate_sequence(length=500)
generator.sequence_to_midi(sequence, "task1.mid")
print("Music generation completed!")

Scanning for MIDI files...
Found 1276 MIDI files
Processing 300 randomly selected files
Processing: MIDI-Unprocessed_22_R2_2006_01_ORIG_MID--AUDIO_22_R2_2006_04_Track04_wav.midi (1/300)
Processing: MIDI-Unprocessed_02_R1_2009_01-02_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_02_WAV.midi (2/300)
Processing: MIDI-UNPROCESSED_06-08_R1_2014_MID--AUDIO_08_R1_2014_wav--2.midi (3/300)
Processing: MIDI-Unprocessed_XP_04_R2_2004_01_ORIG_MID--AUDIO_04_R2_2004_01_Track01_wav.midi (4/300)
Processing: MIDI-UNPROCESSED_16-18_R1_2014_MID--AUDIO_18_R1_2014_wav--4.midi (5/300)
Processing: MIDI-Unprocessed_XP_21_R1_2004_02_ORIG_MID--AUDIO_21_R1_2004_02_Track02_wav.midi (6/300)
Processing: MIDI-Unprocessed_23_R3_2011_MID--AUDIO_R3-D8_05_Track05_wav.midi (7/300)
Processing: MIDI-Unprocessed_17_R2_2008_01-04_ORIG_MID--AUDIO_17_R2_2008_wav--2.midi (8/300)
Processing: MIDI-Unprocessed_XP_08_R1_2004_01-02_ORIG_MID--AUDIO_08_R1_2004_01_Track01_wav.midi (9/300)
Processing: MIDI-Unprocessed_08_R1_2011_MID--AUDIO_R1-D3