In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
from sklearn.manifold import TSNE
from sklearn.decomposition import IncrementalPCA

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
import torchaudio.transforms as AT
import torchaudio.functional as AF
from torchvision.utils import make_grid
from IPython.display import display, Audio
import torchaudio
from torchaudio.io import StreamReader

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device, iter_batches
from src.patchdb import PatchDB, PatchDBIndex
from src.models.encoder import *
from src.models.cnn import *
from src.models.util import *
from src.models.transform import *
from src.util.audio import *
from src.util.files import *
from src.util.embedding import *
from scripts import datasets
from src.algo import AudioUnderstander 

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
audio, sr = torchaudio.load("/home/bergi/Music/Scirocco/07 Bebop.mp3")

In [None]:
SAMPLE_RATE = 44_100
SPEC_SHAPE = (64, 64)

speccer = AT.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=1024 * 2,
    win_length=SAMPLE_RATE // SPEC_SHAPE[-1],
    hop_length=SAMPLE_RATE // SPEC_SHAPE[-1],
    n_mels=SPEC_SHAPE[-2],
    power=1.,
)
spec = speccer(audio[:, 40*sr:41*sr].mean(0))
VF.to_pil_image(resize(VF.vflip(spec.unsqueeze(0)) / spec.max(), 4))

In [None]:
slice_ds = datasets.audio_slice_dataset(
    path="~/Music/", recursive=True,
    interleave_files=1000,
    mono=True,
    #shuffle_slices=1_000,
    #shuffle_files=True,
    with_filename=True,
    with_position=True,
)

In [None]:
speccer = speccer.cuda()

spec_map = {}
num_bytes = 0
last_num_bytes = 0
max_second = 0.
try:
    with torch.inference_mode():
        for i, (audio, fn, pos) in enumerate(tqdm(slice_ds)):
            fn = str(fn)
            spec = speccer(audio.cuda()).cpu().squeeze(0)[:, :SPEC_SHAPE[-1]].to(torch.half)
            if i == 0:
                spec_size_bytes = math.prod(spec.shape) * 2
                print(f"audio shape: {audio.shape}")
                print(f"spec shape:  {spec.shape} ({spec_size_bytes} bytes)")
                display(VF.to_pil_image(VF.vflip(spec) / spec.max()))
            
            if fn not in spec_map:
                spec_map[fn] = {}
            spec_map[fn][str(pos)] = spec
            
            max_second = max(max_second, pos / SAMPLE_RATE)
            num_bytes += spec_size_bytes
            if num_bytes - last_num_bytes > 1024**2 * 500:
                last_num_bytes = num_bytes
                print(f"bytes: {num_bytes:,} files: {len(spec_map):,} max-sec: {max_second:.2f}")
                
            if num_bytes >= 1024**3 * 4:
                break
                
except KeyboardInterrupt:
    pass

In [None]:
torch.save(spec_map, "../datasets/audio-file-pos-spec-1sec-dict.pt")

In [None]:
class AudioSpecIterableDataset(IterableDataset):

    _spec_maps = {}

    def __init__(
            self,
            path: Union[str, Path] = "~/Music/",
            recursive: bool = False,
            sample_rate: int = 44100,
            slice_size: int = 44100,
            stride: Optional[int] = None,
            interleave_files: Optional[int] = None,
            shuffle_files: bool = False,
            shuffle_slices: Optional[int] = None,
            mono: bool = False,
            seek_offset: float = 0.,
            with_filename: bool = False,
            with_position: bool = False,
            
            spec_slice_size: int = 44100,
            spec_shape: Tuple[int, int] = (64, 64),
            spec_stride: int = 1,
    ):
        self.spec_shape = spec_shape
        self.spec_slice_size = spec_slice_size
        self.spec_stride = spec_stride        
        self.sample_rate = sample_rate
        self.with_filename = with_filename
        self.with_position = with_position 
        self.slice_ds = datasets.audio_slice_dataset(
            path=path, 
            recursive=True,
            interleave_files=interleave_files,
            slice_size=slice_size,
            stride=stride,
            mono=mono,
            shuffle_files=shuffle_files,
            shuffle_slices=shuffle_slices,
            seek_offset=seek_offset,
            with_position=True,
            with_filename=True,
        )

        self.speccer = AT.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024 * 2,
            win_length=self.spec_slice_size // spec_shape[-1],
            hop_length=self.spec_slice_size // spec_shape[-1],
            n_mels=spec_shape[-2],
            power=1.,
        )

    def __iter__(self):
        for audio, fn, pos in self.slice_ds:
            spec = self.speccer(audio)
            
            #audio_slice_width = int(self.spec_slice_size / self.spec_shape[-1] * audio.shape[-1])
            for offset in range(0, spec.shape[-1], self.spec_stride):
                # print(offset, self.spec_shape[-1], spec.shape[-1]) 
                if offset + self.spec_shape[-1] <= spec.shape[-1]:
                    audio_offset = int(offset / spec.shape[-1] * audio.shape[-1])
                    # print("X", audio_offset, audio_slice_width, audio.shape[-1])
                    if audio_offset + self.spec_slice_size <= audio.shape[-1]:
                        spec_slice = spec[..., offset:offset + self.spec_shape[-1]]
                        audio_slice = audio[..., audio_offset:audio_offset + self.spec_slice_size]

                        yield audio_slice, spec_slice, fn, pos + audio_offset

            #for offset in range(0, audio.shape[-1], audio.shape[-1] // 50):
            #    if offset + self.sample_rate <= audio.shape[-1]:
            #        audio_slice = audio[offset: offset + self.sample_rate]
            #        yield audio, offset

SAMPLE_RATE = 44_100
SPEC_SHAPE = (64, 64)
ds = AudioSpecIterableDataset(
    "~/Music", recursive=True,
    slice_size=SAMPLE_RATE * 4,
    stride=SAMPLE_RATE * 2,
    spec_shape=SPEC_SHAPE,
    spec_slice_size=SAMPLE_RATE // 4,
    spec_stride=1,
    interleave_files=1000,
    mono=True,
)
#ds = IterableShuffle(ds, 100_000)
if 1:
    audios, specs = [], []
    for i, (audio, spec, filename, pos) in zip(range(16), ds):
        print(audio.shape, spec.shape, pos, str(filename)[-30:])
        audios.append(plot_audio(audio / audio.abs().max(), tensor=True, shape=SPEC_SHAPE))
        specs.append(VF.vflip(spec.mean(0).unsqueeze(0)) / spec.max())
    display(VF.to_pil_image(make_grid(audios)))
    display(VF.to_pil_image(make_grid(specs)))
else:
    for audio, spec, filename, pos in tqdm(ds):    
        #print(audio.shape, spec.shape); break
        pass

# dev spec to audio

In [None]:
AUDIO_SIZE = SAMPLE_RATE // 4
if 1:
    decoder = nn.Sequential(
        nn.Flatten(1),
        nn.Linear(math.prod(SPEC_SHAPE), math.prod(SPEC_SHAPE)),
        nn.ReLU(),
        nn.Linear(math.prod(SPEC_SHAPE), AUDIO_SIZE),
        nn.Tanh(),
        Reshape((1, AUDIO_SIZE)),
    )
    decoder.load_state_dict(torch.load("../checkpoints/sta6/best.pt")["state_dict"])
else:
    from scripts.train_spec_to_audio import SpecToWave
    decoder = SpecToWave(SPEC_SHAPE, AUDIO_SIZE, 10)
    decoder.load_state_dict(torch.load("../checkpoints/sta5-sr4/best.pt")["state_dict"])

In [None]:
audio, _ = torchaudio.load("/home/bergi/Music/Scirocco/07 Bebop.mp3")
#audio, _ = torchaudio.load("/home/bergi/Music/Ray Kurzweil The Age of Spiritual Machines/(audiobook) Ray Kurzweil - The Age of Spiritual Machines - 1 of 4.mp3")
speccer = AT.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=1024 * 2,
    win_length=SAMPLE_RATE // SPEC_SHAPE[-1],
    hop_length=SAMPLE_RATE // SPEC_SHAPE[-1],
    n_mels=SPEC_SHAPE[-2],
    power=1.,
)
spec = speccer(audio.mean(0)[SAMPLE_RATE * 38:SAMPLE_RATE * 52])
print(spec.shape)
VF.to_pil_image(VF.vflip(spec) / spec.max())

In [None]:
in_spec = spec[..., :SPEC_SHAPE[-1]].unsqueeze(0)
recon = decoder(in_spec).squeeze(0)
display(plot_audio(recon, (128, 1000)))
out_spec = speccer(recon)
display(VF.to_pil_image(in_spec / in_spec.max()))
VF.to_pil_image(out_spec / out_spec.max())

In [None]:
def spec_to_audio(decoder, spec, spec_shape, audio_size, stride=1, sub_sample: float = 2):
    def _yield_patches():
        audio_x = 0
        spec_x = 0
        while spec_x + spec_shape[-1] <= spec.shape[-1]:
            yield (
                spec[:, spec_x:spec_x + spec_shape[-1]], 
                audio_x,
            )
            spec_x += stride
            audio_x += max(1, int(audio_size / sub_sample))
                
    result = torch.zeros(int(spec.shape[-1] / spec_shape[-1] * audio_size) * 2)
    result_sum = torch.zeros_like(result)
    
    for spec_batch, pos_batch in iter_batches(_yield_patches(), 32):
        audio_batch = decoder(spec_batch).mean(1)
        for pos, audio_slice in zip(pos_batch, audio_batch):
            end = pos + audio_slice.shape[0]
            if end > result.shape[-1]:
                result = torch.concat([result, torch.zeros(end - result.shape[-1])])
                result_sum = torch.concat([result_sum, torch.zeros(end - result_sum.shape[-1])]) 
            s = slice(pos, end)
            
            result[s] = result[s] + audio_slice
            result_sum[s] = result_sum[s] + 1.
    
    mask = result_sum > 0
    result[mask] = result[mask] / result_sum[mask]
    return result


with torch.no_grad():
    wave = spec_to_audio(decoder, spec, SPEC_SHAPE, AUDIO_SIZE, stride=30, sub_sample=1)
print(
    f"spec {spec.shape[-1]}/{SPEC_SHAPE[-1]}={spec.shape[-1]//SPEC_SHAPE[-1]}"
    f" audio-size {AUDIO_SIZE} wave: {wave.shape}")
display(Audio(wave, rate=SAMPLE_RATE))
plot_audio(wave)