In [None]:
import os
import random
import pandas as pd
import pickle
from collections import defaultdict, Counter
import mido
from mido import MidiFile, MidiTrack, Message

class MelodyToChordGenerator:
    def __init__(self):
        self.conditional_model = defaultdict(Counter)
        self.chord_symbols = set()
        self.melody_symbols = set()
        self.tempos = []
        self.time_signatures = []

    def extract_melody_and_chords(self, midi_path):
        try:
            mid = MidiFile(midi_path)
        except Exception as e:
            print(f"Error loading {midi_path}: {e}")
            return [], []

        all_track_events = []
        for track in mid.tracks:
            abs_t = 0
            ev_list = []
            for msg in track:
                abs_t += msg.time
                if msg.type == 'set_tempo':
                    ev_list.append(('tempo', abs_t, msg.tempo))
                elif msg.type == 'time_signature':
                    ev_list.append(('time_sig', abs_t, f"{msg.numerator}/{msg.denominator}"))
                elif msg.type == 'note_on' and msg.velocity > 0:
                    ev_list.append(('note_on', abs_t, msg.note, msg.velocity))
                elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                    ev_list.append(('note_off', abs_t, msg.note))
            all_track_events.append(ev_list)

        def track_profile(events):
            time_to_notes = defaultdict(list)
            for ev in events:
                if ev[0] == 'note_on':
                    time_to_notes[ev[1]].append(ev[2])
            chord_count = sum(1 for notes in time_to_notes.values() if len(notes) >= 3)
            single_count = sum(1 for notes in time_to_notes.values() if len(notes) == 1)
            return chord_count, single_count

        profiles = [track_profile(ev) for ev in all_track_events]
        melody_idx = max(range(len(profiles)),
                         key=lambda i: (profiles[i][1] / (profiles[i][0] + 1),))
        chord_idx = max(range(len(profiles)),
                         key=lambda i: (profiles[i][0] / (profiles[i][1] + 1),))

        melody_events = []
        for ev in all_track_events[melody_idx]:
            if ev[0] == 'note_on':
                _, t, note, vel = ev
                melody_events.append((t, note, vel))

        chord_events = []
        time_to_notes = defaultdict(list)
        for ev in all_track_events[chord_idx]:
            if ev[0] == 'note_on':
                _, t, note, _ = ev
                time_to_notes[t].append(note)
        for t, notes in time_to_notes.items():
            if len(notes) >= 1:
                pcs = sorted(set(n % 12 for n in notes))
                if pcs == [0, 4, 7]:
                    chord_sym = "Cmaj"
                elif pcs == [2, 5, 9]:
                    chord_sym = "Dmin"
                elif pcs == [4, 7, 11]:
                    chord_sym = "Emaj"
                else:
                    chord_sym = "_".join(str(n) for n in sorted(notes))
                chord_events.append((t, chord_sym))
                self.chord_symbols.add(chord_sym)

        for tr in all_track_events:
            for ev in tr:
                if ev[0] == 'tempo':
                    self.tempos.append(ev[2])
                elif ev[0] == 'time_sig':
                    self.time_signatures.append(ev[2])

        melody_events.sort(key=lambda x: x[0])
        chord_events.sort(key=lambda x: x[0])
        return melody_events, chord_events

    def create_symbolic_representation_for_harmony(self, melody_events, chord_events):
        GRID = 480
        melody_bins = defaultdict(list)
        chord_bins = {}

        for (t, note, vel) in melody_events:
            b = t // GRID
            melody_bins[b].append((note, vel))

        for b, notes in melody_bins.items():
            highest_note, highest_vel = max(notes, key=lambda x: x[0])
            if highest_vel < 40:
                vel_level = "pp"
            elif highest_vel < 70:
                vel_level = "mp"
            elif highest_vel < 90:
                vel_level = "mf"
            elif highest_vel < 110:
                vel_level = "f"
            else:
                vel_level = "ff"
            pc = highest_note % 12
            octv = highest_note // 12
            sym = f"M{pc}_O{octv}_{vel_level}"
            self.melody_symbols.add(sym)
            melody_bins[b] = sym 

        for (t, chord_sym) in chord_events:
            b = t // GRID
            chord_bins[b] = chord_sym

        max_bin = 0
        if melody_bins:
            max_bin = max(max_bin, max(melody_bins.keys()))
        if chord_bins:
            max_bin = max(max_bin, max(chord_bins.keys()))

        melody_seq = []
        chord_seq = []
        for b in range(max_bin + 1):
            m_sym = melody_bins.get(b, "REST")
            c_sym = chord_bins.get(b, "NOC")
            melody_seq.append(m_sym)
            chord_seq.append(c_sym)

        return melody_seq, chord_seq

    def build_conditional_model(self, melody_seq, chord_seq):
        for i in range(1, len(melody_seq)):
            key = (melody_seq[i - 1], melody_seq[i])
            next_chord = chord_seq[i]
            self.conditional_model[key][next_chord] += 1

    def train_on_dataset(self, dataset_root, csv_path, max_files=500, split="train"):
        print("Loading MAESTRO CSV…")
        df = pd.read_csv(csv_path)
        df = df[df["split"] == split].reset_index(drop=True)
        print(f"Total '{split}' rows in CSV: {len(df)}")

        all_midi_paths = []
        for rel_path in df["midi_filename"].tolist():
            full_path = os.path.join(dataset_root, rel_path)
            if os.path.exists(full_path):
                all_midi_paths.append(full_path)
            else:
                print(f" WARNING: {full_path} not found on disk!")
        print(f"Found {len(all_midi_paths)} files on disk (after filtering by split).")

        if len(all_midi_paths) > max_files:
            all_midi_paths = random.sample(all_midi_paths, max_files)
            print(f"Processing {max_files} randomly selected files…")

        processed_count = 0
        for midi_file in all_midi_paths:
            mel_events, chord_events = self.extract_melody_and_chords(midi_file)
            if not mel_events or not chord_events:
                continue

            melody_seq, chord_seq = self.create_symbolic_representation_for_harmony(
                mel_events, chord_events
            )
            if len(melody_seq) < 10:
                continue

            self.build_conditional_model(melody_seq, chord_seq)
            processed_count += 1
            if processed_count % 20 == 0:
                print(f"Trained on {processed_count}/{len(all_midi_paths)} files…")

        print(f"\n➤ Training complete on {processed_count} files.")
        print(f"Learned {len(self.conditional_model)} melody‐bigram → chord counters.")
        print(f"Distinct melody symbols: {len(self.melody_symbols)}")
        print(f"Distinct chord symbols: {len(self.chord_symbols)}")

    def generate_harmony(self, melody_sequence, seed_chord=None):
        if not self.conditional_model:
            raise ValueError("The model hasn’t been trained yet.")

        L = len(melody_sequence)
        chord_sequence = ["NOC"] * L

        if seed_chord and seed_chord in self.chord_symbols:
            chord_sequence[0] = seed_chord
        else:
            total_counts = Counter()
            for ctr in self.conditional_model.values():
                total_counts.update(ctr)
            if total_counts:
                chord_sequence[0] = total_counts.most_common(1)[0][0]
            else:
                chord_sequence[0] = "Cmaj"

        for i in range(1, L):
            key = (melody_sequence[i - 1], melody_sequence[i])
            cnt = self.conditional_model.get(key, None)
            if cnt and sum(cnt.values()) > 0:
                items, counts = zip(*cnt.items())
                chord_sequence[i] = random.choices(items, weights=counts, k=1)[0]
            else:
                chord_sequence[i] = chord_sequence[i - 1]

        return chord_sequence

    def harmony_sequence_to_midi(self, melody_sequence, chord_sequence, output_path="harmonized_output_fixed.mid"):
        import random
        from mido import MidiFile, MidiTrack, Message, MetaMessage

        mid = MidiFile(ticks_per_beat=480)
        track = MidiTrack()
        mid.tracks.append(track)

        if self.tempos:
            chosen_tempo = random.choice(self.tempos)
        else:
            chosen_tempo = 500000
        track.append(MetaMessage('set_tempo', tempo=int(chosen_tempo), time=0))

        if self.time_signatures:
            num, den = map(int, self.time_signatures[0].split('/'))
        else:
            num, den = 4, 4
        track.append(MetaMessage('time_signature', numerator=num, denominator=den, time=0))

        events = []
        GRID = 480

        for i, (m_sym, c_sym) in enumerate(zip(melody_sequence, chord_sequence)):
            bin_start = i * GRID

            if m_sym != "REST":
                parts = m_sym.split('_')
                pc    = int(parts[0][1:])
                octv  = int(parts[1][1:])
                vel_lv = parts[2]
                vel_map = {"pp": 35, "mp": 55, "mf": 75, "f": 95, "ff": 115}
                velocity = vel_map.get(vel_lv, 70)
                note = octv * 12 + pc
                note = max(0, min(127, note))

                events.append((bin_start, Message('note_on', channel=0, note=note, velocity=velocity)))
                events.append((bin_start + GRID, Message('note_off', channel=0, note=note, velocity=0)))

            if c_sym != "NOC":
                if c_sym == "Cmaj":
                    chord_pitches = [60, 64, 67]
                elif c_sym == "Dmin":
                    chord_pitches = [62, 65, 69]
                elif c_sym == "Emaj":
                    chord_pitches = [64, 68, 71]
                else:
                    chord_pitches = [int(x) for x in c_sym.split('_')]

                for p in chord_pitches:
                    events.append((bin_start, Message('note_on', channel=1, note=p, velocity=60)))
                for p in chord_pitches:
                    events.append((bin_start + (GRID * 2), Message('note_off', channel=1, note=p, velocity=0)))

        events.sort(key=lambda x: x[0])

        last_time = 0
        for abs_time, msg in events:
            delta = abs_time - last_time
            msg.time = int(delta)
            track.append(msg)
            last_time = abs_time

        mid.save(output_path)
        print(f"✅ Harmonized MIDI (aligned) saved as: {output_path}")

    def save_model(self, filepath):
        data = {
            "conditional_model": dict(self.conditional_model),
            "melody_symbols": list(self.melody_symbols),
            "chord_symbols": list(self.chord_symbols),
            "tempos": self.tempos,
            "time_signatures": self.time_signatures
        }
        with open(filepath, "wb") as f:
            pickle.dump(data, f)
        print(f"Model saved to: {filepath}")

    def load_model(self, filepath):
        with open(filepath, "rb") as f:
            data = pickle.load(f)
        self.conditional_model = defaultdict(Counter, data["conditional_model"])
        self.melody_symbols    = set(data["melody_symbols"])
        self.chord_symbols     = set(data["chord_symbols"])
        self.tempos            = data["tempos"]
        self.time_signatures   = data["time_signatures"]
        print(f"Model loaded from: {filepath}")

dataset_root = "maestro-v3.0.0"
csv_path     = "maestro-v3.0.0/maestro-v3.0.0.csv"

harm_gen = MelodyToChordGenerator()
harm_gen.train_on_dataset(dataset_root, csv_path, max_files=300, split="train")
harm_gen.save_model("mel2chord_model.pkl")

Loading MAESTRO CSV…
Total 'train' rows in CSV: 962
Found 962 files on disk (after filtering by split).
Processing 300 randomly selected files…
  Trained on 20/300 files…
  Trained on 40/300 files…
  Trained on 60/300 files…
  Trained on 80/300 files…
  Trained on 100/300 files…
  Trained on 120/300 files…
  Trained on 140/300 files…
  Trained on 160/300 files…
  Trained on 180/300 files…
  Trained on 200/300 files…
  Trained on 220/300 files…
  Trained on 240/300 files…

➤ Training complete on 257 files.
Learned 27784 melody‐bigram → chord counters.
Distinct melody symbols: 415
Distinct chord symbols:  5769
Model saved to: mel2chord_model.pkl


In [None]:
mel_ev, _ = harm_gen.extract_melody_and_chords("task1.mid")
melody_seq, _ = harm_gen.create_symbolic_representation_for_harmony(mel_ev, [])
chord_seq = harm_gen.generate_harmony(melody_seq, seed_chord="Cmaj")
harm_gen.harmony_sequence_to_midi(melody_seq, chord_seq, "harmonized_output2.mid")


✅ Harmonized MIDI (aligned) saved as: harmonized_output2.mid
