<a href="https://colab.research.google.com/github/pnandini-sdu/AAI-511-Final-Project/blob/main/FinalTeam_Project_AAI511_CNN_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=c7df02adc8452537bfc15a87086a0bb236be4cc8312b58475ba222c4ccd3a908
  Stored in directory: /root/.cache/pip/wheels/e6/95/ac/15ceaeb2823b04d8e638fd1495357adb8d26c00ccac9d7782e
Successfully built pretty_midi
Installing collected packages: mido, pretty_midi
Successf

In [None]:
import os, glob, itertools, math
import numpy as np
import tensorflow as tf
from tensorflow import keras
import pretty_midi
from collections import defaultdict, Counter
from pathlib import Path

# Mounting content from my google drive
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

# ======================
# Config (from EDA)
# ======================
DATA_DIR = "/content/drive/My Drive/music_dataset"
CLASSES    = ["Bach", "Beethoven", "Chopin", "Mozart"]
MIDI_EXTS  = [".mid", ".midi"]

FS         = 8            # frames/sec for piano roll
WIN_SECS   = 10           # window length
HOP_SECS   = 5            # hop
WIN        = WIN_SECS * FS
HOP        = HOP_SECS * FS

PITCH_LOW  = 25           # crop inclusive
PITCH_HIGH = 100          # crop inclusive
PITCHES    = PITCH_HIGH - PITCH_LOW + 1  # = 75

BATCH_SIZE = 16
EPOCHS     = 40
LR         = 4e-4
SEED       = 1337

rng = np.random.default_rng(SEED)
tf.keras.utils.set_random_seed(SEED)


def iter_midi_files(root: Path):
    # Recursively yield all .mid/.midi (case‑insensitive), skipping hidden dirs
    for p in root.rglob("*"):
        if p.is_file() and p.suffix.lower() in MIDI_EXTS and not any(part.startswith(".") for part in p.parts):
            yield str(p)

def list_files_recursive(data_dir, classes):
    # Ensure it's a Path object
    data_dir = Path(data_dir)

    files, labels = [], []
    for idx, comp in enumerate(classes):
        comp_dir = data_dir / comp
        if not comp_dir.exists():
            print(f"[warn] missing composer dir: {comp_dir}")
            continue

        for file_path in comp_dir.rglob("*"):
            if file_path.suffix.lower() in (".mid", ".midi"):
                files.append(str(file_path))
                labels.append(idx)

    return files, labels

# ======================
# Split (reuse or create)
# ======================


files, labels = list_files_recursive(DATA_DIR, CLASSES)

# Stratified piece-level split (feel free to replace with your saved LSTM split)
def make_split(files, labels, val_ratio=0.15, test_ratio=0.15):
    by_class = defaultdict(list)
    for p, y in zip(files, labels):
        by_class[y].append(p)
    train, val, test = [], [], []
    for y, items in by_class.items():
        rng.shuffle(items)
        n=len(items); n_test=round(n*test_ratio); n_val=round(n*val_ratio)
        test.extend((p,y) for p in items[:n_test])
        val .extend((p,y) for p in items[n_test:n_test+n_val])
        train.extend((p,y) for p in items[n_test+n_val:])
    rng.shuffle(train); rng.shuffle(val); rng.shuffle(test)
    return train, val, test

train, val, test = make_split(files, labels)
print(f"Pieces -> train: {len(train)}, val: {len(val)}, test: {len(test)}")

# ======================
# MIDI -> Piano-roll utils
# ======================
def midi_to_roll(path, fs=FS):
    """Return (128, T) binary piano-roll; robust to odd files."""
    try:
        pm = pretty_midi.PrettyMIDI(path)
        T = int(np.ceil(pm.get_end_time() * fs)) + 1
        roll = np.zeros((128, T), dtype=np.uint8)
        for inst in pm.instruments:
            for note in inst.notes:
                s = max(0, int(np.floor(note.start * fs)))
                e = min(T, int(np.ceil(note.end * fs)))
                if e > s:
                    roll[note.pitch, s:e] = 1
        return roll
    except Exception as e:
        print(f"[warn] failed {path}: {e}")
        return None

def cap_sustain(roll, max_secs=4.0, fs=FS):
    """Cap continuous '1' runs per pitch to reduce pedal-induced blobs."""
    max_len = int(max_secs * fs)
    R = roll.copy()
    for p in range(R.shape[0]):
        row = R[p]
        i = 0
        while i < row.size:
            if row[i] == 1:
                j = i
                while j < row.size and row[j] == 1:
                    j += 1
                if (j - i) > max_len:
                    row[i+max_len:j] = 0
                i = j
            else:
                i += 1
    return R

def crop_roll(roll, low=PITCH_LOW, high=PITCH_HIGH):
    return roll[low:high+1]

def window_roll(roll, win=WIN, hop=HOP):
    T = roll.shape[1]
    if T < win:
        pad = np.zeros((roll.shape[0], win - T), dtype=roll.dtype)
        return [np.concatenate([roll, pad], axis=1)]
    return [roll[:, s:s+win] for s in range(0, T - win + 1, hop)]

# ======================
# Augmentation (conservative from EDA)
# ======================
def augment_roll(roll, max_semitones=2, time_scales=(0.9, 1.0, 1.1)):
    # pitch shift
    st = rng.integers(-max_semitones, max_semitones+1)
    aug = roll
    if st != 0:
        aug = np.roll(aug, shift=st, axis=0)
        if st > 0:  aug[:st, :] = 0
        else:       aug[st:, :] = 0
    # time stretch
    scale = rng.choice(time_scales)
    if scale != 1.0:
        T = aug.shape[1]
        new_T = int(round(T * scale))
        idx = np.clip((np.arange(new_T) / scale).astype(int), 0, T-1)
        aug = aug[:, idx]
    return aug

def to_input(chunk):
    x = chunk.astype(np.float32)  # keep density info
    return x[..., None]           # (H, W, 1)

# ======================
# tf.data pipeline
# ======================
def gen_examples(pairs, augment=False):
    for path, y in pairs:
        roll = midi_to_roll(path)
        if roll is None:
            continue
        roll = cap_sustain(roll, max_secs=4.0, fs=FS)
        roll = crop_roll(roll)
        for ch in window_roll(roll):
            if augment:
                ch = augment_roll(ch)
                # re-fit to exact window after time-stretch
                if ch.shape[1] < WIN:
                    pad = np.zeros((PITCHES, WIN - ch.shape[1]), dtype=ch.dtype)
                    ch = np.concatenate([ch, pad], axis=1)
                elif ch.shape[1] > WIN:
                    ch = ch[:, :WIN]
            yield to_input(ch), y

def make_dataset(pairs, shuffle=True, augment=False):
    spec = (
        tf.TensorSpec(shape=(PITCHES, WIN, 1), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
    ds = tf.data.Dataset.from_generator(lambda: gen_examples(pairs, augment=augment),
                                        output_signature=spec)
    if shuffle:
        ds = ds.shuffle(8192, seed=SEED, reshuffle_each_iteration=True)
    return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

train_ds = make_dataset(train, shuffle=True,  augment=True)
val_ds   = make_dataset(val,   shuffle=False, augment=False)
test_ds  = make_dataset(test,  shuffle=False, augment=False)

# ======================
# Class weights from WINDOW COUNTS (post-slicing)
# ======================
def count_windows(pairs):
    counts = Counter()
    for path, y in pairs:
        roll = midi_to_roll(path)
        if roll is None:
            continue
        n = len(window_roll(crop_roll(cap_sustain(roll))))
        counts[y] += n
    return counts

win_counts = count_windows(train)
total = sum(win_counts.values())
class_weight = {}
for k in range(len(CLASSES)):
    # inverse frequency
    freq = win_counts.get(k, 1)
    class_weight[k] = total / (len(CLASSES) * freq)
print("Window counts per class:", {CLASSES[k]: v for k, v in win_counts.items()})
print("Class weights:", {CLASSES[k]: round(v, 3) for k, v in class_weight.items()})

# ======================
# CNN (EDA-informed: mixed kernels, dropout, BN)
# ======================
def build_cnn(input_shape=(PITCHES, WIN, 1), num_classes=len(CLASSES)):
    inp = keras.layers.Input(shape=input_shape)
    x = inp

    # Block 1: small textures
    x = keras.layers.Conv2D(32, (3,3), padding="same", activation="relu")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Conv2D(32, (3,3), padding="same", activation="relu")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.MaxPooling2D((2,2))(x)         # H/2, W/2
    x = keras.layers.Dropout(0.2)(x)

    # Block 2: widen features
    x = keras.layers.Conv2D(64, (3,3), padding="same", activation="relu")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Conv2D(64, (5,3), padding="same", activation="relu")(x)  # slightly taller
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.MaxPooling2D((2,2))(x)         # H/4, W/4
    x = keras.layers.Dropout(0.25)(x)

    # Block 3: capture large leaps / thick chords
    x = keras.layers.Conv2D(128, (7,3), padding="same", activation="relu")(x) # tall kernel
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Conv2D(128, (3,3), padding="same", activation="relu")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.MaxPooling2D((2,2))(x)         # H/8, W/8
    x = keras.layers.Dropout(0.3)(x)

    # Head
    x = keras.layers.Conv2D(256, (3,3), padding="same", activation="relu")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.4)(x)
    x = keras.layers.Dense(128, activation="relu")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Dropout(0.35)(x)
    out = keras.layers.Dense(num_classes, activation="softmax")(x)

    model = keras.Model(inp, out)
    model.compile(optimizer=tf.keras.optimizers.Adam(LR),
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])
    return model

model = build_cnn()
model.summary()

# ======================
# Train
# ======================
callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor="val_accuracy", factor=0.5, patience=4, min_lr=1e-5, verbose=1),
    tf.keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=8, restore_best_weights=True, verbose=1)
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weight,
    callbacks=callbacks
)


Mounted at /content/drive
Pieces -> train: 1140, val: 245, test: 245




Window counts per class: {'Beethoven': 15806, 'Mozart': 13972, 'Bach': 21660, 'Chopin': 4183}
Class weights: {'Bach': 0.642, 'Beethoven': 0.88, 'Chopin': 3.324, 'Mozart': 0.995}


Epoch 1/40
   3477/Unknown [1m199s[0m 46ms/step - accuracy: 0.5164 - loss: 1.2112



[warn] failed /content/drive/My Drive/music_dataset/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
[1m3477/3477[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m295s[0m 73ms/step - accuracy: 0.5164 - loss: 1.2112 - val_accuracy: 0.5894 - val_loss: 1.0033 - learning_rate: 4.0000e-04
Epoch 2/40
[1m3477/3477[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step - accuracy: 0.6689 - loss: 0.7819[warn] failed /content/drive/My Drive/music_dataset/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
[1m3477/3477[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m210s[0m 53ms/step - accuracy: 0.6689 - loss: 0.7819 - val_accuracy: 0.5072 - val_loss: 1.2423 - learning_rate: 4.0000e-04
Epoch 3/40
[1m3474/3477[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 44ms/step - accuracy: 0.7151 - loss: 0.6797[warn] failed /content/drive/My Drive/music_dataset/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
[1m3477/3

**Model tunning**

In [None]:
import os, glob, itertools, math
import numpy as np
import tensorflow as tf
from tensorflow import keras
import pretty_midi
from collections import defaultdict, Counter
from pathlib import Path

# Mounting content from my google drive
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

# ======================
# Config (from EDA)
# ======================
DATA_DIR = "/content/drive/My Drive/music_dataset"
CLASSES    = ["Bach", "Beethoven", "Chopin", "Mozart"]
MIDI_EXTS  = [".mid", ".midi"]

FS         = 8            # frames/sec for piano roll
WIN_SECS   = 10           # window length
HOP_SECS   = 5            # hop
WIN        = WIN_SECS * FS
HOP        = HOP_SECS * FS

PITCH_LOW  = 25           # crop inclusive
PITCH_HIGH = 100          # crop inclusive
PITCHES    = PITCH_HIGH - PITCH_LOW + 1  # = 75

BATCH_SIZE = 16
EPOCHS     = 40
LR         = 4e-4
SEED       = 1337

rng = np.random.default_rng(SEED)
tf.keras.utils.set_random_seed(SEED)


def iter_midi_files(root: Path):
    # Recursively yield all .mid/.midi (case‑insensitive), skipping hidden dirs
    for p in root.rglob("*"):
        if p.is_file() and p.suffix.lower() in MIDI_EXTS and not any(part.startswith(".") for part in p.parts):
            yield str(p)

def list_files_recursive(data_dir, classes):
    # Ensure it's a Path object
    data_dir = Path(data_dir)

    files, labels = [], []
    for idx, comp in enumerate(classes):
        comp_dir = data_dir / comp
        if not comp_dir.exists():
            print(f"[warn] missing composer dir: {comp_dir}")
            continue

        for file_path in comp_dir.rglob("*"):
            if file_path.suffix.lower() in (".mid", ".midi"):
                files.append(str(file_path))
                labels.append(idx)

    return files, labels

# ======================
# Split (reuse or create)
# ======================


files, labels = list_files_recursive(DATA_DIR, CLASSES)

# Stratified piece-level split (feel free to replace with your saved LSTM split)
def make_split(files, labels, val_ratio=0.15, test_ratio=0.15):
    by_class = defaultdict(list)
    for p, y in zip(files, labels):
        by_class[y].append(p)
    train, val, test = [], [], []
    for y, items in by_class.items():
        rng.shuffle(items)
        n=len(items); n_test=round(n*test_ratio); n_val=round(n*val_ratio)
        test.extend((p,y) for p in items[:n_test])
        val .extend((p,y) for p in items[n_test:n_test+n_val])
        train.extend((p,y) for p in items[n_test+n_val:])
    rng.shuffle(train); rng.shuffle(val); rng.shuffle(test)
    return train, val, test

train, val, test = make_split(files, labels)
print(f"Pieces -> train: {len(train)}, val: {len(val)}, test: {len(test)}")


# --- Multi-channel MIDI loader ---
def midi_to_roll_multich(path, fs=FS):
    try:
        pm = pretty_midi.PrettyMIDI(path)
        T = int(np.ceil(pm.get_end_time() * fs)) + 1
        on = np.zeros((128, T), dtype=np.float32)
        vel = np.zeros((128, T), dtype=np.float32)
        for inst in pm.instruments:
            for note in inst.notes:
                s = max(0, int(np.floor(note.start * fs)))
                e = min(T, int(np.ceil(note.end * fs)))
                on[note.pitch, s:e] = 1.0
                vel[note.pitch, s:e] = max(vel[note.pitch, s:e].max(), note.velocity / 127.0)
        onset = np.zeros_like(on)
        onset[:, 1:] = np.maximum(0.0, on[:, 1:] - on[:, :-1])
        return on, vel, onset
    except Exception as e:
        print(f"[warn] failed {path}: {e}")
        return None

# --- Multi-channel input stacker ---
def to_input_multich(on, vel, onset):
    return np.stack([on, vel, onset], axis=-1)  # (H, W, 3)

# --- Balanced training dataset generator ---
STEPS_PER_CLASS = 1200  # tune as needed

def balanced_train_ds_multich(pairs):
    per_class = {k: [] for k in range(len(CLASSES))}
    for path, y in pairs:
        per_class[y].append(path)
    for k in per_class:
        rng.shuffle(per_class[k])

    def gen():
        counts = {k: 0 for k in per_class}
        iters = {k: iter(per_class[k]) for k in per_class}
        while min(counts.values()) < STEPS_PER_CLASS:
            for k in per_class:
                if counts[k] >= STEPS_PER_CLASS:
                    continue
                try:
                    p = next(iters[k])
                except StopIteration:
                    continue
                m = midi_to_roll_multich(p)
                if m is None:
                    continue
                on, vel, onset = m
                on = crop_roll(cap_sustain(on))
                vel = crop_roll(cap_sustain(vel))
                onset = crop_roll(cap_sustain(onset))
                on_windows = window_roll(on)
                vel_windows = window_roll(vel)
                onset_windows = window_roll(onset)
                for chunk_on, chunk_vel, chunk_onset in zip(on_windows, vel_windows, onset_windows):
                    yield to_input_multich(chunk_on, chunk_vel, chunk_onset), k
                    counts[k] += 1
                    if counts[k] >= STEPS_PER_CLASS:
                        break

    spec = (tf.TensorSpec((PITCHES, WIN, 3), tf.float32),
            tf.TensorSpec((), tf.int32))
    return tf.data.Dataset.from_generator(gen, output_signature=spec).shuffle(8192).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# --- Standard dataset generator for val/test ---
def gen_examples_multich(pairs, augment=False):
    for path, y in pairs:
        m = midi_to_roll_multich(path)
        if m is None:
            continue
        on, vel, onset = m
        on = cap_sustain(on)
        vel = cap_sustain(vel)
        onset = cap_sustain(onset)

        on = crop_roll(on)
        vel = crop_roll(vel)
        onset = crop_roll(onset)

        on_windows = window_roll(on)
        vel_windows = window_roll(vel)
        onset_windows = window_roll(onset)

        for chunk_on, chunk_vel, chunk_onset in zip(on_windows, vel_windows, onset_windows):
            if augment:
                # Optional augmentation code here (same as before) if you want
                pass
            yield to_input_multich(chunk_on, chunk_vel, chunk_onset), y

def make_dataset_multich(pairs, shuffle=True, augment=False):
    spec = (
        tf.TensorSpec(shape=(PITCHES, WIN, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
    ds = tf.data.Dataset.from_generator(lambda: gen_examples_multich(pairs, augment=augment),
                                        output_signature=spec)
    if shuffle:
        ds = ds.shuffle(8192, seed=SEED, reshuffle_each_iteration=True)
    return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# --- Create datasets ---
train_ds = balanced_train_ds_multich(train)
val_ds   = make_dataset_multich(val, shuffle=False, augment=False)
test_ds  = make_dataset_multich(test, shuffle=False, augment=False)

# --- Build model with updated input shape ---
model = build_cnn(input_shape=(PITCHES, WIN, 3))
model.summary()

# --- Train without class weights ---
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks
)
