In [7]:
import music_data_analysis as analysis
from music_data_analysis.data.pianoroll import Pianoroll
from music_data_analysis import Note
from pathlib import Path

In [8]:
from collections import defaultdict


def get_num_overlaps2(a: Pianoroll, b: Pianoroll, pitch_shift: int):
    num_overlaps = 0
    i = 0
    j = 0
    a_pitch_to_notes: dict[int, list[Note]] = defaultdict(list)
    for note in a.notes:
        a_pitch_to_notes[note.pitch].append(note)
    b_pitch_to_notes: dict[int, list[Note]] = defaultdict(list)
    for note in b.notes:
        b_pitch_to_notes[note.pitch + pitch_shift].append(note)

    all_pitches = set(a_pitch_to_notes.keys()) | set(b_pitch_to_notes.keys())
    for pitch in all_pitches:
        a_notes = a_pitch_to_notes[pitch]
        b_notes = b_pitch_to_notes[pitch]
        i = 0
        j = 0
        while i < len(a_notes) and j < len(b_notes):
            if abs(a_notes[i].onset - b_notes[j].onset) <= 1:
                num_overlaps += 1
                i += 1
                j += 1
            elif a_notes[i].onset < b_notes[j].onset:
                i += 1
            else:
                j += 1

    denom = max(len(a.notes), len(b.notes))
    if denom == 0:
        return 0
    else:
        return num_overlaps / denom


def get_num_overlaps(a: Pianoroll, b: Pianoroll, pitch_shift: int):
    num_overlaps = 0
    i = 0
    j = 0

    a_notes = []
    seen_pitch = [0] * 128
    for note in a.notes:
        if seen_pitch[note.pitch] < 2:
            a_notes.append(note)
            seen_pitch[note.pitch] += 1

    b_notes = []

    seen_pitch = [0] * 128
    for note in b.notes:
        if seen_pitch[note.pitch] < 2:
            b_notes.append(note)
            seen_pitch[note.pitch] += 1

    i_max = len(a_notes)
    j_max = len(b_notes)

    while i < i_max and j < j_max:
        note_a = a_notes[i]
        note_b = b_notes[j]

        if note_a.onset < note_b.onset:
            i += 1
        elif note_a.onset > note_b.onset:
            j += 1
        else:
            if note_a.pitch == note_b.pitch + pitch_shift:
                num_overlaps += 1

            i += 1
            j += 1

    denom = (len(a_notes) + len(b_notes)) / 2

    if denom == 0:
        if len(a.notes) == 0 and len(b.notes) == 0:
            return 1
        else:
            return 0
    return num_overlaps / denom


def get_overlap_sim(a: Pianoroll, b: Pianoroll):
    pitch_shift_search = [0, -12, 12]
    num_overlaps_list = []
    for pitch_shift in pitch_shift_search:
        num_overlaps = get_num_overlaps2(a, b, pitch_shift)
        num_overlaps_list.append(num_overlaps)
    return max(num_overlaps_list)


In [9]:
from matplotlib import pyplot as plt
import torch



In [10]:
def get_skyline(pr: Pianoroll, max_slope:float=1, intercept:float=0):
    '''
    max_slope: octaves per beat
    '''
    notes = pr.notes

    max_slope_semitones_per_frame = max_slope * 12 / pr.frames_per_beat

    # filter notes that on top of each frame

    result1 = []
    for i in range(len(notes)-1):
        if notes[i].onset != notes[i+1].onset:
            result1.append(notes[i])
    result1.append(notes[-1])

    result2: list[Note] = []
    last_onset = -2147483648
    last_pitch = 0
    for note in result1:
        if (note.pitch - last_pitch + intercept) / (note.onset - last_onset) >= -max_slope_semitones_per_frame:
            result2.append(note.copy())
            last_onset = note.onset
            last_pitch = note.pitch

    result3: list[Note] = []
    last_onset = 2147483647
    last_pitch = 0
    for note in reversed(result2):
        if (note.pitch - last_pitch + intercept) / (last_onset - note.onset) >= -max_slope_semitones_per_frame:
            result3.append(note.copy())
            last_onset = note.onset
            last_pitch = note.pitch

    result3.reverse()

    return Pianoroll(result3, beats_per_bar=pr.beats_per_bar, frames_per_beat=pr.frames_per_beat, duration=pr.duration)


In [11]:
ds = analysis.Dataset(Path("../dataset/pop80k_k"))
# song = ds.get_song("@Animenzzz/1zKejX-up-k/0_554")
# song = ds.get_song("@AnCoongPiano/_MwrgFgL5wo/0_270")
# data/pop80k_k/segmentation/@0AdRiaNleE0/4Ne_JADL0Yc/0_234.json
song = ds.get_song("@0AdRiaNleE0/4Ne_JADL0Yc/0_234")
# song = ds.songs()[42398]
# song = ds.songs()[69145]
# song = ds.songs()[35923]

# "W:\piano-ai\output\@Animenzzz\4UEnnIChm8U\0_473.mid"
# pr = ds.get_song("@Animenzzz/-liXLunc-JQ/0_365").read_pianoroll("pianoroll")


# is_pop = open("../dataset/pop80k_k/is_pop.txt").read().splitlines()
# i = 30789
# songs = ds.songs()
# song = songs[i]
# while song.song_name.split("/")[0] not in is_pop:
#     i+=1
#     song = songs[i]

chords = song.read_json("chords")
pr = song.read_pianoroll("pianoroll")

mat = torch.zeros((pr.duration // 32, pr.duration // 32))
skyline = get_skyline(pr)
for i in range(pr.duration // 32):
    for j in range(i, pr.duration // 32):
        sim = (
            get_overlap_sim(
                skyline.slice(i * 32, (i + 1) * 32), skyline.slice(j * 32, (j + 1) * 32)
            )
            * 0.5
            + get_overlap_sim(
                pr.slice(i * 32, (i + 1) * 32), pr.slice(j * 32, (j + 1) * 32)
            )
            * 0.5
        )
        mat[i, j] = sim
        mat[j, i] = sim

In [None]:
pr.duration // 32

In [None]:
import numpy as np
from sklearn.cluster import KMeans


import pickle

file = song.get_old_path("triplet_predictions")

a = pickle.load(open(file, "rb"))
beats = song.read_json("beats")
beats_in_second = torch.tensor(beats['beats'])

def second_to_beat(second):
    beat = torch.searchsorted(beats_in_second, second)
    return int(beat)
def second_to_bar(second):
    return round(second_to_beat(second)/4)
split_points, labels = a[3]
split_points = split_points[1:,0]
split_points = [second_to_bar(split_point) for split_point in split_points]

''''''

# plot the mat with the split points
names = ["A", "B", "C", "D", "E", "F", "G"]
plt.imshow(mat, vmax=0.5, vmin=0.2)
plt.colorbar()
for i, (split, next_split) in enumerate(zip([0] + split_points, split_points + [mat.shape[0]])):
    plt.plot([split - 0.5, split - 0.5], [0, mat.shape[1]], "r--", alpha=0.5)
    plt.plot([0, mat.shape[0]], [split - 0.5, split - 0.5], "r--", alpha=0.5)
    plt.text(
        split + (next_split - split) / 2,
        split + (next_split - split) / 2,
        # names[labels[split]],
        names[labels[i]],
        ha="center",
        va="center",
        fontsize=20,
        color="#000000",
    )

segments = []
for i, (split, next_split) in enumerate(zip([0] + split_points, split_points + [mat.shape[0]])):
    segments.append({'start': split * 32, 'end': next_split * 32, 'label': labels[i]})




In [None]:
A.numpy().max()

In [None]:
import numpy as np
import scipy
from sklearn.cluster import KMeans

ignores = []
# A is similarity matrix add adjacency matrix so we favor more connected bars
adj = torch.zeros_like(mat)
adj.diagonal(0).fill_(1)  # Main diagonal
adj.diagonal(1).fill_(1)  # Diagonal +1
adj.diagonal(-1).fill_(1)  # Diagonal -1
adj_weight = 0.7
mat_clamped = torch.clamp(mat, min=0.2)
A = mat_clamped * (1 - adj_weight) + adj_weight * adj
D = torch.diag(A.sum(dim=1))
L = D - A
# Using matrix square root manually since torch.linalg doesn't have sqrtm
# D_sqrt_inv = torch.diag(torch.pow(torch.diag(D), -0.5))
# L = torch.eye(A.shape[0]) - D_sqrt_inv @ A @ D_sqrt_inv


eigvals, eigvecs = torch.linalg.eig(L)
# Extract the first k eigenvectors of the Laplacian (smallest eigenvalues):
# k = 5
# d = 5
# eigvecs_first_k = eigvecs[:, torch.argsort(eigvals.real)[:k]].real
# eigvecs_first_d = eigvecs[:, torch.argsort(eigvals.real)[:d]].real

# # remove the ignores
# eigvecs = eigvecs[[i not in ignores for i in range(len(eigvecs))], :]

# labels: np.ndarray = KMeans(n_clusters=k).fit_predict(eigvecs_first_k.real)


max_k = 5

eigvals_sorted = torch.sort(eigvals.real)[0]
eigval_diff = eigvals_sorted[1:] - eigvals_sorted[:-1]

eigval_diff[0]=0 # we don't consider k=1

k = int((torch.argmax(eigval_diff[:max_k]) + 1))

# Extract the first k eigenvectors of the Laplacian (smallest eigenvalues):
eigvecs_first_k = eigvecs[:, torch.argsort(eigvals.real)[:k]].real

labels = KMeans(n_clusters=k).fit_predict(eigvecs_first_k)



label_order = []
seen_labels = set()
for i in range(len(labels)):
    if labels[i] not in seen_labels:
        label_order.append(labels[i])
        seen_labels.add(labels[i])
labels = [label_order.index(i) for i in labels]

split_points = []
for i in range(1, len(labels)):
    if labels[i] != labels[i - 1]:
        split_points.append(i)


''''''

# plot the mat with the split points
names = ["A", "B", "C", "D", "E", "F", "G"]
plt.imshow(mat)
plt.colorbar()
for i, (split, next_split) in enumerate(zip([0] + split_points, split_points + [mat.shape[0]])):
    plt.plot([split - 0.5, split - 0.5], [0, mat.shape[1]], "r--", alpha=0.5)
    plt.plot([0, mat.shape[0]], [split - 0.5, split - 0.5], "r--", alpha=0.5)
    plt.text(
        split + (next_split - split) / 2,
        split + (next_split - split) / 2,
        names[labels[split]],
        # names[labels[i]],
        ha="center",
        va="center",
        fontsize=20,
        color="#000000",
    )

segments = []
for i, (split, next_split) in enumerate(zip([0] + split_points, split_points + [mat.shape[0]])):
    segments.append({'start': split * 32, 'end': next_split * 32, 'label': labels[split]})

def calinski_harabasz_index(eigvecs, labels, k):
    labels = torch.tensor(labels)
    custer_centroids = []
    for i in range(k):
        custer_centroids.append(eigvecs[labels == i].mean(dim=0))
    global_centroid = eigvecs.mean(dim=0)
    bcss = 0
    wcss = 0
    for i in range(k):
        bcss += (custer_centroids[i] - global_centroid).pow(2).sum()
        wcss += (eigvecs[labels == i] - custer_centroids[i]).pow(2).sum()
    print(bcss, wcss)
    return (bcss / (k - 1)) / (wcss / (len(eigvecs) - k))

for segment in segments:
    print(segment['start'], segment['end'], segment['label'])

In [None]:
[labels[i] for i in split_points]

In [None]:
max_k = 9

eigvals_sorted = torch.sort(eigvals.real)[0]
eigval_diff = eigvals_sorted[1:] - eigvals_sorted[:-1]

eigval_diff[0]=0 # we don't consider k=1

optimal_k = np.argmax(eigval_diff[:max_k]) + 1
print(f"optimal_k: {optimal_k}")

plt.plot(range(1, max_k+1), eigval_diff[:max_k])

In [5]:
from music_data_analysis.processors.segmentation import SegmentationProcessor
proc = SegmentationProcessor()
proc.process_impl(song)


In [None]:
a[3]

In [None]:
# simulate that human make this music.

# First, the composer comes up with the seed segment (possibly chorus).
# To identify the seed segment, the model looks for the segment label with most bars in total.
# Within segments with this label, it selects the segment that is clost to the middle of the song.
import random
segment_compose_order = []

duration = pr.duration


n_bars_per_label = [0] * k
for i in range(len(segments)):
    n_bars_per_label[segments[i]['label']] += (segments[i]['end'] - segments[i]['start']) // 32

print('n_bars_per_label', n_bars_per_label)

label = np.argmax(n_bars_per_label)
print('label', label)

selected_segment = None
for segment in segments:
    if segment['label'] == label:
        if selected_segment is None:
            selected_segment = segment
        elif abs(segment['start'] - duration // 2) < abs(selected_segment['start'] - duration // 2):
            selected_segment = segment

segment_compose_order.append(selected_segment)

# Next, the composer writes the second-most bars segment.


label = np.argsort(n_bars_per_label)[-2]
print('label', label)

selected_segment = None
for segment in segments:
    if segment['label'] == label:
        if selected_segment is None:
            selected_segment = segment
        elif abs(segment['start'] - duration // 2) < abs(selected_segment['start'] - duration // 2):
            selected_segment = segment

segment_compose_order.append(selected_segment)

print('segment_compose_order', segment_compose_order)

# randomly permute the remaining segments
remaining_segments = [segment for segment in segments if segment not in segment_compose_order]
random.shuffle(remaining_segments)

segment_compose_order.extend(remaining_segments)

print('segment_compose_order', segment_compose_order)



In [None]:
# for training, sample a segment from the segment_compose_order
# target_index = random.randint(0, len(segment_compose_order) - 1)
target_index = 2
target_segment = segment_compose_order[target_index]


already_composed_segments = segment_compose_order[:target_index]

nearest_left_segment = None
nearest_left_segment_distance = float('inf')
for segment in reversed(already_composed_segments):
    if segment['end'] > target_segment['start']:
        continue
    left_segment_distance = target_segment['start'] - segment['end']
    if left_segment_distance < nearest_left_segment_distance:
        nearest_left_segment_distance = left_segment_distance
        nearest_left_segment = segment

nearest_right_segment = None
nearest_right_segment_distance = float('inf')
for segment in already_composed_segments:
    if segment['start'] < target_segment['end']:
        continue
    right_segment_distance = segment['start'] - target_segment['end']
    if right_segment_distance < nearest_right_segment_distance:
        nearest_right_segment_distance = right_segment_distance
        nearest_right_segment = segment

reference_segment = None
for segment in already_composed_segments:
    if segment['label'] == target_segment['label']:
        reference_segment = segment
        break

if target_index == 0:
    seed_segment = None
else:
    seed_segment = segment_compose_order[0]

print('target_index', target_index)
print('left_segment', nearest_left_segment)
print('right_segment', nearest_right_segment)
print('seed_segment', seed_segment)
print('reference_segment', reference_segment)

# plot on the mat
plt.figure(figsize=(8, 8))
plt.gca().set_facecolor('black')
# plt.gcf().set_facecolor('black')

plt.imshow(mat, cmap='gray', alpha=0.5)

annotation_per_segment = ['']*len(segments)
for name, segment in zip(['tar', 'l', 'r', 'seed', 'ref'], [target_segment, nearest_left_segment, nearest_right_segment, seed_segment, reference_segment]):
    if segment is None:
        continue
    annotation_per_segment[segments.index(segment)] += f',{name}'
    # plt.text((segment['start'] + (segment['end'] - segment['start']) / 2) / 32, (segment['start'] + (segment['end'] - segment['start']) / 2) / 32+2, name, ha='center', va='center', fontsize=20, color="#ff7700")

for i, annotation in enumerate(annotation_per_segment):
    plt.text((segments[i]['start'] + segments[i]['end']) / 2 / 32, (segments[i]['start'] + segments[i]['end']) / 2 / 32, annotation[1:], ha='center', va='center', fontsize=20, color="#00ffff")

for i, (split, next_split) in enumerate(zip([0] + split_points, split_points + [mat.shape[0]])):
    plt.plot([split - 0.5, split - 0.5], [0, mat.shape[1]], "r--", alpha=0.5)
    plt.plot([0, mat.shape[0]], [split - 0.5, split - 0.5], "r--", alpha=0.5)
    plt.text(
        split + (next_split - split) / 2,
        split + (next_split - split) / 10,
        names[labels[split]]+f'({segment_compose_order.index(segments[i])})',
        ha="center",
        va="center",
        fontsize=15,
        color="#00ff00",
    )


In [None]:
a

In [None]:
split_points

In [None]:


print(second_to_beat(500000))



In [None]:
len(beats_in_second)

In [523]:
from numpy import random


for _ in range(20):
    i = random.randint(0, len(ds) - 1)
    song = ds.songs()[i]
    pr = song.read_pianoroll("pianoroll")
    pr.to_midi(f"/home/eri24816/segment_full_song/ignore/output/sample/{i}.mid")


In [33]:
def cut(W: torch.Tensor, A: torch.Tensor, B: torch.Tensor):
    s = 0
    for u in A:
        for v in B:
            s += W[u, v]
    return s


def assoc(W: torch.Tensor, A: torch.Tensor):
    return W.sum(1)[A].sum()


def ncut(W: torch.Tensor, A: torch.Tensor, B: torch.Tensor):
    return cut(W, A, B) / assoc(W, A) + cut(W, A, B) / assoc(W, B)


def search_cut(W: torch.Tensor, eigvec: torch.Tensor, num_search: int = 10) -> tuple[torch.Tensor, torch.Tensor, float]:
    best_cut_point = 0
    best_ncut = float("inf")

    for cut_point in torch.linspace(eigvec.min(), eigvec.max(), num_search):
        A = torch.where(eigvec < cut_point)[0]
        B = torch.where(eigvec >= cut_point)[0]
        ncut_for_this_cut_point = ncut(W, A, B)
        if ncut_for_this_cut_point < best_ncut:
            best_ncut = ncut_for_this_cut_point.item()
            best_cut_point = cut_point
    res_A = torch.where(eigvec < best_cut_point)[0]
    res_B = torch.where(eigvec >= best_cut_point)[0]
    return res_A, res_B, best_ncut


In [None]:
# W = mat[:27, :27] + 0.1


def ncut_eig(W):
    D = torch.diag(W.sum(dim=1))
    D_inv = torch.diag(1 / torch.diag(D))
    eigvals, eigvecs = torch.linalg.eig(D_inv @ (D - W))
    # use the second smallest eigenvalue's eigenvector
    eigvec = eigvecs[:, torch.argsort(eigvals.real)[1]].real
    return eigvec


def remove_index(mat, index):
    """
    Remove the index-th row and column from the matrix
    """
    row_removed = torch.cat([mat[:index, :], mat[index + 1 :, :]])
    col_removed = torch.cat([row_removed[:, :index], row_removed[:, index + 1 :]], dim=1)
    return col_removed


def remove_indices(mat, indices):
    """
    Remove the indices-th rows and columns from the matrix
    """
    for index in sorted(indices, reverse=True):
        mat = remove_index(mat, index)
    return mat


W = mat
eigvec = ncut_eig(W)
plt.plot(eigvec)

In [None]:
eigvec = ncut_eig(remove_indices(W, [0, 1]))
plt.plot(eigvec)

A, B, ncut_val = search_cut(W, eigvec)
plt.imshow(W, vmax=0.5)
plt.plot(A, torch.zeros(len(A)), "ro")
plt.plot(B, torch.zeros(len(B)), "bo")
plt.show()
print(ncut_val)


In [39]:
import pprint


def recursive_ncut(W: torch.Tensor, real_indices: torch.Tensor | None = None, stop_ncut: float = 1000):
    """
    Recursively cut the matrix until the ncut value is larger than the stop_ncut
    Args:
        W (torch.Tensor): The weight matrix
        stop_ncut (float, optional): The ncut value to stop the recursion. Defaults to 1.
    """
    if real_indices is None:
        real_indices = torch.arange(W.shape[0])

    if W.shape[0] <= 2:
        return real_indices.tolist()

    eigvec = ncut_eig(W)
    A, B, ncut_val = search_cut(W, eigvec, num_search=20)
    if ncut_val > stop_ncut:
        return real_indices.tolist()
    else:
        A_indices = real_indices[A]
        B_indices = real_indices[B]
        return [
            recursive_ncut(remove_indices(W, B), A_indices, stop_ncut),
            recursive_ncut(remove_indices(W, A), B_indices, stop_ncut),
        ]


res = recursive_ncut(mat + 0.1)

In [None]:
pprint.pprint(res)

In [None]:
from typing import Callable


class Node:
    def __init__(
        self,
        parent: "Node|None" = None,
        tag: str | None = None,
        tree: "Tree|None" = None,
    ):
        self.tag = tag
        self.parent = parent
        self.tree = tree
        self.children = []
        if self.tree is None:
            assert self.parent is not None
            self.tree = self.parent.tree
        if self.tag is not None:
            self.tree.nodes[self.tag] = self
        if self.parent is not None:
            self.parent.children.append(self)

    def __str__(self):
        indent_first = "│" + "  " * 1
        indent = "│" + "  " * 1
        indent_last = "   "
        if self.tag is None:
            res = ""
        else:
            res = self.tag
        if len(self.children) > 0:
            first_child_str = self.children[0].__str__()
            if len(self.children) > 1:
                first_child_prefix = "┬──"
            else:
                first_child_prefix = "───"
            first_child_line = first_child_prefix + first_child_str.replace("\n", "\n" + indent_first) + "\n"

            res += first_child_line
            for child in self.children[1:-1]:
                child_str = child.__str__()
                res += "├──" + child_str.replace("\n", "\n" + indent)
                res += "\n"
            last_child_str = self.children[-1].__str__()
            res += "└──" + last_child_str.replace("\n", "\n" + indent_last)
        return res


class Tree:
    def __init__(self):
        self.nodes = {}
        self.root = Node(None, None, self)

    def get_depth(self, node: Node):
        depth = 0
        while node.parent is not None:
            node = node.parent
            depth += 1
        return depth


tree = Tree()

# plot the result, which is list of lists... of indices


def add_to_tree(
    tree: Tree,
    parent: Node,
    indices: list,
    name_func: Callable[[int], str] = lambda x: str(x),
):
    if isinstance(indices, int):
        Node(parent, name_func(indices))
    elif len(indices) == 1:
        Node(parent, name_func(indices[0]), tree)
    else:
        left, right = indices
        node = Node(parent, None, tree)
        add_to_tree(tree, node, left, name_func)
        add_to_tree(tree, node, right, name_func)


add_to_tree(tree, tree.root, res, lambda x: str(x + 1))
print(tree.root)


In [None]:
def tree_distance(tree: Tree, i: int | str, j: int | str):
    if isinstance(i, int):
        i = str(i)
    if isinstance(j, int):
        j = str(j)
    a = tree.nodes[i]
    b = tree.nodes[j]
    a_search = a
    b_search = b
    while a_search != b_search:
        if tree.get_depth(a_search) > tree.get_depth(b_search):
            a_search = a_search.parent
        else:
            b_search = b_search.parent
    return tree.get_depth(a) + tree.get_depth(b) - 2 * tree.get_depth(a_search)


print(tree_distance(tree, 21, 95))

In [None]:
def merge_score(tree: Tree, i: int, j: int):
    s = 0
    for d in range(-min(i, j) + 1, len(tree.nodes) - max(i, j) + 1):
        if d == 0:
            continue
        # s += 2 ** -(tree_distance(tree, i, i+d) + tree_distance(tree, j, j+d))
        s += 20 ** ((mat[i - 1, i + d - 1] + mat[j - 1, j + d - 1]) - 1)
    return s


print(merge_score(tree, 63, 64))

In [None]:
m = torch.zeros((len(tree.nodes), len(tree.nodes)))
for i in range(1, len(tree.nodes) + 1 - 1):
    j = i + 1
    for d in range(-min(i, j) + 1, len(tree.nodes) - max(i, j) + 1):
        base = 10
        if d != 0:
            # m[i, i+d] = base ** -(tree_distance(tree, i, i+d) + tree_distance(tree, j, j+d))
            m[i, i + d] = base ** ((mat[i - 1, i + d - 1] + mat[j - 1, j + d - 1]) - 1)

plt.imshow(m[:30, :30])
plt.colorbar()

In [None]:
tree_distance(tree, 2, 8)

In [None]:
l = []
for i in range(1, len(tree.nodes) + 1 - 1):
    j = i + 1
    l.append(merge_score(tree, i, j))
plt.plot(l[:30])


In [None]:
merge_score(tree, 2, 3)

In [None]:
plt.imshow(mat[:30, :30], vmax=0.5)

In [None]:
best_cut_point = search_cut(W, eigvec)
print(best_cut_point)