In [None]:
import sys
sys.path.append("..")

import json
import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
from sklearn.manifold import TSNE
from sklearn.decomposition import IncrementalPCA

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
import torchaudio.transforms as AT
import torchaudio.functional as AF
from torchvision.utils import make_grid
from IPython.display import display, Audio
import torchaudio
from torchaudio.io import StreamReader

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device, iter_batches
from src.patchdb import PatchDB, PatchDBIndex
from src.models.encoder import *
from src.util.audio import *
from src.util.files import *
from src.util.embedding import *
from scripts import datasets
from src.algo import AudioUnderstander 

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode)

# model

In [None]:
au = AudioUnderstander.load("../models/au/au-1sec-3x256.pt")
#au.slice_size = au.sample_rate // 10
#au.spectral_shape = (au.spectral_shape[0], au.spectral_shape[1] // 10)
#au.drop_encoder(1)
#au.spectral_patch_shapes

In [None]:
SHAPE = (1, 256 * 3)

if 0:
    from scripts.train_vae_spectral import SimpleVAE

    vae = SimpleVAE(SHAPE, latent_dims=math.prod(SHAPE) // 12, kl_loss_weight=0.)
    data = torch.load("../checkpoints/spec6-final-vae/best.pt")
    print(f"inputs: {data['num_input_steps']:,}")
    vae.load_state_dict(data["state_dict"])
    final_encoder = vae.encoder.linear_mu
else:
    from scripts.train_contrastive_ds import SimpleEncoder
    
    m = SimpleEncoder((math.prod(SHAPE), 64))
    #data = torch.load("../checkpoints/contr-au5-64-shuff50k/best.pt")
    #data = torch.load("../checkpoints/contr-au6-64-mask/snapshot.pt")
    data = torch.load("../checkpoints/contr-au7-64/snapshot.pt")
    data = torch.load("../checkpoints/contr-au9-64-cr03/best.pt")
    print(f"inputs: {data['num_input_steps']:,}")
    m.load_state_dict(data["state_dict"])
    final_encoder = m
    
final_encoder

# data

In [None]:
ds = datasets.audio_slice_dataset(
    path="~/Music", recursive=True,
    interleave_files=1,
    mono=True,
    slice_size=au.slice_size,
    stride=au.slice_size,
    spectral_shape=au.spectral_shape,
    spectral_normalize=1,
    
    with_filename=True,
    with_position=True,
)
for i, (spec, filename, pos) in zip(range(10), ds):
    print(pos / au.sample_rate, "\t", filename, spec.shape)
    

# make embedding dataset

In [None]:
if 0:
    embeddings = []
    filename_ids = []
    filename_map = dict()
    size_in_bytes = 0
    last_print_size = 0
    try:
        for spec, filename, pos in tqdm(ds):
            filename = str(filename)
            embeddings.append(au.encode_spectrum(spec.squeeze(0)))
            if filename not in filename_map:
                filename_map[filename] = len(filename_map) + 1
            filename_ids.append(filename_map[filename])
            size_in_bytes += 4 * math.prod(embeddings[-1].shape)

            if size_in_bytes - last_print_size > 1024 * 1024 * 100:
                last_print_size = size_in_bytes
                print(f"bytes {size_in_bytes:,}, files: {len(filename_map):,}")

            if size_in_bytes >= 1024 ** 3 * 2:
                break

    except KeyboardInterrupt:
        pass

    embeddings = torch.concat([e for e in embeddings])
    print("embeddings", embeddings.shape)

## save dataset

In [None]:
if 0:
    fn = "../datasets/embeddings-au-1sec-3x256"
    torch.save(embeddings, f"{fn}.pt")
    torch.save(torch.Tensor(filename_ids).to(torch.int64), f"{fn}-ids.pt")
    Path(f"{fn}-filename-map.json").write_text(json.dumps({v: k for k, v in filename_map.items()}))

# create embeddings for testing

In [None]:
embeddings = []
filenames = []
positions = []
with torch.inference_mode():
    try:
        for spec, filename, pos in tqdm(ds, total=20_000):
            bow = au.encode_spectrum(spec.squeeze(0))
            embedding = final_encoder(bow)
            embeddings.append(embedding)
            filenames.append(filename)
            positions.append(pos)
            if len(embeddings) >= 20_000:
                break
    except KeyboardInterrupt:
        pass

embeddings = torch.concat([e for e in embeddings])
embeddings_n = embeddings / embeddings.norm(dim=1, keepdim=True)
print(embeddings.shape)

In [None]:
embeddings = embeddings[:20_000]
embeddings_n = embeddings_n[:20_000]

In [None]:
similarity = embeddings_n @ embeddings_n.T

In [None]:
px.imshow(similarity[:300, :300], height=1000)

In [None]:
px.imshow(similarity[:300, :300], height=1000)

In [None]:
s_min, s_max = similarity.min(), similarity.max()
img = ((similarity - s_min) / (s_max - s_min))[:5000,:5000].unsqueeze(0)
VF.to_pil_image(resize(img, .2, VF.InterpolationMode.BILINEAR))

In [None]:
def get_similars(idx: int, count: int = 10):
    emb = embeddings_n[idx]
    sim = emb @ embeddings_n.T
    best_indices = sim.argsort(descending=True)
    return best_indices[:count], sim[best_indices][:count]

def show_similars(idx: int, count: int = 10):
    best_indices, distances = get_similars(idx, count)
    
    reader_map = {}    
    for bi, dist in zip(best_indices, distances):
        fn = filenames[bi]
        print(float(dist), positions[bi] // au.sample_rate, filenames[bi])
        if fn not in reader_map:
            reader_map[fn] = StreamReader(str(fn))
            reader_map[fn].add_audio_stream(au.slice_size, au.sample_rate)
        reader_map[fn].seek(positions[bi] / au.sample_rate)
        audio = next(iter(reader_map[fn].stream()))[0]#.mean(1)
        if audio.dtype != torch.float32:
            audio = audio.to(torch.float32) / 32767
        audio = audio.mean(1)
        display(plot_audio(audio, (128, 386)))
        display(Audio(audio, rate=au.sample_rate))
        
#show_similars(9224)
show_similars(9000)

In [None]:
StreamReader?
#.add_audio_stream?
