In [1]:
import os
import pickle
import keras as K
import numpy as np
import tensorflow as tf
import jams
import builtins
from tqdm import tqdm
from decimal import ROUND_HALF_UP, Decimal
import pandas as pd

builtins.tf = tf

## IMPORTANT set your working directory

In [None]:
working = "/Users/theo/School/2/MIR/final-project/FinalProjectMIR/working/chords_andrea"

In [3]:
def load_model(model_spec_path, weights_path):
    # Load the serialized model spec
    with open(model_spec_path, "rb") as fd:
        model_spec = pickle.load(fd)
    # Reconstruct the model from the serialized spec
    model = K.utils.deserialize_keras_object(model_spec, safe_mode=False)
    # Load the weights from the saved file
    model.load_weights(weights_path)
    return model

In [6]:
output_path = os.path.join(working, "model_deep")
split = 0
epochs = 100

model_spec_path = os.path.join(
    output_path, "fold{:02d}_model_{:03d}_epochs.pkl".format(split, epochs)
)
weights_path = os.path.join(
    output_path, "fold{:02d}_weights_{:03d}_epochs.keras".format(split, epochs)
)

model = load_model(model_spec_path, weights_path)
model.summary()

In [7]:
def rename_slashes_in_op_fields(op):
    """
    In-place rename of all slash-laden keys inside op.fields
    so that the new keys replace '/' with '_'.
    """
    # Must check that op.fields is actually a mutable dict
    if hasattr(op, "fields") and isinstance(op.fields, dict):
        new_dict = {}
        for old_key, old_val in op.fields.items():
            new_key = old_key.replace("/", "_")  # e.g. 'cqt/mag' -> 'cqt_mag'
            new_dict[new_key] = old_val
        # Now overwrite op.fields with slash-free keys
        op.fields = new_dict


def rename_slashes_in_pump_opmap(pump):
    """
    Go through pump.opmap, rename slash-based keys in each operator's fields.
    """
    for op_name, op in pump.opmap.items():
        # rename slashes in the operator's .fields
        rename_slashes_in_op_fields(op)

        # If you also need to rename the op_name itself if it had a slash,
        # do it here (though 'chord_struct', 'chord_tag', 'cqt' do not have slashes):
        # new_op_name = op_name.replace('/', '_')
        # if new_op_name != op_name:
        #     pump.opmap[new_op_name] = op


def rename_slashes_in_pump_ops_list(pump):
    """
    pump.ops is a list of the same operators, rename slash-based keys in each.
    """
    for op in pump.ops:
        rename_slashes_in_op_fields(op)

In [8]:
with open(
    os.path.join(working + "/pump.pkl"),
    "rb",
) as fd:
    pump = pickle.load(fd)

rename_slashes_in_pump_opmap(pump)
rename_slashes_in_pump_ops_list(pump)

In [9]:
def round_observation_times(annotation, precision=10, snap_tol=1e-6):
    """
    Create new Observation objects with times and durations rounded using Decimal
    arithmetic, then force them to be consecutive by snapping boundaries that are
    within snap_tol.

    Args:
        annotation (JAMS Annotation): A JAMS-style chord annotation with Observation objects.
        precision (int): Decimal places to round to.
        snap_tol (float): Tolerance under which boundaries are forced equal.

    Returns:
        JAMS Annotation: The adjusted annotation.
    """

    # Define quantizer string and quant
    quant_str = "1." + "0" * precision
    quant = Decimal(quant_str)

    # First pass: convert observation times/durations to Decimal.
    obs_list = []
    for obs in annotation.data:
        rt = Decimal(str(obs.time)).quantize(quant, rounding=ROUND_HALF_UP)
        rd = Decimal(str(obs.duration)).quantize(quant, rounding=ROUND_HALF_UP)
        obs_list.append((rt, rd, obs.value, obs.confidence))

    # Sort by start time.
    obs_list.sort(key=lambda tup: tup[0])

    # Second pass: force consecutive intervals.
    fixed = []
    # Start with the first observation.
    prev_start, prev_dur, val, conf = obs_list[0]
    prev_end = prev_start + prev_dur
    fixed.append((prev_start, prev_dur, val, conf))

    for current in obs_list[1:]:
        current_start, current_dur, val, conf = current
        # Force the current observation to start at prev_end
        new_start = prev_end
        # Calculate original end of current observation.
        current_end = current_start + current_dur
        # New duration is calculated as difference.
        new_dur = current_end - new_start
        if new_dur < Decimal("0"):
            new_dur = Decimal("0")
        fixed.append((new_start, new_dur, val, conf))
        prev_end = new_start + new_dur  # update end

    # Convert fixed intervals back to floats with snapping.
    fixed_obs = []
    # We'll build the new observations, and whenever the gap is below snap_tol, snap them.
    prev_end_float = None
    for start, dur, val, conf in fixed:
        start_float = float(start)
        dur_float = float(dur)
        end_float = start_float + dur_float
        if (
            prev_end_float is not None
            and abs(start_float - prev_end_float) < snap_tol
        ):
            # snap start exactly to previous end.
            start_float = prev_end_float
            # Adjust duration based on the original end.
            end_float = float(start + dur)
            dur_float = max(0, end_float - start_float)
        obs_new = jams.Observation(
            time=start_float,
            duration=dur_float,
            value=val,
            confidence=conf,
        )
        fixed_obs.append(obs_new)
        prev_end_float = start_float + dur_float

    annotation.data = fixed_obs
    return annotation

In [10]:
def score_model(pump, model, idx, working, refs):
    results = {}
    for item in tqdm(idx, desc="Evaluating the model"):
        jam = jams.load(os.path.join(refs, f"{item}.jams"), validate=False)
        datum = np.load(os.path.join(working, "pump", f"{item}.npz"))[
            "cqt_mag"
        ]

        output = model.predict(datum)[0]

        ann = pump["chord_tag"].inverse(output)
        ann = round_observation_times(ann)

        ref_ann = round_observation_times(
            jam.annotations["chord", 0], precision=10
        )

        try:
            results[item] = jams.eval.chord(ref_ann, ann)
        except Exception as e:
            print(f"Error evaluating {item}: {e}")

    return pd.DataFrame.from_dict(results, orient="index")[
        ["root", "thirds", "triads", "tetrads", "mirex", "majmin", "sevenths"]
    ]

In [None]:
test_dataset_path = "/Users/theo/School/2/MIR/final-project/FinalProjectMIR/working/beethoven/dataset.csv"
idx = pd.read_csv(
    test_dataset_path,
    header=None,
    names=["id"],
)
pump_path = (
    "/Users/theo/School/2/MIR/final-project/FinalProjectMIR/working/beethoven"
)
refs = "/Users/theo/School/2/MIR/final-project/datasets/Beethoven_Piano_Sonata_Dataset_v2/2_Annotations/ann_audio_chord"
scores = score_model(pump, model, idx["id"], working, refs)

Evaluating the model:   0%|          | 0/128 [00:00<?, ?it/s]




FileNotFoundError: [Errno 2] No such file or directory: '/Users/theo/School/2/MIR/final-project/FinalProjectMIR/working/chords_andrea/pump/Beethoven_Op002No1-01_AS35.npz'

In [28]:
scores.mean()

root        0.355713
thirds      0.197155
triads      0.149103
tetrads     0.027053
mirex       0.316645
majmin      0.155019
sevenths    0.024112
dtype: float64

In [7]:
chord_file_name = "B-4-sus2-chord-1"
with open(
    working + f"jazznet/clean_dataset/jams/test/{chord_file_name}.jams"
) as fd:
    jam = jams.load(fd)

In [8]:
input = np.load(working + f"chords/pump/{chord_file_name}.npz")["cqt_mag"]

predictions = model.predict(input)[0]

ann = pump["chord_tag"].inverse(predictions)
print(ann)
results = jams.eval.chord(jam.annotations["chord", 0], ann)

print(results)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 312ms/step
{
  "annotation_metadata": {
    "curator": {
      "name": "",
      "email": ""
    },
    "annotator": {},
    "version": "",
    "corpus": "",
    "annotation_tools": "",
    "annotation_rules": "",
    "validation": "",
    "data_source": ""
  },
  "namespace": "chord",
  "data": [
    {
      "time": 0.0,
      "duration": 1.1145578231292517,
      "value": "F#:sus4",
      "confidence": 0.795343279838562
    },
    {
      "time": 1.1145578231292517,
      "duration": 0.8359183673469388,
      "value": "C#:sus4",
      "confidence": 0.7688261270523071
    },
    {
      "time": 1.9504761904761905,
      "duration": 0.18575963718820843,
      "value": "D:maj",
      "confidence": 0.7014773488044739
    },
    {
      "time": 2.136235827664399,
      "duration": 0.09287981859410444,
      "value": "G:maj",
      "confidence": 0.24560080468654633
    },
    {
      "time": 2.2291156462585033,
      "duration":

