In [None]:
# Install pretty_midi if needed
!pip install -q pretty_midi
!pip install -q music21

In [None]:
import mido
import os
import pretty_midi
from IPython.display import Audio
from collections import defaultdict
from music21 import converter, chord, stream, tempo
import numpy as np

In [None]:
# Get all files in the target directory
all_midi_files = [f for f in os.listdir('/root/class_stuff/musicml/assign2/midi_files/') if f.endswith('.mid')]

In [None]:
print(all_midi_files[0])
print(len(all_midi_files))

In [None]:
def get_instruments(midi_file):
    """
    Extracts the instruments from a MIDI file.
    """
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    return [instrument.name for instrument in midi_data.instruments if instrument.is_drum is False]

dataroot = '/root/class_stuff/musicml/assign2/midi_files/'

all_instruments = []
num_instruments = []
only_plpr = []
for midi_file in all_midi_files:
    full_path = os.path.join(dataroot, midi_file)
    instruments = get_instruments(full_path)
    format_instrument_names = [inst.lower() for inst in instruments]
    all_instruments.extend(format_instrument_names)
    num_instruments.append(len(instruments))
    if 'piano right' in format_instrument_names and 'piano left' in format_instrument_names:
        only_plpr.append(midi_file)
        if len(format_instrument_names) != 2:
            print(f"File {midi_file} has more than two instruments: {format_instrument_names}") 

print(f'Tota number of MIDI files: {len(all_midi_files)}')
unique_instruments, counts = np.unique(all_instruments, return_counts = True)
for i, inst in enumerate(unique_instruments):
    print(f"{inst}: {counts[i]}")

print(np.mean(num_instruments))
print(len(only_plpr))



In [None]:
from music21 import converter, chord, stream, tempo, pitch
import numpy as np
import pretty_midi
from IPython.display import Audio


def extract_chords_as_array(
    midi_path,
    interval=0.1,
    playback: str = None   # options: 'original', 'chords', 'both', or None
):
    """
    Returns:
      chords_list            # your original list of chord-events
      normalized_chords_list # same, but each chord’s pitches_midi all equal
      interval_array         # at each time-step, lists of active chords’ MIDI lists
      normalized_audio       # waveform of the normalized-chord MIDI
    """
    score = converter.parse(midi_path)
    chords = score.chordify()
    mm = score.recurse().getElementsByClass(tempo.MetronomeMark).first()
    bpm = mm.number if mm else 120
    # print(f"Detected BPM: {bpm}")
    duration = chords.highestTime

    time_steps = np.arange(0, duration, interval)
    chord_stream = stream.Part()
    chord_stream.insert(0, tempo.MetronomeMark(number=bpm))
    chords_list = []

    last_chord = None
    last_offset = 0.0

    all_pitches = []
    for t in time_steps:
        elems = chords.flat.getElementsByOffset(t, mustBeginInSpan=False)
        this_chord = next((e for e in elems if isinstance(e, chord.Chord)), None)
        if this_chord is not None:
            if last_chord is None or set(this_chord.normalOrder) != set(last_chord.normalOrder):
                if last_chord:
                    dur = t - last_offset
                    last_chord.quarterLength = dur
                    chord_stream.insert(last_offset, last_chord)
                    chords_list.append({
                        'offset':      last_offset,
                        'duration':    dur,
                        'pitches_midi': [p.midi for p in last_chord.pitches],
                        'pitches_name': [p.nameWithOctave for p in last_chord.pitches]
                    })
                last_chord = chord.Chord(this_chord)  # copy
                last_offset = t

    if last_chord:
        dur = duration - last_offset
        last_chord.quarterLength = dur
        chord_stream.insert(last_offset, last_chord)
        chords_list.append({
            'offset':      last_offset,
            'duration':    dur,
            'pitches_midi': [p.midi for p in last_chord.pitches],
            'pitches_name': [p.nameWithOctave for p in last_chord.pitches]
        })


    chordified_fp = 'chord_progression_clean.mid'
    chord_stream.write('midi', fp=chordified_fp)

    def synth_and_display(midi_fp, label):
        pm = pretty_midi.PrettyMIDI(midi_fp)
        audio = pm.synthesize(fs=44100)
        print(f"Playing {label}...")
        display(Audio(audio, rate=44100, autoplay=True))
        return audio

    # if playback in ('original', 'both'):
    #     _ = synth_and_display(midi_path, 'original MIDI')
    # if playback in ('chords',   'both'):
    #     _ = synth_and_display(chordified_fp, 'chordified progression')

    # --- 3) Build normalized_chords_list ---
    normalized_chords_list = []
    for evt in chords_list:
        orig = evt['pitches_midi']
        if orig:
            m = int(round(np.mean(orig)))
            name = pitch.Pitch(m).nameWithOctave
            norm_midi = [m] * len(orig)
            norm_names = [name] * len(orig)
        else:
            norm_midi = []
            norm_names = []
        normalized_chords_list.append({
            'offset':       evt['offset'],
            'duration':     evt['duration'],
            'pitches_midi': norm_midi,
            'pitches_name': norm_names
        })

    interval_array = []
    for t in time_steps:
        # find all events covering t
        active = [
            evt['pitches_midi']
            for evt in chords_list
            if evt['offset'] <= t < evt['offset'] + evt['duration']
        ]
        interval_array.append(active)

    norm_stream = stream.Part()
    norm_stream.insert(0, tempo.MetronomeMark(number=bpm))
    for evt in normalized_chords_list:
        c = chord.Chord(evt['pitches_midi'])
        c.quarterLength = evt['duration']
        norm_stream.insert(evt['offset'], c)

    norm_fp = 'normalized_chord_progression.mid'
    norm_stream.write('midi', fp=norm_fp)
    normalized_audio = pretty_midi.PrettyMIDI(norm_fp).synthesize(fs=44100)
    # print("Playing normalized-pitch progression…")
    # display(Audio(normalized_audio, rate=44100, autoplay=True))

    return chords_list, normalized_chords_list, interval_array, normalized_audio

test_file = dataroot + only_plpr[0]

# chord_array = extract_chords_as_array(test_file, playback=None)

# 2) Play only the chordified version:
# chord_array = extract_chords_as_array(test_file, playback='chords')

# 3) Play both original then chordified:
chord_array, interval_array, per_step, normalized_audio = extract_chords_as_array(test_file, playback='both')

# 4) Play only original:
# chord_array = extract_chords_as_array(test_file, playback='original')

In [None]:
print(chord_array[:10])
chord_pitch_dict = {}
for chord in chord_array:
    midi_pitches = chord['pitches_midi']
    pitch_names = chord['pitches_name']

    for i, pitch in enumerate(midi_pitches):
        if pitch_names[i] not in chord_pitch_dict.keys():
            chord_pitch_dict[pitch_names[i]] = [pitch]
        else:
            if pitch not in chord_pitch_dict[pitch_names[i]]:   
                chord_pitch_dict[pitch_names[i]].append(pitch)

print(chord_pitch_dict)


In [None]:
midi_data = pretty_midi.PrettyMIDI(test_file)
audio_data = midi_data.synthesize(fs=44100)
Audio(audio_data, rate=44100, autoplay=True)

In [None]:
import json
# only_plpr_info = {}
# successful = 0
# for file in only_plpr:
#     try:
#         chord_array, interval_array, per_step, _ = extract_chords_as_array(dataroot + file, playback='both')
#         only_plpr_info[file] = {'chord_array': chord_array,
#                                 'interval_array': interval_array,
#                                 'per_step': per_step}
#         successful+= 1
#         if successful % 20 == 0 or successful == 1:
#             with open('data.json', 'w') as f:
#                 json.dump(only_plpr_info, f, indent=4)

#     except:
#         print(f'Extracting midi from {file} failed')

In [None]:
with open('full_data.json', 'r') as file:
    data = json.load(file)


Count all unique chords and the number of times they appear in the dataset. To narrow the space of possible chords, we restrict our random forest to only the chords that appear greater than the overall mean

In [None]:
def stringify_array(arr):
    str_arr = []
    for ele in arr:
        str_arr.append(str(ele))
    return str_arr

chord_pitch_dict = {}
for i, itms in data.items():
    chord_array = itms['chord_array']
    for chord in chord_array:
        midi_pitches = chord['pitches_midi']
        pitch_names = chord['pitches_name']

        for i, pitch in enumerate(midi_pitches):
            if pitch_names[i] not in chord_pitch_dict.keys():
                chord_pitch_dict[pitch_names[i]] = [pitch, 1]
            else:
                if pitch not in chord_pitch_dict[pitch_names[i]]:   
                    chord_pitch_dict[pitch_names[i]].append(pitch)
                else:
                    chord_pitch_dict[pitch_names[i]][1] += 1

print(chord_pitch_dict)
print(len(chord_pitch_dict))
print("Filtered chord pitch dict:")
filteredChordPitch = {}
totalCounts = []
for chord, counts in chord_pitch_dict.items():
    totalCounts.append(counts[1])
    if counts[1] < 7000:
        continue
    filteredChordPitch[chord] = counts
print(filteredChordPitch)
print(len(filteredChordPitch))
print(np.mean(totalCounts))

Get a reversed key-value dictionary for ease of use later

In [None]:
filteredPitchChord = {}
for chord, pitch in filteredChordPitch.items():
    filteredPitchChord[pitch[0]] = chord

# Make sure they are the same size.
assert len(filteredPitchChord) == len(filteredChordPitch)

Now, we look at the actual time stepped progression of chords. First, lets gather some statistics about how long arrays are in general.

In [None]:
chord_lens = []
chord_totals = []

filtered_lens = []
filtered_totals = []

for i, itms in data.items():
    chord_array = itms['per_step']
    simplified_chord_array = [chord[0] if len(chord) > 0 else [] for chord in chord_array]

    filtered_simplified_chord_array = []
    for chords in simplified_chord_array:
        new_chord_arr = []
        for chord in chords:
            if chord in filteredPitchChord.keys():
                new_chord_arr.append(chord)
        filtered_simplified_chord_array.append(new_chord_arr)

    for chords in simplified_chord_array:
        chord_lens.append(len(chords))
    chord_totals.append(len(simplified_chord_array))

    for chords in filtered_simplified_chord_array:
        filtered_lens.append(len(chords))
    filtered_totals.append(len(filtered_simplified_chord_array))

print(f'Unfiltered chords average length: {np.mean(chord_totals)}, average number of chords per time step: {np.mean(chord_lens)}')
print(f'Filtered chords average length: {np.mean(filtered_totals)}, average number of chords per time step: {np.mean(filtered_lens)}')

Similar to the chords, we make this a tractable task for a random forest by truncating the chord lengths. Lets start with just 600 so things run fast.

In [None]:
MAX_CHORD_LENGTH = 600 # Since we chunk by .1, this is about 1 minute.

In [None]:
allChordDatapoints = []
filtered_lens = []
filtered_totals = []
for i, itms in data.items():
    chord_array = itms['per_step']
    simplified_chord_array = [chord[0] if len(chord) > 0 else [] for chord in chord_array]

    filtered_simplified_chord_array = []
    for chords in simplified_chord_array[:MAX_CHORD_LENGTH]:
        new_chord_arr = []
        for chord in chords:
            if chord in filteredPitchChord.keys():
                new_chord_arr.append(chord)
        filtered_simplified_chord_array.append(new_chord_arr)

    for chords in simplified_chord_array:
        chord_lens.append(len(chords))
    chord_totals.append(len(simplified_chord_array))

    for chords in filtered_simplified_chord_array:
        filtered_lens.append(len(chords))
    filtered_totals.append(len(filtered_simplified_chord_array))
    allChordDatapoints.append(filtered_simplified_chord_array)

print(f'Filtered chords average length: {np.mean(filtered_totals)}, average number of chords per time step: {np.mean(filtered_lens)}')

Now lets make some ear cancer. We mask out all chords except the immediate next chord (next chord prediction like teacher-forcing)

In [None]:
# print the available chords: These are our classes
print(filteredChordPitch.keys())
print(len(filteredChordPitch.keys()))
MASK_CHORD = [[-1]]

In [None]:
allData = []
targets = []
last_chord_added = ""
for datapoint in allChordDatapoints:
    chord_datapoint = []
    chord_masks = len(datapoint)
    for chord in datapoint:
        # Add a chord
        chord_datapoint.append(chord)
        chord_masks -= 1
        # Mask all the rest of the chords and add it to add data
        datapoint_copy = chord_datapoint.copy()
        datapoint_copy += MASK_CHORD * chord_masks
        # We skip duplicate chords to avoid letting the random forest just always output the most recent chord.
        if last_chord_added == chord:
            last_chord_added = ""
            continue
        allData.append(datapoint_copy)
        targets.append(chord)
        last_chord_added = chord
# print(allData[0])
# print(targets[0])
print(len(allData))

allData = []
targets = []
last_chord_added = ""
for datapoint in allChordDatapoints:
    chord_datapoint = []
    chord_masks = len(datapoint)
    for chord in datapoint:
        # Add a chord
        chord_datapoint.append(chord)
        chord_masks -= 1
        # Mask all the rest of the chords and add it to add data
        datapoint_copy = chord_datapoint.copy()
        datapoint_copy += MASK_CHORD * chord_masks
        # We skip duplicate chords to avoid letting the random forest just always output the most recent chord.
        if last_chord_added == chord:
            last_chord_added = ""
            continue
        allData.append(datapoint_copy)
        targets.append(chord)
        last_chord_added = chord
print(allData[101])
print(targets[101])
print(len(allData))



In [None]:
print(len([43, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]))

In [None]:
from collections import Counter
from sklearn.feature_extraction import DictVectorizer
from sklearn.preprocessing   import MultiLabelBinarizer
from sklearn.multioutput     import MultiOutputClassifier
from sklearn.ensemble        import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics         import classification_report


CLASS_LIST = [43, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]
NUM_CLASSES = len(CLASS_LIST)

def extract_features(seqs):
    feat_dicts = []
    for seq in seqs:
        # flatten out all non-mask values
        values = [v for sub in seq for v in sub if v != -1]
        cnt = Counter(values)
        # build a dict { "count_0":…, "count_1":…, … }
        feat_dicts.append({f"count_{c}": cnt.get(c, 0) for c in CLASS_LIST})
    return feat_dicts

def extract_features_decay(seqs, gamma=0.8):
    feat_dicts = []
    N = len(CLASS_LIST)
    for seq in seqs:
        T = len(seq)
        # Initialize decay scores to zero
        decay_scores = {c: 0.0 for c in CLASS_LIST}
        
        # For each time‐slice t, add gamma^(T-1 - t) to each chord in that slice
        for t, slice_ in enumerate(seq):
            weight = (gamma ** (T - 1 - t))
            for c in slice_:
                if c != -1:
                    decay_scores[c] += weight
        
        # Flatten into a feature dict
        d = {f"decay_{c}": decay_scores[c] for c in CLASS_LIST}
        feat_dicts.append(d)
    
    return feat_dicts

vec = DictVectorizer(sparse=False)
X = vec.fit_transform(extract_features_decay(allData))

mlb = MultiLabelBinarizer(classes=CLASS_LIST)
Y = mlb.fit_transform(targets)

X_train, X_test, y_train, y_test = train_test_split(
    X, Y, test_size=0.2, random_state=42
)

base_rf = RandomForestClassifier(
    n_estimators=1500,
    max_depth=20,
    random_state=42,
    n_jobs=-1,
    class_weight='balanced'
)
clf = MultiOutputClassifier(base_rf)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
print(classification_report(
    y_test, y_pred,
    target_names=[str(c) for c in mlb.classes_]
))

pred_sets = mlb.inverse_transform(y_pred)

In [None]:
import random
def get_chord_staleness(last_chords):
    if len(last_chords) < 3:
        return False
    chord1 = last_chords[-1]
    chord2 = last_chords[-2]
    chord3 = last_chords[-3]
    if len(chord1) != len(chord2) or len(chord2) != len(chord3):
        return False
    if len(chord1) == 1:
        all_chords = []
        for chord in last_chords:
            all_chords.append(chord[0])
        if len(np.unique(all_chords)) != 1:
            return False
        else:
            return True
    else:
        for i, chord in enumerate(chord2):
            if chord != chord1[i] or chord != chord3[i]:
                return False
    return True
        
        
# Initialize an empty vector
# last_chords = [random.choice(CLASS_LIST)]
last_chords =  [50, 52, 59, 64, 68]
sanity_check = [[66], [66], [66], [66], [66], [68], [68], [68], [69], [69], [50, 52, 59, 64, 68]]
X_seq = sanity_check + [[-1] for _ in range(600 - len(sanity_check))]
last_chords_added = []
for i in range(600):
    X_seq[i] = last_chords
    last_chords_added.append(last_chords)
    print(X_seq[i])
    if get_chord_staleness(last_chords_added):
        last_chords = [random.choice(CLASS_LIST)]
        last_chords_added = []
        print("Last random here")
    else:
        X = vec.fit_transform(extract_features_decay([X_seq]))
        proba_list = clf.predict_proba(X)
        last_chords = []
        for j, probs in enumerate(proba_list):
            if probs[0][1] > .5:
                last_chords.append(CLASS_LIST[j])


In [None]:

def fill_full_mask_one_per_slot(clf, vec, mlb, MAX_SLOTS: int) -> list[list[int]]:
    X_seq = [[-1] for _ in range(MAX_SLOTS)]

    for slot_idx in range(MAX_SLOTS):
        feat_dicts = extract_features([X_seq])  # returns list of dicts
        X0 = vec.transform(feat_dicts)          # shape = (1, n_features)

        # predict_proba returns a list of length = n_classes; each entry is (1×2) array
        proba_list = clf.predict_proba(X0)
        pos_probs = np.array([arr[0, 1] for arr in proba_list])

        best_idx = np.argmax(pos_probs)         # index of most‐likely class
        best_label = mlb.classes_[best_idx]      # actual class value, e.g. 43, 48, etc.

        X_seq[slot_idx] = [best_label]

    return X_seq


def fill_full_mask_multi_per_slot(clf, vec, mlb, MAX_SLOTS: int,
                                  threshold: float = 0.2) -> list[list[int]]:
    X_seq = [[-1] for _ in range(MAX_SLOTS)]

    for slot_idx in range(MAX_SLOTS):
        feat_dicts = extract_features([X_seq])
        X0 = vec.transform(feat_dicts)

        proba_list = clf.predict_proba(X0)
        pos_probs = np.array([arr[0, 1] for arr in proba_list])  # shape=(n_classes,)

        chosen_indices = np.where(pos_probs >= threshold)[0].tolist()
        if not chosen_indices:
            # force one if nothing crosses the threshold
            best_idx = np.argmax(pos_probs)
            chosen_indices = [best_idx]

        predicted_labels = [mlb.classes_[i] for i in chosen_indices]
        X_seq[slot_idx] = predicted_labels

    return X_seq


MAX_SLOTS = 60

# 1) One‐label‐per‐slot
filled_one = fill_full_mask_one_per_slot(clf, vec, mlb, MAX_SLOTS)
print("Hypothesized sequence (one‐label per slot):")
print(filled_one)

# 2) Multi‐label per slot with threshold=0.2
filled_multi = fill_full_mask_multi_per_slot(clf, vec, mlb, MAX_SLOTS, threshold=0.2)
print("\nHypothesized sequence (multi‐label per slot, T=0.2):")
print(filled_multi)

In [None]:
def fill_full_mask_multi_per_slot_fast(
    clf,
    counts_vec: np.ndarray,
    class_to_feat_idx: dict[int,int],
    all_classes: np.ndarray,
    MAX_SLOTS: int,
    threshold: float = 0.2
) -> list[list[int]]:
    """
    For each slot, pick every class whose P(label=1) >= threshold.
    If none exceed threshold, forcibly pick the single class with max P(label=1).
    """

    completed_slots: list[list[int]] = []
    clf_proba = clf.predict_proba

    for slot_idx in range(MAX_SLOTS):
        X0 = counts_vec.reshape(1, -1)

        # Get the list of (1×2) arrays, one per class:
        proba_list = clf_proba(X0)
        pos_probs   = np.fromiter((arr[0, 1] for arr in proba_list),
                                  dtype=float, count=len(all_classes))

        # Pick all class‐indices with prob >= threshold
        chosen_idxs = np.where(pos_probs >= threshold)[0].tolist()

        if not chosen_idxs:
            # If nothing passes threshold, fallback to argmax
            best_idx = int(np.argmax(pos_probs))
            chosen_idxs = [best_idx]

        # Convert those indices → actual class values, e.g. [43, 60, 79]
        chosen_labels = [int(all_classes[i]) for i in chosen_idxs]
        completed_slots.append(chosen_labels)

        # Update counts_vec for all chosen labels
        for i in chosen_idxs:
            label = int(all_classes[i])
            feat_col = class_to_feat_idx[label]
            counts_vec[feat_col] += 1

    return completed_slots

def fill_full_mask_one_per_slot_fast(
    clf,
    counts_vec: np.ndarray,
    class_to_feat_idx: dict[int,int],
    all_classes: np.ndarray,
    MAX_SLOTS: int
) -> list[list[int]]:
    """
    For each of the MAX_SLOTS, pick exactly one class: the one with highest P(label=1).
    counts_vec is a numpy array of length = n_features, initially all zeros.
    class_to_feat_idx maps each label (e.g. 43) to its column in counts_vec.
    all_classes is mlb.classes_ (shape = (n_classes,)).
    """

    # 1) Start with a fully‐masked representation: we won't actually store X_seq,
    #    but we will build the output list-of-lists incrementally.
    completed_slots: list[list[int]] = []

    # Local references for speed:
    clf_proba = clf.predict_proba  # bound method to get probabilities

    for slot_idx in range(MAX_SLOTS):
        # --- A) We already have counts_vec from previous slots. Just reshape it.
        X0 = counts_vec.reshape(1, -1)  # shape = (1, n_features)

        # --- B) Compute positive‐class probabilities for each of the n_classes
        #     Because clf is MultiOutputClassifier(RF), predict_proba(X0) returns
        #     a list of length = n_classes, where each entry is a (1×2) array
        #     for “P(class_i = 0)” vs “P(class_i = 1)”.
        proba_list = clf_proba(X0)
        #    proba_list[i][0,1] == P(label_i = 1) for class index i.

        # Build a (n_classes,) float array of “positive” probabilities:
        pos_probs = np.fromiter((arr[0, 1] for arr in proba_list),
                                dtype=float, count=all_classes)

        # --- C) Find the argmax index
        best_idx   = int(np.argmax(pos_probs))  # index in [0 .. n_classes-1]
        best_label = int(all_classes[best_idx])  # e.g. 43 or 48, etc.

        # --- D) Record it as our “filled” subarray for slot_idx
        completed_slots.append([best_label])

        # --- E) Update counts_vec in-place: increment the feature “count_<best_label>”
        feat_col = class_to_feat_idx[best_label]
        counts_vec[feat_col] += 1

        # move on to next slot…

    return completed_slots

In [None]:
n_features = len(vec.vocabulary_)  # e.g. 34 if you have 34 classes
class_to_feat_idx = {
    c: vec.vocabulary_[f"count_{c}"]
    for c in CLASS_LIST
}
all_classes = mlb.classes_  # e.g. array([43,48,49,…,81])

counts_vec = np.zeros((n_features,), dtype=int)
MAX_SLOTS  = 600  # or however many subarray positions you expect


seq_one_label = fill_full_mask_one_per_slot_fast(
    clf,
    counts_vec.copy(),        # make a fresh copy if you want to re-use the same initial state
    class_to_feat_idx,
    all_classes,
    MAX_SLOTS
)

# (B) Possibly multiple labels per slot (threshold = 0.2)
seq_multi_label = fill_full_mask_multi_per_slot_fast(
    clf,
    counts_vec.copy(),
    class_to_feat_idx,
    all_classes,
    MAX_SLOTS,
    threshold=0.2
)

In [None]:
import numpy as np

# 1) Number of “count_<c>” features in your DictVectorizer:
n_features = len(vec.vocabulary_)  # should equal len(CLASS_LIST)

# 2) Build a direct mapping: class_value → column‐index in vec
#    (Recall: when you did `vec.fit(...)`, your feature‐keys were like "count_43", "count_48", …)
class_to_feat_idx = {
    c: vec.vocabulary_[f"count_{c}"]
    for c in CLASS_LIST
}

# 3) We’ll keep a 1D NumPy array of length = n_features that stores,
#    at any time, how many times each “count_<c>” has appeared so far.
counts_vec = np.zeros((n_features,), dtype=int)

# 4) For convenience, pull out mlb.classes_ just once
all_classes = mlb.classes_   # e.g. array([43,48,49,…,81])
n_classes   = len(all_classes)