In [2]:
from pathlib import Path
from music_data_analysis import Dataset


ds = Dataset(Path('../../pop80k_k'))

In [29]:
import numpy as np
import random
from typing import TypedDict
def get_compose_order(segments: list[dict]):
    # simulate that human make this music.

    # First, the composer comes up with the seed segment (possibly chorus).
    # To identify the seed segment, the model looks for the segment label with most bars in total.
    # Within segments with this label, it selects the segment that is clost to the middle of the song.
    import random

    duration = max(segment["end"] for segment in segments)
    segment_compose_order = []

    n_bars_per_label = [0] * (max(segment["label"] for segment in segments) + 1)
    for i in range(len(segments)):
        n_bars_per_label[segments[i]["label"]] += (
            segments[i]["end"] - segments[i]["start"]
        ) // 32

    # print("n_bars_per_label", n_bars_per_label)

    label_sorted_by_n_bars = np.argsort(n_bars_per_label)

    # The composer writes a segment with the most common label first.
    label = label_sorted_by_n_bars[-1]
    selected_segment = None
    for segment in segments:
        if segment["label"] == label:
            if selected_segment is None:
                selected_segment = segment
            elif abs(segment["start"] - duration // 2) < abs(
                selected_segment["start"] - duration // 2
            ):
                selected_segment = segment
    segment_compose_order.append(selected_segment)

    # Next, the composer writes a segment with the second most common label.
    if len(n_bars_per_label) > 2:
        label = label_sorted_by_n_bars[-2]
        selected_segment = None
        for segment in segments:
            if segment["label"] == label:
                if selected_segment is None:
                    selected_segment = segment
                elif abs(segment["start"] - duration // 2) < abs(
                    selected_segment["start"] - duration // 2
                ):
                    selected_segment = segment
        segment_compose_order.append(selected_segment)


    # randomly permute the remaining segments
    remaining_segments = [
        segment for segment in segments if segment not in segment_compose_order
    ]
    random.shuffle(remaining_segments)

    segment_compose_order.extend(remaining_segments)

    # print("segment_compose_order", segment_compose_order)

    return segment_compose_order



class SampleTrainingSegmentsResultItem(TypedDict):
    start: int
    end: int
    shift_from_segment_start: int
    segment_duration: int
    label: int


def get_context_for_target_segment(
    segments: list[dict],
    target_segment: dict,
) -> dict[str, SampleTrainingSegmentsResultItem]:
    target_index = segments.index(target_segment)
    already_composed_segments = segments[:target_index]

    nearest_left_segment = None
    nearest_left_segment_distance = float("inf")
    for segment in reversed(already_composed_segments):
        if segment["end"] > target_segment["start"]:
            continue
        left_segment_distance = target_segment["start"] - segment["end"]
        if left_segment_distance < nearest_left_segment_distance:
            nearest_left_segment_distance = left_segment_distance
            nearest_left_segment = segment

    nearest_right_segment = None
    nearest_right_segment_distance = float("inf")
    for segment in already_composed_segments:
        if segment["start"] < target_segment["end"]:
            continue
        right_segment_distance = segment["start"] - target_segment["end"]
        if right_segment_distance < nearest_right_segment_distance:
            nearest_right_segment_distance = right_segment_distance
            nearest_right_segment = segment

    reference_segment = None
    for segment in already_composed_segments:
        if segment["label"] == target_segment["label"]:
            reference_segment = segment
            break

    if target_index == 0:
        seed_segment = None
    else:
        seed_segment = segments[0]

    # print("target_index", target_index)
    # print("left_segment", nearest_left_segment)
    # print("right_segment", nearest_right_segment)
    # print("seed_segment", seed_segment)
    # print("reference_segment", reference_segment)

    selected_segments = {
        "target": target_segment,
        "left": nearest_left_segment,
        "right": nearest_right_segment,
        "seed": seed_segment,
        "reference": reference_segment,
    }
    return selected_segments


def sample_training_segments(
    segments: list[dict],
    max_context_duration: dict[str, int],
) -> tuple[dict[str, SampleTrainingSegmentsResultItem], list[dict]]:
    # for training, sample a segment from the segment_compose_order
    # target_index = random.randint(0, len(segment_compose_order) - 1)

    segment_compose_order = get_compose_order(segments)

    target_index = random.randint(0, len(segment_compose_order) - 1)
    target_segment = segment_compose_order[target_index]

    selected_segments = get_context_for_target_segment(
        segment_compose_order, target_segment
    )

    result: dict[str, SampleTrainingSegmentsResultItem] = {}
    for k, full_seg in selected_segments.items():
        if full_seg is None:
            result[k] = {
                "start": 0,
                "end": 0,
                "shift_from_segment_start": 0,
                "segment_duration": 0,
                "label": -1,
            }
            continue
        elif full_seg["end"] - full_seg["start"] > max_context_duration[k]:
            if k == "target":
                shift = random.randint(
                    0, full_seg["end"] - full_seg["start"] - max_context_duration[k] - 1
                )
                shift = shift - (shift % 32)  # quantize to bar
                start = full_seg["start"] + shift
                end = start + max_context_duration[k]
            elif k == "left":
                # right most
                start = full_seg["end"] - max_context_duration[k]
                end = full_seg["end"]
            elif k in ["seed", "reference", "right"]:
                # left most
                start = full_seg["start"]
                end = start + max_context_duration[k]
            else:
                raise ValueError(f"Unknown segment type: {k}")
        else:
            start = full_seg["start"]
            end = full_seg["end"]
        assert start < end
        result[k] = {
            "start": start,
            "end": end,
            "shift_from_segment_start": start - full_seg["start"],
            "segment_duration": full_seg["end"] - full_seg["start"],
            "label": full_seg["label"],
        }

    return result, segment_compose_order



In [30]:
s = ds.get_song('@pianotutorial7630/TmZp_sMNnHk/167_229').read_json('segmentation')

In [31]:
np.argmax([3, 4, 6, 6, 4])

np.int64(2)

In [39]:
get_compose_order(s)

[{'start': 416, 'end': 608, 'label': 3},
 {'start': 224, 'end': 416, 'label': 2},
 {'start': 0, 'end': 96, 'label': 0},
 {'start': 96, 'end': 224, 'label': 1},
 {'start': 608, 'end': 736, 'label': 4}]

In [11]:
s = [{'start': 0, 'end': 96, 'label': 0},
 {'start': 96, 'end': 224, 'label': 1},
 {'start': 224, 'end': 416, 'label': 2},
 {'start': 416, 'end': 608, 'label': 3},
 {'start': 608, 'end': 736, 'label': 4}]
# divide by 32 to get bars
for i in range(len(s)):
    s[i]['start'] = s[i]['start'] // 32
    s[i]['end'] = s[i]['end'] // 32

s






[{'start': 0, 'end': 3, 'label': 0},
 {'start': 3, 'end': 7, 'label': 1},
 {'start': 7, 'end': 13, 'label': 2},
 {'start': 13, 'end': 19, 'label': 3},
 {'start': 19, 'end': 23, 'label': 4}]

In [2]:
import torch
from torch import nn
from einops import rearrange
from torch import arange, stack, autocast

def exists(val):
    return val is not None

class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        use_xpos = False,
        scale_base = 512,
        interpolation_factor = 1.,
        base = 10000,
        base_rescale_factor = 1.
    ):
        super().__init__()
        # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
        # has some connection to NTK literature
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
        base *= base_rescale_factor ** (dim / (dim - 2))

        inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        assert interpolation_factor >= 1.
        self.interpolation_factor = interpolation_factor

        if not use_xpos:
            self.register_buffer('scale', None)
            return

        scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

        self.scale_base = scale_base
        self.register_buffer('scale', scale)

    def forward_from_seq_len(self, seq_len):
        device = self.inv_freq.device

        t = arange(seq_len, device = device)
        return self.forward(t)

    @autocast('cuda', enabled = False)
    def forward(self, t):
        max_pos = t.max() + 1

        if t.ndim == 1:
            t = rearrange(t, 'n -> 1 n')

        freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
        freqs = stack((freqs, freqs), dim = -1)
        freqs = rearrange(freqs, '... d r -> ... (d r)')

        if not exists(self.scale):
            return freqs, 1.

        power = (t - (max_pos // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, '... n -> ... n 1')
        scale = stack((scale, scale), dim = -1)
        scale = rearrange(scale, '... d r -> ... (d r)')

        return freqs, scale



In [62]:
pe.scale

In [None]:
import matplotlib.pyplot as plt

pe = RotaryEmbedding(64, base=2000)
a = pe.forward(torch.arange(200))
a = torch.cat((a[0][0].cos(), a[0][0].sin()), dim=1)
plt.imshow(a)


In [None]:
inner_prod = torch.einsum('i d, j d -> i j', a, a)
plt.imshow(inner_prod)
plt.colorbar()

In [None]:
plt.plot(inner_prod[0])

In [None]:
plt.plot(inner_prod[199])

In [None]:
from music_data_analysis import Pianoroll, Note

from segment_full_song.models.representation import SymbolicRepresentation


pr = Pianoroll([
    Note(onset=0, pitch=60, velocity=100),
    Note(onset=3, pitch=60, velocity=100),
    Note(onset=3, pitch=68, velocity=100),
    Note(onset=4, pitch=61, velocity=100),
    Note(onset=4, pitch=62, velocity=100),
    Note(onset=5, pitch=63, velocity=100),
    Note(onset=5, pitch=64, velocity=100),
    Note(onset=6, pitch=65, velocity=100),
    Note(onset=6, pitch=66, velocity=100),

], duration=8)

a = pr.slice(0, 4).notes

s = SymbolicRepresentation.from_pianorolls([pr], device='cuda', max_tokens_rate=4.5, need_frame_tokens=False)
b = s.slice_pos(0, 4).to_pianoroll(21).notes

a==b


In [51]:
from music_data_analysis import Pianoroll, Note

pr = Pianoroll([
    Note(0, 60, 100, 0),
    Note(1, 62, 100, 1),
    Note(2, 64, 100, 2),
    Note(3, 65, 100, 3),
    Note(4, 67, 100, 4),
    Note(5, 69, 100, 5),
    Note(6, 71, 100, 6),
])

In [None]:
pr.to_midi()