In [None]:
from pathlib import Path
from functools import partial

from einops import repeat
from einops import rearrange

import IPython.display as ipd
import torch
import torchaudio
from hear_ced import ced_base

from evotorch import Problem
from evotorch.algorithms import SteadyStateGA
from evotorch.operators import (
    SimulatedBinaryCrossOver,
    GaussianMutation,
)
from evotorch.logging import StdOutLogger

from synthmap.synth import Snare808

%load_ext autoreload
%autoreload 2

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
background_path = Path("../audio/snares")
audio_files = list(background_path.rglob("*.wav"))

In [None]:
model = ced_base.load_model(device=device)

In [None]:
audio = []
for file in audio_files:
    waveform, sample_rate = torchaudio.load(file)

    # Convert to mono
    if waveform.shape[0] > 1:
        waveform = waveform[:1]

    # Resample if necessary
    if sample_rate != model.sample_rate:
        waveform = torchaudio.functional.resample(
            waveform, sample_rate, model.sample_rate, lowpass_filter_width=512
        )

    # Pad to minimum length
    if waveform.shape[-1] < model.sample_rate:
        waveform = torch.nn.functional.pad(
            waveform, (0, model.sample_rate - waveform.shape[-1])
        )

    audio.append(waveform)

In [None]:
embeddings = []
for waveform in audio:
    emb = ced_base.get_scene_embeddings(waveform.to(device), model)
    embeddings.append(emb)

embeddings = torch.vstack(embeddings)

mean = torch.mean(embeddings, dim=0)
cov = torch.cov(embeddings.T)

In [None]:
torchaudio.functional.frechet_distance(mean, cov, mean, cov)

# Audio Distance

In [None]:
snare = Snare808(48000, 48000)

num_params = snare.get_num_params()
params = torch.rand(1, num_params)

y = snare(params)

ipd.display(ipd.Audio(y, rate=48000))

In [None]:
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=snare.sample_rate,
    n_fft=2048,
    hop_length=128,
    n_mels=128,
)
mel_spectrogram = mel_spectrogram.to(device)

mel_target = []

for file in audio_files:
    waveform, sample_rate = torchaudio.load(file)

    # Convert to mono
    if waveform.shape[0] > 1:
        waveform = waveform[:1]

    # Resample if necessary
    if sample_rate != model.sample_rate:
        waveform = torchaudio.functional.resample(
            waveform, sample_rate, snare.sample_rate, lowpass_filter_width=512
        )

    # Ensure 1sec length
    if waveform.shape[-1] < snare.sample_rate:
        waveform = torch.nn.functional.pad(
            waveform, (0, snare.sample_rate - waveform.shape[-1])
        )
    elif waveform.shape[-1] > snare.sample_rate:
        waveform = waveform[:, : snare.sample_rate]

    # Normalize
    waveform = waveform / waveform.abs().max()

    mel = mel_spectrogram(waveform.to(device))
    mel_target.append(mel)

mel_target = torch.vstack(mel_target)

In [None]:
def mel_minmax(audio: torch.Tensor, targets: torch.Tensor):
    """
    Given a batch of audio and a batch of targets, find the minimum and maximum
    errors between the audio and the target.
    """
    minimums = torch.zeros(audio.shape[0], device=audio.device)
    maximums = torch.zeros(audio.shape[0], device=audio.device)
    error = torch.zeros(audio.shape[0], device=audio.device)
    for i in range(audio.shape[0]):
        diff = torch.mean(torch.abs(audio[i] - targets), dim=(-1, -2))
        minimums[i] = torch.min(diff)
        maximums[i] = torch.max(diff)
        error[i] = torch.mean(diff)

    # Iterate through the audio patch and select the best and worst matches
    return minimums, maximums, error

# Mel Embedding

In [None]:
class MelEmbedding(torch.nn.Module):

    def __init__(self, **mel_kwargs):
        super().__init__()
        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(**mel_kwargs)

    def forward(self, audio: torch.Tensor):
        x = self.mel_spectrogram(audio)

        # Summarize
        x_mean = torch.mean(x, dim=(-1))
        x_diff = torch.mean(torch.diff(x, dim=-1), dim=-1)

        return torch.hstack([x_mean, x_diff])

In [None]:
mel_embed = MelEmbedding(
    sample_rate=snare.sample_rate, n_fft=2048, hop_length=128, n_mels=128
)
mel_embed = mel_embed.to(device)

embed = mel_embed(y.to(device))
print(embed.shape)

In [None]:
mel_embed_target = []
for file in audio_files:
    waveform, sample_rate = torchaudio.load(file)

    # Convert to mono
    if waveform.shape[0] > 1:
        waveform = waveform[:1]

    # Resample if necessary
    if sample_rate != model.sample_rate:
        waveform = torchaudio.functional.resample(
            waveform, sample_rate, snare.sample_rate, lowpass_filter_width=512
        )

    # Ensure 1sec length
    if waveform.shape[-1] < snare.sample_rate:
        waveform = torch.nn.functional.pad(
            waveform, (0, snare.sample_rate - waveform.shape[-1])
        )
    elif waveform.shape[-1] > snare.sample_rate:
        waveform = waveform[:, : snare.sample_rate]

    # Normalize
    waveform = waveform / waveform.abs().max()
    mel_embed_target.append(mel_embed(waveform.to(device)))

mel_embed_target = torch.vstack(mel_embed_target)
print(mel_embed_target.shape)

mel_mean = torch.mean(mel_embed_target, dim=0)
mel_cov = torch.cov(mel_embed_target.T)

print(mel_mean.shape, mel_cov.shape)

# EvoTorch

In [None]:
def compute_synth_distance(params: torch.Tensor) -> torch.Tensor:

    # Generate audio
    y = snare(torch.clamp(params, 0.0, 1.0))

    # Resample to CED sample rate
    y_down = torchaudio.functional.resample(
        y, 48000, model.sample_rate, lowpass_filter_width=512
    )

    # Compute embeddings
    with torch.no_grad():
        y_down = y_down.clone()
        emb = model.clip_embedding(y_down)

    # Split into chunks and compute Frechet distance for each chunk
    emb_chunk = torch.chunk(emb, 8, dim=0)
    distances = []
    for chunk in emb_chunk:
        emb_mean = chunk.mean(dim=0)
        emb_cov = torch.cov(chunk.T)

        dist = torchaudio.functional.frechet_distance(mean, cov, emb_mean, emb_cov)
        dist = repeat(dist.unsqueeze(0), "() -> b", b=chunk.shape[0])
        distances.append(dist)

    dist = torch.hstack(distances)

    # Compute frechet on the mel embeddings
    # with torch.no_grad():
    #     emb = mel_embed(y)

    # emb_chunk = torch.chunk(emb, 50, dim=0)
    # distances = []
    # for chunk in emb_chunk:
    #     emb_mean = chunk.mean(dim=0)
    #     emb_cov = torch.cov(chunk.T)

    #     mel_dist = torchaudio.functional.frechet_distance(mel_mean, mel_cov, emb_mean, emb_cov)
    #     mel_dist = repeat(mel_dist.unsqueeze(0), '() -> b', b=chunk.shape[0])
    #     distances.append(mel_dist)

    # mel_dist = torch.hstack(distances)

    # Compute the minimum and maximum distance after normalizing
    max_sample = torch.max(torch.abs(y), dim=-1).values
    y = y / max_sample[:, None]
    mel_audio = mel_spectrogram(y)

    min, max, error = mel_minmax(mel_audio, mel_target)

    # # Minimize the error to the sample with the maximum distance
    # # Could potentially do this in chunks as well.
    # mel_chunk = torch.chunk(mel_audio, 2, dim=0)
    # mel_distances = []
    # for chunk in mel_chunk:
    #     _, max, _ = mel_minmax(mel_target, mel_audio)
    #     max = torch.max(max)
    #     max = repeat(max.unsqueeze(0), '() -> b', b=chunk.shape[0])
    #     mel_distances.append(max)

    # max = torch.hstack(mel_distances)

    # mel_audio = torch.mean(mel_audio, dim=-1)
    # mel_dist = torch.cdist(mel_audio, mel_audio, p=2)
    # mel_dist = torch.mean(mel_dist, dim=-1)
    # mel_dist = mel_dist * 0.0001

    fitness = torch.stack([dist, max], dim=-1)
    return fitness

In [None]:
prob = Problem(
    # Three objectives
    ["min", "min"],
    compute_synth_distance,
    initial_bounds=(0.0, 1.0),
    bounds=(0.0, 1.0),
    solution_length=num_params,
    vectorized=True,
    device=device,
)

ga = SteadyStateGA(prob, popsize=200)
ga.use(
    SimulatedBinaryCrossOver(
        prob,
        tournament_size=4,
        cross_over_rate=1.0,
        eta=8,
    )
)
ga.use(GaussianMutation(prob, stdev=0.3))

logger = StdOutLogger(ga)

In [None]:
ga.run(500)

In [None]:
# params = ga.population.values.clone()
# y = snare(torch.clamp(params, 0.0 ,1.0))

# # Normalize each sample
# y_max = torch.max(torch.abs(y), dim=1).values
# y = y / y_max[:, None]

# y = rearrange(y, 'b n -> 1 (b n)')
# ipd.display(ipd.Audio(y.detach().cpu(), rate=48000))

In [None]:
for p in ga.population[:50]:
    print(p.evals)
    y = snare(torch.clamp(p.values, 0.0, 1.0)[None])
    ipd.display(ipd.Audio(y.detach().cpu(), rate=48000))