# embedding analytics

some functions to assess the usefulness of embeddings

## imports

In [1]:
import os
import mido
import torch
import numpy as np
import pretty_midi
import matplotlib.pyplot as plt

plt.style.use("dark_background")

In [2]:
VALID_EXTENSIONS = (".mid", ".midi")
WEIGHT_PRS = True
# test_file = "/media/nova/Datasets/sageev-midi/20250110/segmented/20240511-088-03/20240511-088-03_0169-0174.mid"
test_file = "/media/nova/Datasets/sageev-midi/20250110/segmented/20240121-070-01/20240121-070-01_0041-0047.mid"

## representation functions

currently supported: pitch histogram (weighted & unweighted), spectrogram diffusion embeddings

to add: clamp embeddings, blurred piano rolls

In [3]:
def load_embedding(midi_path: str) -> torch.tensor:
    midi_name = os.path.splitext(os.path.basename(midi_path))[0]
    embedding_path = os.path.join("..", "data", "embeddings", midi_name + ".pt")
    return torch.load(embedding_path, weights_only=False).flatten()

In [4]:
def load_pr(midi_path: str) -> np.ndarray:
    return pretty_midi.PrettyMIDI(midi_path).get_pitch_class_histogram(
        use_duration=WEIGHT_PRS, use_velocity=WEIGHT_PRS
    )

In [16]:
# weighted piano roll
representation_function = lambda x: load_pr(x)

# spectrogram diffusion embeddings
# representation_function = lambda x: load_embedding(x)

## similarity functions

currently supported: cosine similarity

to add: euclidean, manhattan, ...?

In [17]:
# cosine similarity
similarity_metric = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

## load filesystem

In [None]:
# extract representations for all segments
p_midi = os.path.dirname(test_file)
midi_file = os.path.basename(test_file)
midi_segments = [
    os.path.join(p_midi, segment)
    for segment in os.listdir(p_midi)
    if segment.endswith(VALID_EXTENSIONS)
]
print(f"midi path is {p_midi}")
print(f"midi file is {midi_file}")
print(f"midi segments ({len(midi_segments)}) is {midi_segments[:3]}")

try:
    segment_index = midi_segments.index(test_file)
    print(f"segment index is {segment_index}")
except ValueError:
    print(f"ERROR: couldn't find key midi file in segment list")

## calculate the similarities

In [None]:
segment_representations = [
    representation_function(segment) for segment in midi_segments
]

# get the representation of the specified segment
target_representation = representation_function(test_file)

# calculate similarity for all segments
similarities = [
    similarity_metric(representation_function(test_file), representation)
    for representation in segment_representations
]
# do a quick validity test
key_sim = similarity_metric(
    target_representation, segment_representations[segment_index]
)
print(f"self-similarity for key midi file is {key_sim:.05f}")

# normalize similarities to [0, 1] range
min_similarity, max_similarity = min(similarities), max(similarities)
normalized_similarities = [
    (sim - min_similarity) / (max_similarity - min_similarity) for sim in similarities
]

## similarity plot

visualize the similarity of the chosen segment against the entire track

In [None]:
track_path = os.path.join(
    p_midi.replace("segmented", "unsegmented"), p_midi.split("/")[-1] + ".mid"
)

# load file
midi_mido = mido.MidiFile(track_path)
midi_pm = pretty_midi.PrettyMIDI(track_path)

# make piano roll
piano_roll = midi_pm.get_piano_roll() / 128.0

# trim piano roll to remove rows below lowest and above highest notes
row_sums = piano_roll.sum(axis=1)
non_zero_rows = np.where(row_sums > 0)[0]
min_row, max_row = non_zero_rows[0], non_zero_rows[-1]
trimmed_piano_roll = piano_roll[min_row : max_row + 1]

# calculate dimensions for plotting
pr_width = trimmed_piano_roll.shape[1]

# calculate pixel-tick conversion ratio
bpm = int(os.path.basename(track_path).split("-")[1])
num_ticks = mido.second2tick(
    midi_pm.get_end_time(), midi_mido.ticks_per_beat, mido.bpm2tempo(bpm)
)
ticks_per_pixel = pr_width / num_ticks

# extract tick positions for every 8th beat
beat_positions = [0]
num_beats = 0
for track in midi_mido.tracks:
    first_msg = track[0]
    if first_msg.is_meta and first_msg.type == "track_name":
        if first_msg.name != "tick":
            continue

        current_tick = 0
        for msg in track:
            current_tick += msg.time
            num_beats += 1
            if num_beats % 8 == 0:  # only include every 8th beat
                tick_pixel = int(current_tick * ticks_per_pixel)
                if 0 <= tick_pixel < pr_width:
                    beat_positions.append(tick_pixel)

# plot piano roll
plt.figure(figsize=(12, 6))
plt.title(f"{test_file}")
plt.imshow(piano_roll, aspect="auto", origin="lower", cmap="gray")

# plot vertical lines for every 8th beat
for beat in beat_positions:
    plt.axvline(x=beat, color="red", linestyle="--", linewidth=0.5)

# similarity histogram
plt.axhline(y=piano_roll.shape[0], xmax=0.95, color="white", alpha=0.3)
bin_edges = np.linspace(0, pr_width, num_beats // 8 + 1)
for i, value in enumerate(similarities):
    bin_center = (bin_edges[i] + bin_edges[i + 1]) / 2
    color = "blue" if i == segment_index else "green"
    plt.bar(
        bin_center,
        value * piano_roll.shape[0],
        width=(bin_edges[i + 1] - bin_edges[i]),
        color=color,
        alpha=0.5,
        align="center",
        edgecolor="none",
    )

plt.axis("off")
plt.show()

## file playback

listen to the original file, best 3 matches, and worst match, within the same track, not within the entire dataset.

In [21]:
from midi_player import MIDIPlayer
from midi_player.stylers import dark

matched_sims = sorted(
    list(zip(midi_segments, similarities)), key=lambda x: x[1], reverse=True
)

### original file

In [None]:
print(
    f"{os.path.basename(matched_sims[0][0])} has similarity {matched_sims[0][1]:.03f} to {midi_file}"
)
MIDIPlayer(matched_sims[0][0], 300, styler=dark)

### best match

In [None]:
print(
    f"{os.path.basename(matched_sims[1][0])} has similarity {matched_sims[1][1]:.03f} to {midi_file}"
)
MIDIPlayer(matched_sims[1][0], 300, styler=dark)

### second best match

In [None]:
print(
    f"{os.path.basename(matched_sims[2][0])} has similarity {matched_sims[2][1]:.03f} to {midi_file}"
)
MIDIPlayer(matched_sims[2][0], 300, styler=dark)

### third best match

In [None]:
print(
    f"{os.path.basename(matched_sims[3][0])} has similarity {matched_sims[3][1]:.03f} to {midi_file}"
)
MIDIPlayer(matched_sims[3][0], 300, styler=dark)

### worst match

In [None]:
print(
    f"{os.path.basename(matched_sims[-1][0])} has similarity {matched_sims[-1][1]:.03f} to {midi_file}"
)
MIDIPlayer(matched_sims[-1][0], 400, styler=dark)