In [2]:
from pathlib import Path
from music_data_analysis import Dataset


Dataset(Path('../../pop80k_k')).songs()[0].read_pianoroll('pianoroll').to_midi('a.mid')

ticks per beat: 480
max tick: 0
tempo changes: 1
time sig: 0
key sig: 0
markers: 0
lyrics: False
instruments: 1

In [None]:
from music_data_analysis import Pianoroll


Pianoroll.from_midi(Dataset(Path('../../pop80k_k')).songs()[0].read_midi('synced_midi')).notes[:10]


[Note(14,51,76,30),
 Note(15,63,72,18),
 Note(16,58,55,30),
 Note(16,61,66,28),
 Note(16,66,72,30),
 Note(17,70,92,30),
 Note(20,72,93,30),
 Note(24,73,96,30),
 Note(28,61,34,30),
 Note(30,53,79,60)]

In [2]:
import torch
from torch import nn
from einops import rearrange
from torch import arange, stack, autocast

def exists(val):
    return val is not None

class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        use_xpos = False,
        scale_base = 512,
        interpolation_factor = 1.,
        base = 10000,
        base_rescale_factor = 1.
    ):
        super().__init__()
        # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
        # has some connection to NTK literature
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
        base *= base_rescale_factor ** (dim / (dim - 2))

        inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        assert interpolation_factor >= 1.
        self.interpolation_factor = interpolation_factor

        if not use_xpos:
            self.register_buffer('scale', None)
            return

        scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

        self.scale_base = scale_base
        self.register_buffer('scale', scale)

    def forward_from_seq_len(self, seq_len):
        device = self.inv_freq.device

        t = arange(seq_len, device = device)
        return self.forward(t)

    @autocast('cuda', enabled = False)
    def forward(self, t):
        max_pos = t.max() + 1

        if t.ndim == 1:
            t = rearrange(t, 'n -> 1 n')

        freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
        freqs = stack((freqs, freqs), dim = -1)
        freqs = rearrange(freqs, '... d r -> ... (d r)')

        if not exists(self.scale):
            return freqs, 1.

        power = (t - (max_pos // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, '... n -> ... n 1')
        scale = stack((scale, scale), dim = -1)
        scale = rearrange(scale, '... d r -> ... (d r)')

        return freqs, scale



In [62]:
pe.scale

In [None]:
import matplotlib.pyplot as plt

pe = RotaryEmbedding(64, base=2000)
a = pe.forward(torch.arange(200))
a = torch.cat((a[0][0].cos(), a[0][0].sin()), dim=1)
plt.imshow(a)


In [None]:
inner_prod = torch.einsum('i d, j d -> i j', a, a)
plt.imshow(inner_prod)
plt.colorbar()

In [None]:
plt.plot(inner_prod[0])

In [None]:
plt.plot(inner_prod[199])

In [None]:
from music_data_analysis import Pianoroll, Note

from segment_full_song.models.representation import SymbolicRepresentation


pr = Pianoroll([
    Note(onset=0, pitch=60, velocity=100),
    Note(onset=3, pitch=60, velocity=100),
    Note(onset=3, pitch=68, velocity=100),
    Note(onset=4, pitch=61, velocity=100),
    Note(onset=4, pitch=62, velocity=100),
    Note(onset=5, pitch=63, velocity=100),
    Note(onset=5, pitch=64, velocity=100),
    Note(onset=6, pitch=65, velocity=100),
    Note(onset=6, pitch=66, velocity=100),

], duration=8)

a = pr.slice(0, 4).notes

s = SymbolicRepresentation.from_pianorolls([pr], device='cuda', max_tokens_rate=4.5, need_frame_tokens=False)
b = s.slice_pos(0, 4).to_pianoroll(21).notes

a==b


In [51]:
from music_data_analysis import Pianoroll, Note

pr = Pianoroll([
    Note(0, 60, 100, 0),
    Note(1, 62, 100, 1),
    Note(2, 64, 100, 2),
    Note(3, 65, 100, 3),
    Note(4, 67, 100, 4),
    Note(5, 69, 100, 5),
    Note(6, 71, 100, 6),
])

In [None]:
pr.to_midi()