In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
from sklearn.manifold import TSNE

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
import torchaudio.transforms as AT
from torchvision.utils import make_grid
from IPython.display import display, Audio
import torchaudio
from torchaudio.io import StreamReader

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device, iter_batches
from src.patchdb import PatchDB, PatchDBIndex
from src.models.encoder import *
from src.util.audio import *
from src.util.files import *
from src.util.embedding import *
from scripts.train_vae_audio import VAEAudioConv

In [None]:
@torch.no_grad()
def play(audio, shape=(200, 1024), normalize=False):
    if normalize:
        audio_max = audio.abs().max()
        if audio_max:
            audio = audio / audio_max
    display(Audio(audio.clamp(-1, 1), rate=SAMPLERATE, normalize=False))
    display(plot_audio(audio, shape))
    
def read_stream(stream, seconds: float = 3.):
    chunks = []
    max_samples = int(seconds * SAMPLERATE)
    num_samples = 0
    for chunk, in stream:
        chunks.append(chunk.mean(-1))
        num_samples += chunks[-1].shape[-1]
        if num_samples >= max_samples:
            break
    return torch.concat(chunks)[:max_samples]

# load model

In [None]:
SLICE_SIZE = 1024# * 4
LATENT_SIZE = 128
#SAMPLERATE = 22_050
SAMPLERATE = 44_100
    
#model = VAEAudioConv(slice_size=SLICE_SIZE, latent_dims=LATENT_SIZE, channels=[1, 16, 24, 32])
model = VAEAudioConv(slice_size=SLICE_SIZE, latent_dims=LATENT_SIZE, channels=[1, 16, 16, 16], kernel_size=7)
#model = VAEAudioConv(slice_size=SLICE_SIZE, latent_dims=LATENT_SIZE, channels=[1, 16, 16, 16], kernel_size=15)
data = torch.load("../checkpoints/audio-vae-4/snapshot.pt")
print("steps: {:,}".format(data["num_input_steps"]))
model.load_state_dict(data["state_dict"])
model

# load audio

In [None]:
filename = "/home/bergi/Music/BR80-backup/ROLAND/LIVEREC/LIVE0033-2023-09-10-poesnek.WAV"
#filename = "/home/bergi/Music/Aphex Twin/Acoustica Alarm Will Sound Performs Aphex Twin/01 Cock_Ver 10.mp3"
filename = "/home/bergi/Music/Ray Kurzweil The Age of Spiritual Machines/(audiobook) Ray Kurzweil - The Age of Spiritual Machines - 1 of 4.mp3"
filename = "/home/bergi/Music/Scirocco/07 Bebop.mp3"

reader = StreamReader(filename)
reader.add_basic_audio_stream(4096, sample_rate=SAMPLERATE)

reader.seek(1)
play(read_stream(reader.stream(), 10))

# encode audio

In [None]:
@torch.no_grad()
def encode_audio(
        stream, 
        sub_sample: float = 1.1, 
        seconds: float = 3.,
        noise: float = 0.,
):
    num_samples = int(seconds * SAMPLERATE)
    repro = torch.zeros(num_samples)
    repro_sum = torch.zeros(num_samples)
    window = torch.hamming_window(SLICE_SIZE)
    
    embedding_batches = []
    pos_batches = []
    sample_count = 0
    for slice_batch, pos_batch in iter_batches(
        tqdm(iter_audio_slices(
            stream, 
            slice_size=SLICE_SIZE, 
            with_position=True, 
            stride=int(SLICE_SIZE / sub_sample)
        )), 
        batch_size=128
    ):
        slice_batch = slice_batch.mean(-1).unsqueeze(-2)
        embedding_batch = model.encoder(slice_batch)
        embedding_batches.append(embedding_batch)
        pos_batches.append(torch.Tensor(pos_batch).type(torch.int64))
        sample_count += slice_batch.shape[0] * slice_batch.shape[-1]
        if sample_count >= num_samples:
            break
            
    return (
        torch.concat(embedding_batches)[:num_samples], 
        torch.concat(pos_batches)[:num_samples],
    )
reader.seek(1)
slices, positions = encode_audio(reader.stream(), seconds=10)
slices = slices.permute(1, 0)
if 1:
    tsne = TSNE(1, perplexity=10)
    o = torch.Tensor(tsne.fit_transform(slices)).squeeze(-1).argsort()
    slices = slices[o]
slices = (slices * 10.).clamp(-1, 1)
img = signed_to_image(slices.unsqueeze(0))
img = VF.resize(img, [s * 3 for s in img.shape[-2:]], VF.InterpolationMode.NEAREST)
VF.to_pil_image(img)

# reconstruct audio

In [None]:
@torch.no_grad()
def reconstruct_audio(
        stream, 
        sub_sample: float = 1.1, 
        seconds: float = 3.,
        noise: float = 0.,
        echo: float = 0.,
        transform: Optional[Callable] = None,
        samplerate: int = SAMPLERATE,
        slice_size: int = SLICE_SIZE,
):
    num_samples = int(seconds * samplerate)
    repro = torch.zeros(num_samples)
    repro_sum = torch.zeros(num_samples)
    window = torch.hamming_window(slice_size)
    
    last_embedding = None
    for slice_batch, pos_batch in iter_batches(
        tqdm(iter_audio_slices(
            stream, 
            slice_size=slice_size, 
            with_position=True, 
            stride=int(slice_size / sub_sample)
        )), 
        batch_size=128
    ):
        slice_batch = slice_batch.mean(-1).unsqueeze(-2)
        embedding_batch = model.encoder(slice_batch)
        if noise:
            embedding_batch = embedding_batch + noise * torch.randn_like(embedding_batch)
        if transform:
            embedding_batch = transform(embedding_batch)
        
        if echo:
            for i, embedding in enumerate(embedding_batch):
                if last_embedding is None:
                    last_embedding = embedding
                else:
                    embedding_batch[i] = embedding_batch[i] * (1. - echo) + echo * last_embedding
                last_embedding = embedding_batch[i]
            
        repro_batch = model.decoder(embedding_batch)
        
        do_break = False
        for repro_slice, pos in zip(repro_batch, pos_batch):
            if pos + slice_size >= num_samples:
                do_break = True
                break
            
            repro[pos: pos+slice_size] += window * repro_slice[0]
            repro_sum[pos: pos+slice_size] += window
            
        if do_break:
            break
    
    mask = repro_sum > 0
    repro[mask] = repro[mask] / repro_sum[mask]
    
    return repro#.clamp(-1, 1)

reader.seek(1)
play(reconstruct_audio(
    reader.stream(), seconds=10, noise=0.00,
    #sub_sample=4.,
    #echo=.9,
    #transform=lambda emb: torch.concat([emb[:, -1:], emb[:, :-1]], dim=-1), 
    #transform=lambda emb: emb.permute(1, 0)
    #transform=lambda emb: emb.clamp(-1, 0) + emb.clamp(0, 1) * 10# + .0001 * torch.linspace(0, 10, emb.shape[-1]) 
), normalize=True)

# convert to EncoderConv1d

In [None]:
model

In [None]:
enc = EncoderConv1d(
    shape=(1, SLICE_SIZE),
    kernel_size=model.encoder.encoder[0].layers[0].kernel_size[0],
    channels=model.encoder.encoder[0].channels[1:],
    code_size=LATENT_SIZE,
)
dec = EncoderConv1d(
    shape=(1, SLICE_SIZE),
    kernel_size=model.encoder.encoder[0].layers[0].kernel_size[0],
    channels=model.encoder.encoder[0].channels[1:],
    code_size=LATENT_SIZE,
    reverse=True,
)
with torch.no_grad():
    enc.linear.weight[:] = model.encoder.linear_mu.weight
    enc.linear.bias[:] = model.encoder.linear_mu.bias
    for i in range(0, len(model.encoder.encoder[0].layers), 2):
        enc.convolution.layers[i].weight[:] = model.encoder.encoder[0].layers[i].weight
        enc.convolution.layers[i].bias[:] = model.encoder.encoder[0].layers[i].bias
    
    dec.linear.weight[:] = model.decoder[0].weight
    dec.linear.bias[:] = model.decoder[0].bias
    for i in range(0, len(model.decoder[2].layers), 2):
        dec.convolution.layers[i].weight[:] = model.decoder[2].layers[i].weight
        dec.convolution.layers[i].bias[:] = model.decoder[2].layers[i].bias
    
enc

In [None]:
with torch.no_grad():
    reader.seek(31)
    for slice in zip(iter_audio_slices(reader.stream(), SLICE_SIZE), "abc"):
        slice = slice[0].permute(1, 0)
    
    emb1 = model.encoder(slice.unsqueeze(0))[0]
    emb2 = enc(slice.unsqueeze(0))[0]

    slice_repro = dec(emb2.unsqueeze(0))[0]
    display(px.line(pd.DataFrame({"org": slice[0], "repro": slice_repro[0]})))
    
    display(px.line(pd.DataFrame({"vae": emb1, "enc": emb2})))

# save encoder1d

In [None]:
!ls ../models/encoder1d

In [None]:
filename = f"../models/encoder1d/conv-1x{SLICE_SIZE}-{LATENT_SIZE}.pt"
torch.save(enc.state_dict(), filename)
filename = f"../models/encoder1d/conv-1x{SLICE_SIZE}-{LATENT_SIZE}-decoder.pt"
torch.save(dec.state_dict(), filename)

In [None]:
reader.seek(1)
audio = read_stream(reader.stream(), 3.)

speccer = AT.MelSpectrogram(
    sample_rate=SAMPLERATE,
    n_fft=2048,
    n_mels=200,
    #f_max=1000,
    power=.5,
    #normalized=True,
)
spec = speccer(audio)
spec[spec == 0] = torch.nan
px.imshow(spec, height=800, aspect=False)

In [None]:
AT.MelSpectrogram?