In [5]:
import mido

mido.MidiFile(
    "/media/nova/Datasets/sageev-midi/20250110/segmented/20240305-050-09/20240305-050-09_0172-0182.mid"
).print_tracks()

=== Track 0
MetaMessage('set_tempo', tempo=1200000, time=0)
MetaMessage('set_tempo', tempo=1200000, time=0)
MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0)
MetaMessage('end_of_track', time=1760)
=== Track 1
MetaMessage('track_name', name='20240305-050-09_0172-0182', time=0)
Message('program_change', channel=0, program=0, time=0)
Message('note_on', channel=0, note=41, velocity=73, time=215)
Message('note_on', channel=0, note=41, velocity=0, time=30)
Message('note_on', channel=0, note=41, velocity=60, time=25)
Message('note_on', channel=0, note=41, velocity=0, time=30)
Message('note_on', channel=0, note=38, velocity=70, time=28)
Message('note_on', channel=0, note=38, velocity=0, time=32)
Message('note_on', channel=0, note=35, velocity=45, time=23)
Message('note_on', channel=0, note=35, velocity=0, time=16)
Message('note_on', channel=0, note=32, velocity=50, time=43)
Message('note_on', channel=0, note=62, velocity=44, t

In [3]:
import os
import h5py
import torch
import pandas as pd
from diffusers.pipelines.deprecated.spectrogram_diffusion.notes_encoder import (
    SpectrogramNotesEncoder,
)
from diffusers import MidiProcessor

torch.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7fd970b13a40>

In [4]:
with h5py.File("data/test-tokens.h5", "r") as f:
    print(f["tokens"][:5])
    print(f["filenames"][:5])
    print(str(f["filenames"][0][0], "utf-8"))

[[1134.   65. 1135. ...    0.    0.    0.]
 [1134.  103. 1135. ...    0.    0.    0.]
 [1134.   46. 1135. ...    0.    0.    0.]
 [1134.   53. 1135. ...    0.    0.    0.]
 [1134.   53. 1135. ...    0.    0.    0.]]
[[b'alternating-060-02_baba-1-16shift_t00s00']
 [b'alternating-060-02_baba-1-4shift_t00s00']
 [b'alternating-060-02_baba-microshift_t00s00']
 [b'alternating-060-02_baba-octdown_t00s00']
 [b'alternating-060-02_baba-octup_t00s00']]
alternating-060-02_baba-1-16shift_t00s00


In [5]:
with h5py.File("data/test-embeddings.h5", "r") as f:
    print(f["embeddings"][:5])
    print(f["filenames"][:5])
    print(str(f["filenames"][0][0], "utf-8"))

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 'data/test-embeddings.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [6]:
from rich.progress import Progress


def build_neighbor_table(all_files: list[str], output_path: str) -> None:
    column_names = ["prev_2", "prev", "current", "next", "next_2"]
    n_table = pd.DataFrame(index=all_files, columns=column_names)

    progress = Progress()
    update_task = progress.add_task(f"gathering neighbors", total=len(all_files))
    with progress:
        for i, file in enumerate(all_files):
            neighbors = []
            curr_track, _ = file.split("_")
            for offset in range(-2, 3):
                idx = i + offset
                valid = 0 <= idx < len(all_files)
                filename = (
                    all_files[idx]
                    if valid and all_files[idx].split("_")[0] == curr_track
                    else None
                )
                neighbors.append(filename)

            n_table.loc[file] = neighbors
            progress.advance(update_task)

    n_table.to_parquet(output_path)


path = "/media/nova/Datasets/sageev-midi/test"
build_neighbor_table(
    [os.path.join(path, filename) for filename in os.listdir(path)],
    "data/test-neighbors.parquet",
)

In [10]:
class EmbeddingGenerator:
    encoder_config = {
        "d_ff": 2048,
        "d_kv": 64,
        "d_model": 768,
        "dropout_rate": 0.1,
        "feed_forward_proj": "gated-gelu_pytorch_tanh",
        "is_decoder": False,
        "max_length": 2048,
        "num_heads": 12,
        "num_layers": 12,
        "vocab_size": 1536,
    }

    def __init__(
        self,
        encoder_weights: str,
        device: str = None,
        config: dict = None,
    ):
        self.processor = MidiProcessor()
        if device:
            self.device = device
        else:
            self.device = "cpu"

        # build encoder
        if config:  # option to override local config
            for k, v in config.items():
                self.encoder_config[k] = v
        self.midi_encoder = SpectrogramNotesEncoder(**self.encoder_config).cuda(
            device=self.device
        )
        self.midi_encoder.eval()
        sd = torch.load(encoder_weights, weights_only=True)
        self.midi_encoder.load_state_dict(sd)

    def process(self, file_path: str) -> list[torch.tensor]:
        return self.processor(file_path)

    def get_embeddings_tokenized(self, input_tokens):
        tokens_mask = input_tokens > 0
        tokens_encoded, tokens_mask = self.midi_encoder(
            encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask
        )
        return tokens_encoded, tokens_mask

    def get_embeddings(self, file_path: str) -> list[torch.tensor]:
        out = self.process(file_path)
        embeddings = []
        for input_tokens in out:
            input_tokens = (
                torch.IntTensor(input_tokens).view(1, -1).cuda(device=self.device)
            )
            tokens_mask = input_tokens > 0
            tokens_encoded, tokens_mask = self.midi_encoder(
                encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask
            )
            embeddings.append(tokens_encoded[tokens_mask].mean(dim=0).cpu().detach())
            break  # NOTE: we're only using the first 5.12 seconds!

        return embeddings

In [None]:
supported_extensions = (".mid", ".midi")
encoder = "data/note_encoder.bin"
device = "cuda:1"
dataset_name = "20250110-segmented"
in_path = "/media/nova/Datasets/sageev-midi/20250110/segmented"
out_path = f"data/{dataset_name}.h5"
batch_size = 8

print(f"initializing embedding generator")
generator = EmbeddingGenerator(encoder, device)
print(f"initialization complete")

initializing embedding generator
initialization complete


In [12]:
midi_tokens = dict()
for path, _, files in os.walk(in_path):
    for file in [f for f in files if f.endswith(supported_extensions)][:9]:
        midi_tokens[file[:-4]] = generator.process(os.path.join(path, file))
        break

In [13]:
keys = list(midi_tokens.keys())
all_tokens = torch.cat(
    [torch.IntTensor(midi_tokens[key][0]).view(1, -1) for key in keys]
)
print(f"all tokens shape {all_tokens.shape}, saving...")
torch.save(all_tokens, f"data/{dataset_name}.pt")
with open(f"data/{dataset_name}.txt", "w") as f:
    for k in keys:
        f.write(f"{k}\n")

all tokens shape torch.Size([120, 2048]), saving...


In [14]:
df = pd.DataFrame(index=keys, columns=["embeddings"])
df.head()

Unnamed: 0,embeddings
20240117-064-2b_0044-0052,
20240117-064-03_0194-0202,
20240305-050-04_0038-0047,
20240213-100-02_0153-0158,
20231227-080-02_0071-0077,


In [15]:
# generate embeddings
for i in range(0, all_tokens.shape[0], batch_size):
    batch = all_tokens[i : i + batch_size].cuda(device=generator.device)

    with torch.autocast("cuda"):
        tokens, tokens_mask = generator.get_embeddings_tokenized(batch)

    for idx in range(batch.shape[0]):
        file = keys[i + idx]

        avg_embedding = tokens[idx][tokens_mask[idx]].mean(0).cpu().detach()

        df.loc[file, "embeddings"] = [ae.item() for ae in avg_embedding]
df.head()

Unnamed: 0,embeddings
20240117-064-2b_0044-0052,"[0.006497884634882212, 0.012574161402881145, -..."
20240117-064-03_0194-0202,"[-0.04920037463307381, 0.013055611401796341, 0..."
20240305-050-04_0038-0047,"[0.03366667777299881, 0.013924905098974705, -0..."
20240213-100-02_0153-0158,"[-0.07525605708360672, 0.08745083957910538, -0..."
20231227-080-02_0071-0077,"[-0.03562545403838158, 0.010814672335982323, 0..."


In [16]:
df.to_hdf(out_path, key=dataset_name)

if os.path.exists(out_path):
    print(f"saved embeddings to {out_path}")
else:
    print(f"error saving embeddings!")

print(f"embedding generation complete")
print(f"wrote {os.path.getsize(out_path)} bytes")

saved embeddings to data/20250110segmented.h5
embedding generation complete
wrote 1891128 bytes


  check_attribute_name(name)
your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block0_values] [items->Index(['embeddings'], dtype='object')]

  df.to_hdf(out_path, key=dataset_name)


In [12]:
df = pd.DataFrame([{"a": 0, "b": [0, 0, 0]}, {"a": 1, "b": [0, 0, 1]}])
# df = pd.DataFrame([{"a": 2, "b": [0, 1, 0]}, {"a": 3, "b": [0, 1, 1]}])
df["b"] = df["b"].apply(lambda x: str(x))
df.to_hdf("test.h5", key="test", mode="a", format="table", append=True)

TypeError: Cannot serialize the column [b]
because its data contents are not [string] but [mixed] object dtype

In [4]:
pd.read_hdf("data/20250110-augmented.h5")

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
20231220-080-01_0000-0005_t00s00,-0.05374003201723099,-0.1369369626045227,0.06174706295132637,0.0841943547129631,-0.17261245846748352,0.11589771509170532,-0.09340494126081467,0.06125115975737572,0.08776569366455078,-0.06082483381032944,...,-0.07109753042459488,0.05602120980620384,0.02119760401546955,0.0708494782447815,-0.009089718572795391,-0.0425434447824955,-0.07553980499505997,0.10298950970172882,-0.015250486321747303,-0.10468641668558121
20231220-080-01_0000-0005_t00s01,-0.038339756429195404,-0.1342100203037262,0.09470196068286896,0.09810271114110947,-0.17174844443798065,0.09985591471195221,-0.07501658797264099,0.05304262414574623,0.07501006871461868,-0.04961874336004257,...,-0.07254941016435623,0.041400086134672165,0.008511332795023918,0.06827718764543533,-0.015496845357120037,-0.045479707419872284,-0.0636056512594223,0.059915002435445786,-0.010466128587722778,-0.09437036514282227
20231220-080-01_0000-0005_t00s02,-0.05657316744327545,-0.15519671142101288,0.05781048536300659,0.0674133449792862,-0.16535654664039612,0.10068245232105255,-0.10294418781995773,0.0594639889895916,0.13015171885490417,-0.08925369381904602,...,-0.07949882000684738,0.06631117314100266,0.03909671679139137,0.07320185750722885,-0.0517890527844429,-0.06620056182146072,-0.09595770388841629,0.0935552641749382,-0.023065408691763878,-0.11531564593315125
20231220-080-01_0000-0005_t00s03,-0.0431549996137619,-0.13562242686748505,0.0913548693060875,0.07467231899499893,-0.14855173230171204,0.1182851791381836,-0.08895932883024216,0.067657470703125,0.07845760136842728,-0.06035539507865906,...,-0.0783281922340393,0.052579037845134735,0.009989548474550247,0.06871314346790314,-0.00746879680082202,-0.05118381232023239,-0.0709441751241684,0.07933220267295837,-0.02274744026362896,-0.10365426540374756
20231220-080-01_0000-0005_t00s04,-0.05730858072638512,-0.15270911157131195,0.07435664534568787,0.06725263595581055,-0.15380841493606567,0.10258445143699646,-0.10784383118152618,0.08174601197242737,0.12598533928394318,-0.08983059227466583,...,-0.08303553611040115,0.0686410591006279,0.03741278871893883,0.07390012592077255,-0.04493352025747299,-0.06163519620895386,-0.08583448082208633,0.08797251433134079,-0.023278703913092613,-0.11473135650157928
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20231220-080-01_0053-0059_t05s03,-0.06875725090503693,-0.04103611409664154,0.11414047330617905,0.011352252215147018,-0.15800300240516663,-0.002407602034509182,-0.03119153156876564,0.015505052171647549,0.09793587028980255,-0.038123808801174164,...,0.04546653851866722,0.02875274047255516,-0.18045999109745026,0.06853660196065903,-0.1815134733915329,0.03369107469916344,-0.05071718990802765,-0.0065003358758986,-0.07546187937259674,-0.006707934197038412
20231220-080-01_0053-0059_t05s04,-0.0704573467373848,-0.040392786264419556,0.1155763566493988,0.010905946604907513,-0.16281193494796753,0.020425371825695038,-0.023078151047229767,0.05271013826131821,0.09052076190710068,-0.023340968415141106,...,0.046312205493450165,0.01853790506720543,-0.16320231556892395,0.07389026135206223,-0.18047258257865906,0.010222106240689754,-0.05881476402282715,0.0060934945940971375,-0.08300969004631042,-0.013403835706412792
20231220-080-01_0053-0059_t05s05,-0.05435103178024292,-0.049290817230939865,0.10865814238786697,0.027649959549307823,-0.17027336359024048,0.024690626189112663,-0.006107950583100319,0.04813861474394798,0.06113281100988388,0.028078727424144745,...,0.06228329613804817,0.006776917725801468,-0.1625426858663559,0.07488423585891724,-0.16369161009788513,0.011404530145227909,-0.050045810639858246,-0.007715838495641947,-0.08749676495790482,0.015223273076117039
20231220-080-01_0053-0059_t05s06,-0.06292832642793655,-0.04788944497704506,0.12196234613656998,0.019506486132740974,-0.14955517649650574,0.007080752868205309,-0.015167568810284138,0.007508762180805206,0.07859722524881363,-0.015099816955626011,...,0.03664637356996536,0.04132010415196419,-0.1822945773601532,0.06495873630046844,-0.18330901861190796,0.02601979672908783,-0.032848991453647614,-0.005492504686117172,-0.0923195630311966,0.003674736013635993


In [13]:
batch_size = 8
for i in range(0, 50, batch_size):
    print(i)

0
8
16
24
32
40
48
