In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
from sklearn.manifold import TSNE
from sklearn.decomposition import IncrementalPCA

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
import torchaudio.transforms as AT
import torchaudio.functional as AF
from torchvision.utils import make_grid
from IPython.display import display, Audio
import torchaudio
from torchaudio.io import StreamReader

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device, iter_batches
from src.patchdb import PatchDB, PatchDBIndex
from src.models.encoder import *
from src.util.audio import *
from src.util.files import *
from src.util.embedding import *
from scripts import datasets

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode)

# create dataset

In [None]:
SPEC_SHAPE = (128, 128)
SHAPE = (8, 8)
PATCHES_PER_SLICE = math.prod(s1 // s2 for s1, s2 in zip(SPEC_SHAPE, SHAPE))
TOTAL_SLICES = 20103
TOTAL_PATCHES = TOTAL_SLICES * PATCHES_PER_SLICE
print(PATCHES_PER_SLICE, "patches per slice")

ds = datasets.audio_slice_dataset(
    interleave_files=200,
    mono=True,
    spectral_size=SPEC_SHAPE[-1],
    spectral_patch_shape=SHAPE,
    spectral_normalize=1,
)

# ds = ImageFilterIterableDataset(ds, ImageFilter(min_std=.1))

grid = [patch for i, patch in zip(range(32*32), ds)]
img = make_grid(grid, nrow=32, normalize=True)
VF.to_pil_image(VF.resize(img, [s * 3 for s in img.shape[-2:]], VF.InterpolationMode.NEAREST))

In [None]:
max([g.max() for g in grid])

# PCA

In [None]:
pca = IncrementalPCA(math.prod(SHAPE), batch_size=1024)

try:
    with tqdm(total=TOTAL_PATCHES) as progress:
        for batch in DataLoader(ds, batch_size=1024):
            pca.partial_fit(batch.flatten(1))
            progress.update(batch.shape[0])
            
except KeyboardInterrupt:
    pass

In [None]:
VF.to_pil_image(resize(make_grid(torch.Tensor(pca.components_).view(64, 1, 8, 8), normalize=True), 4))

# greedy lib

In [None]:
lib = GreedyLibrary(1, shape=(1, *SHAPE))

try:
    with tqdm(total=TOTAL_PATCHES) as progress:
        for batch in DataLoader(ds, batch_size=1024):
            lib.fit(
                batch, 
                metric="corr",
                max_entries=math.prod(SHAPE),
                grow_if_distance_above=.1,
            )
            progress.update(batch.shape[0])
            
except KeyboardInterrupt:
    pass

In [None]:
lib.plot_entries(sort_by="hits")

# multiple autoencoders

In [None]:
from scripts.train_vae_spectral import SimpleVAE

vaes = []
for filename, shape in (
        ("../checkpoints/spec2/best.pt", (8, 8)),
        ("../checkpoints/spec3-8x128/best.pt", (8, 128)),
        ("../checkpoints/spec4-128x8/best.pt", (128, 8)),
):
    vae = SimpleVAE(shape, latent_dims=math.prod(shape) // 8)
    data = torch.load(filename)
    print(f"{filename} inputs: {data['num_input_steps']:,}")
    vae.load_state_dict(data["state_dict"])
    vaes.append((vae, shape))
    print(vae)
    for W in (vae.encoder.linear_mu.weight, vae.decoder[0].weight.permute(1, 0)):
        print(float(W.min()), float(W.max()))
        display(VF.to_pil_image(resize(signed_to_image(make_grid(W.view(-1, 1, *shape)[:16])), 4)))

# create KMeans cluster for each autoencoder / spec-patch size

In [None]:
NUM_CLUSTERS = 256

from sklearn.cluster import MiniBatchKMeans

clusterers = []
for vae, shape in vaes:
    ds_local = datasets.audio_slice_dataset(
        interleave_files=200,
        mono=True,
        spectral_size=SPEC_SHAPE[-1],
        spectral_normalize=1_000,
        spectral_patch_shape=shape,
    )
    clusterer = MiniBatchKMeans(
        n_clusters=NUM_CLUSTERS,
        batch_size=1024,
        random_state=23,
        n_init=1,
        reassignment_ratio=.1,
    )
    clusterers.append(clusterer)

    with torch.no_grad():
        try:
            patches_per_slice = math.prod(s1 // s2 for s1, s2 in zip(SPEC_SHAPE, shape))
            W = vae.encoder.linear_mu.weight
            with tqdm(
                    total=TOTAL_SLICES * patches_per_slice, 
                    desc=f"cluster spec-shape: {tuple(shape)}, encoder: {W.shape[-1]}->{W.shape[-2]}"
            ) as progress:
                hist_sum = None
                for i, batch in enumerate(DataLoader(ds_local, batch_size=1024)):
                    embeddings = vae.encoder(batch)
                    clusterer.partial_fit(embeddings.flatten(1))
                    progress.update(batch.shape[0])
                    
                    if False and i > 10:
                        labels = clusterer.predict(embeddings.flatten(1))
                        hist, _ = np.histogram(labels, clusterer.n_clusters)
                        if hist_sum is None:
                            hist_sum = hist
                        else:
                            hist_sum = hist_sum + hist
                        if i % 200 == 0:
                            print(clusterer.inertia_, hist_sum.max(), hist_sum)
                            hist_sum = None
                            
        except KeyboardInterrupt:
            pass

## test single clusterer

In [None]:
idx = 1
with torch.no_grad():
    ds_local = datasets.audio_slice_dataset(
        #interleave_files=200,
        mono=True,
        spectral_size=SPEC_SHAPE[-1],
        spectral_normalize=1_000,
        spectral_patch_shape=vaes[idx][1],
        #seek_offset=30,
    )
    hist_sum = None
    max_v = 0
    for i, spec in zip(tqdm(range(1000)), ds_local):
        max_v = max(max_v, float(spec.max()))
    print("MAX", max_v)
    hists = []
    for i, patches in zip(tqdm(range(20)), DataLoader(ds_local, batch_size=1000)):
        #patches = (patches / max_v).clamp(0, 1)
        #patches = torch.concat(list(iter_image_patches(spec, vaes[idx][1])))
        embeddings = vaes[idx][0].encoder.linear_mu(patches.flatten(1))
        labels = clusterers[idx].predict(embeddings)
        hist, _ = np.histogram(labels, clusterers[idx].n_clusters)
        hists.append(torch.Tensor(hist).unsqueeze(0))
        if hist_sum is None:
            hist_sum = hist
        else:
            hist_sum = hist_sum + hist

hists = torch.concat(hists)
hists = hists / hists.max() 
display(VF.to_pil_image(hists))
px.line(hists.sum(0))#.type(torch.int32)

In [None]:
np.histogram(np.array([0, 1, 1, 1, 2]), 4, range=(0, 4))

In [None]:
idx = 2
with torch.no_grad():
    audio, sr = torchaudio.load("/home/bergi/Music/theDropper/05 CD Track 05.mp3")
    audio = audio.mean(0)
    slices = list(iter_audio_slices((audio, ), 44100))
    print(len(slices), "slices")
    
    speccer = AT.MelSpectrogram(
        sample_rate=au.sample_rate,
        n_fft=1024 * 2,
        win_length=au.sample_rate // 30,
        hop_length=au.sample_rate // au.spectral_shape[-1],
        n_mels=au.spectral_shape[-2],
        power=1.,
    )
    specs = [speccer(slice) for slice in slices if slice.shape[-1] == au.slice_size]
    
    hist_sum = None
    max_v = 0
    for spec in specs:
        max_v = max(max_v, float(spec.max()))
    print("MAX", max_v)
    hists = []
    embeddingss = []
    for spec in specs:
        spec = (spec / max_v).clamp(0, 1)
        patches = torch.concat(list(iter_image_patches(spec.unsqueeze(0), vaes[idx][1])))
        #display(VF.to_pil_image(make_grid(patches.unsqueeze(1), nrow=128)))
        embeddings = vaes[idx][0].encoder.linear_mu(patches.flatten(1))
        # print(patches.shape, embeddings.shape)
        embeddingss.append(embeddings)
        labels = clusterers[idx].predict(embeddings)
        # print(embeddings.shape, labels.shape, labels)
        hist, _ = np.histogram(labels, clusterers[idx].n_clusters, range=(0, clusterers[idx].n_clusters - 1))
        # print(hist)
        hists.append(torch.Tensor(hist).unsqueeze(0))
        if hist_sum is None:
            hist_sum = hist
        else:
            hist_sum = hist_sum + hist

embeddingss = torch.concat(embeddingss)
embeddingss = embeddingss / embeddingss.max()
# display(VF.to_pil_image(embeddingss))
hists = torch.concat(hists)
hists = hists / hists.max() 
display(VF.to_pil_image(hists))
px.line(hists.sum(0))#.type(torch.int32)

# create AudioUnderstander

In [None]:
class AudioUnderstander:
    """
    Based on: 

    Attention is All You Need?
    Good Embeddings with Statistics are Enough: Large Scale Audio Understanding without Transformers/ Convolutions/ BERTs/ Mixers/ Attention/ RNNs or ....
    Prateek Verma
    https://browse.arxiv.org/pdf/2110.03183.pdf
    """
    def __init__(
            self,
            sample_rate: int = 44100,
            slice_size: int = 44100,
            spectral_shape: Tuple[int, int] = (128, 128), 
            spectral_patch_shapes: Iterable[Tuple[int, int]] = ((8, 8), ),
            encoder_ratios: Iterable[int] = (8, ),
            num_clusters: Iterable[int] = (256, ),
    ):
        self.sample_rate = sample_rate
        self.slice_size = slice_size
        self.spectral_shape = tuple(spectral_shape)
        self.spectral_patch_shapes = [tuple(s) for s in spectral_patch_shapes]
        self.encoder_ratios = list(encoder_ratios)
        self.num_clusters = list(num_clusters)
        self._spectrogrammer = AT.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=1024 * 2,
            win_length=sample_rate // 30,
            hop_length=sample_rate // spectral_shape[-1],
            n_mels=spectral_shape[-2],
            power=1.,
        )
        self.encoders = [
            nn.Linear(math.prod(spectral_patch_shape), math.prod(spectral_patch_shape) // encoder_ratio)
            for spectral_patch_shape, encoder_ratio in zip(self.spectral_patch_shapes, self.encoder_ratios)
        ]
        self.clusterers = [
            MiniBatchKMeans(
                n_clusters=num_clusters,
                batch_size=1024,
                random_state=23,
                n_init=1,
            )
            for num_clusters in self.num_clusters
        ]
        
        for check_attribute in ("encoder_ratios", "clusterers"):
            if len(self.spectral_patch_shapes) != len(getattr(self, check_attribute)):
                raise ValueError(
                    f"`{check_attribute}` must be same length as `spectral_patch_shapes`"
                    f", expected {len(self.spectral_patch_shapes)}, got {len(getattr(self, check_attribute))}"
                )
            
    @torch.inference_mode()
    def encode_audio(self, audio: torch.Tensor, normalize_spec: bool = True) -> torch.Tensor:
        if audio.ndim == 1:
            pass
        elif audio.ndim == 2:
            if audio.shape[0] > 1:
                audio = audio.mean(0)
            else:
                audio = audio.squeeze(0)
        else:
            raise ValueError(f"Need audio.ndim == 1 or 2, got {audio.ndim}")
        
        if audio.shape[-1] < self.slice_size:
            audio = torch.concat([audio, torch.zeros(self.slice_size - audio.shape[-1]).to(audio.dtype)])
        
        histograms = []
        while audio.shape[-1] >= self.slice_size:            
            spec = self._get_spec(audio[:self.slice_size], normalize=normalize_spec)
            
            histograms.append(self._get_histogram(spec))
            
            audio = audio[self.slice_size:]
        
        return torch.concat(histograms)

    @torch.inference_mode()
    def encode_spectrum(self, spectrum: torch.Tensor, normalize: bool = True) -> torch.Tensor:
        if spectrum.ndim == 2:
            pass
        else:
            raise ValueError(f"Need spectrum.ndim == 2, got {spectrum.ndim}")
        
        if spectrum.shape[-2] != self.spectral_shape[0]:
            raise ValueError(f"spectrum.shape must fit `spectral_shape`, expected {self.spectral_shape}, got {spectrum.shape}")

        if spectrum.shape[-1] < self.slice_size:
            audio = torch.concat([spectrum, torch.zeros(spectrum.shape[-2], self.slice_size - spectrum.shape[-1]).to(spectrum.dtype)])
        
        if normalize:
            spec_max = spec.max()
            if spec_max:
                spectrum = spectrum / spec_max
        
        histograms = []
        while spectrum.shape[-1] >= self.spectral_shape[-1]:
            spec = spectrum[:, :self.spectral_shape[-1]]
            histograms.append(self._get_histogram(spec))
            
            spectrum = spectrum[:, self.spectral_shape[-1]:]
        
        return torch.concat(histograms)
    
    def _get_histogram(self, spec: torch.Tensor) -> torch.Tensor:
        one_full_hist = []
        for spectral_patch_shape, encoder, clusterer in zip(self.spectral_patch_shapes, self.encoders, self.clusterers):
            patches = torch.concat(list(iter_image_patches(spec.unsqueeze(0), spectral_patch_shape)))
            # print(patches.max())
            embeddings = encoder(patches.flatten(1))

            labels = clusterer.predict(embeddings.flatten(1))

            hist, _ = np.histogram(labels, bins=clusterer.n_clusters, range=(0, clusterer.n_clusters - 1))
            
            if spectral_patch_shape[-1] != spectral_patch_shape[-2]:
                pass
                #print(spec.max(), hist)
                #hist = hist[1:-1]
            
            one_full_hist.append(torch.Tensor(hist) / patches.shape[0])

        return torch.concat(one_full_hist).unsqueeze(0)
        
    def _get_spec(self, audio: torch.Tensor, normalize: bool):
        spec = self._spectrogrammer(audio)[:, :self.spectral_shape[-1]]
        if normalize:
            spec_max = spec.max()
            if spec_max:
                spec = spec / spec_max
        return spec.clamp(0, 1)
    
    def save(self, file) -> None:
        data = {
            "sample_rate": self.sample_rate,
            "slice_size": self.slice_size,
            "spectral_shape": self.spectral_shape, 
            "spectral_patch_shapes": self.spectral_patch_shapes,
            "encoder_ratios": self.encoder_ratios,
            "num_clusters": self.num_clusters,
        }
        for i in range(len(self.clusterers)):
            data.update({
                f"encoder.{i}.weight": self.encoders[i].weight[:],
                f"encoder.{i}.bias": self.encoders[i].bias[:],
                f"clusterer.{i}": self.clusterers[i]
            })
        torch.save(data, file)
    
    @classmethod
    def load(cls, fp):
        data = torch.load(fp)
        c = cls(
            sample_rate=data["sample_rate"],
            slice_size=data["slice_size"],
            spectral_shape=data["spectral_shape"],
            spectral_patch_shapes=data["spectral_patch_shapes"],
            encoder_ratios=data["encoder_ratios"],
            num_clusters=data["num_clusters"],
        )
        with torch.no_grad():
            for i in range(len(c.clusterers)):
                c.encoders[i].weight[:] = data[f"encoder.{i}.weight"]
                c.encoders[i].bias[:] = data[f"encoder.{i}.bias"]
                c.clusterers[i] = data[f"clusterer.{i}"]
        return c
    
au = AudioUnderstander(
    spectral_patch_shapes=[shape for vae, shape in vaes],
    encoder_ratios=[8] * len(vaes),
    num_clusters=[NUM_CLUSTERS] * len(vaes),    
)
with torch.no_grad():
    for i, ((vae, shape), clusterer) in enumerate(zip(vaes, clusterers)):
        au.encoders[i].weight[:] = vae.encoder.linear_mu.weight
        au.encoders[i].bias[:] = vae.encoder.linear_mu.bias
        au.clusterers[i] = deepcopy(clusterer)

if 1:
    import io
    fp = io.BytesIO()
    au.save(fp)
    print(f"filesize: {fp.tell():,} bytes")
    fp.seek(0)
    au = AudioUnderstander.load(fp)

In [None]:
# SAVE
au.save("../models/au/au-1sec-3x256.pt")

In [None]:
def plot_features(filename, histograms: bool = True):
    audio, sr = torchaudio.load(filename)
    print(f"{audio.shape[-2]} x {audio.shape[-1] / sr:.2f} secs @ {sr:,}Hz")
    if sr != au.sample_rate:
        audio = AF.resample(audio, sr, au.sample_rate)
    # num_slices = audio.shape[-1] / au.slice_size 
    hists = au.encode_audio(audio)
    if histograms:
        img = hists.T.unsqueeze(0) / hists.max()
        img = (img * 10).clamp(0, 1)
        display(VF.to_pil_image(resize(img, 3)))
    display(px.line(hists.mean(0)))

plot_features("/home/bergi/Music/theDropper/05 CD Track 05.mp3")

In [None]:
plot_features("/home/bergi/Music/Scirocco/07 Bebop.mp3", histograms=False)

In [None]:
plot_features("/home/bergi/Music/Hitchhiker's Guide - Radio Play/Hitchhiker'sGuideEpisode-03.mp3")

In [None]:
plot_features("/home/bergi/Music/Ray Kurzweil The Age of Spiritual Machines/(audiobook) Ray Kurzweil - The Age of Spiritual Machines - 1 of 4.mp3")