# 🎼 Task 2: Conditioned Symbolic Music Generation with LSTM
This notebook extends the Task 1 LSTM-based symbolic music generator by conditioning generation on chord tokens.

In [None]:
import numpy as np
import os
from music21 import converter, note, chord, stream, instrument
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import random
import pickle


## 🧩 Helper Class for Conditioning

In [None]:
class ConditionedMidiLSTM:
    def __init__(self):
        self.note_to_int = {}
        self.int_to_note = {}
        self.vocab_size = 0

    def parse_midi(self, file_path):
        midi = converter.parse(file_path)
        notes = []
        parts = instrument.partitionByInstrument(midi)

        if parts:  # file has instrument parts
            for element in parts.parts[0].recurse():
                if isinstance(element, note.Note):
                    notes.append(str(element.pitch))
                elif isinstance(element, chord.Chord):
                    notes.append('.'.join(str(n) for n in element.normalOrder))
        return notes

    def preprocess_midi_files(self, directory, max_files=None):
        all_notes = []
        files = list(Path(directory).rglob("*.mid"))[:max_files]
        for file in files:
            notes = self.parse_midi(file)
            if len(notes) > 0:
                # Fake chord condition (could be improved)
                chord_token = random.choice(['C', 'G', 'Am', 'F'])  
                all_notes.extend([chord_token] + notes)
        return all_notes

    def create_vocabulary(self, notes):
        unique_notes = sorted(set(notes))
        self.note_to_int = {note: i for i, note in enumerate(unique_notes)}
        self.int_to_note = {i: note for note, i in self.note_to_int.items()}
        self.vocab_size = len(unique_notes)
        return notes

    def create_sequences(self, notes, seq_length=50):
        inputs, targets = [], []
        for i in range(len(notes) - seq_length):
            seq_in = notes[i:i + seq_length]
            seq_out = notes[i + seq_length]
            inputs.append([self.note_to_int[n] for n in seq_in])
            targets.append(self.note_to_int[seq_out])
        return np.array(inputs), to_categorical(targets, num_classes=self.vocab_size)

    def build_model(self, seq_length):
        model = Sequential()
        model.add(Embedding(input_dim=self.vocab_size, output_dim=100, input_length=seq_length))
        model.add(LSTM(256, return_sequences=True))
        model.add(LSTM(256))
        model.add(Dense(self.vocab_size, activation='softmax'))
        model.compile(loss='categorical_crossentropy', optimizer='adam')
        return model

    def sample(self, preds, temperature=1.0):
        preds = np.log(preds + 1e-9) / temperature
        exp_preds = np.exp(preds)
        preds = exp_preds / np.sum(exp_preds)
        return np.random.choice(range(len(preds)), p=preds)

    def generate(self, model, seed_seq, length=100):
        result = []
        current_seq = seed_seq.copy()
        for _ in range(length):
            prediction = model.predict(np.array([current_seq]), verbose=0)[0]
            index = self.sample(prediction, temperature=0.9)
            result.append(index)
            current_seq = current_seq[1:] + [index]
        return result


## 🚀 Training

In [None]:
midi_lstm = ConditionedMidiLSTM()

# Load and process data
notes = midi_lstm.preprocess_midi_files("nottingham-dataset/MIDI_cleaned", max_files=30)
filtered_notes = midi_lstm.create_vocabulary(notes)
X, y = midi_lstm.create_sequences(filtered_notes, seq_length=50)

# Split and train
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=42)
model = midi_lstm.build_model(seq_length=50)
model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=20, batch_size=64)
model.save("conditioned_lstm_model.h5")


## 🎼 Generation

In [None]:
from music21 import stream, note, chord, instrument

# Pick a random seed sequence with a chord token at start
start = random.randint(0, len(X) - 1)
seed = list(X[start])

# Generate indices
generated = midi_lstm.generate(model, seed, length=100)
generated_notes = [midi_lstm.int_to_note[idx] for idx in generated]

# Convert to MIDI
output_stream = stream.Stream()
output_stream.append(instrument.Piano())
for token in generated_notes:
    if '.' in token:
        output_stream.append(chord.Chord([int(n) for n in token.split('.')]))
    else:
        output_stream.append(note.Note(token))
output_stream.write("midi", fp="task2_generated.mid")
