In [None]:
import torch

import IPython.display as ipd
import lightning as L

from synthmap.data.fitness import MultiScaleSpectralFitness
from synthmap.data.genetic import GeneticSynthDataLoader
from synthmap.synth import Snare808
from synthmap.utils.model_utils import load_model
from synthmap.utils.audio_utils import load_audio

%load_ext autoreload
%autoreload 2

In [None]:
TARGET = "audio/kicks/BD 808 Sat Click Decay B 02.wav"
VAE_CFG = "cfg/param_vae.yaml"

In [None]:
snare = Snare808(48000, 48000)

num_params = snare.get_num_params()
params = torch.rand(1, num_params)

In [None]:
mel_fitness = MultiScaleSpectralFitness(TARGET, 48000, 48000, fft_sizes=[2048], scale="mel", n_bins=128)
stft_fitness = MultiScaleSpectralFitness(TARGET, 48000, 48000, fft_sizes=[1024, 512, 256, 64], w_sc=1.0, w_log_mag=0.0, w_lin_mag=1.0, sum_loss=True)

In [None]:
dataloader = GeneticSynthDataLoader(snare, 10, 128, fitness_fns=[mel_fitness, stft_fitness], verbose=False, reset_on_epoch=False, device="cuda")

In [None]:
# Pre-seed
config = "lightning_logs/wandb/run-20240419_095659-7vtkk1s1/files/model-config.yaml"
ckpt = "lightning_logs/wandb/run-20240419_095659-7vtkk1s1/epoch=9-step=312500.ckpt"
model, synth = load_model(config, ckpt, return_synth=True, device="cuda")

In [None]:
audio = load_audio(TARGET, 48000, length=48000)

with torch.no_grad():
    params, _, _ = model(audio=audio)

print(params)

ipd.Audio(synth(params).cpu().numpy(), rate=48000)

In [None]:
population = dataloader.ga.population.access_values()
population[0] = params[0]

In [None]:
for i, batch in enumerate(dataloader):
    preset, audio = batch

In [None]:
best_params = dataloader.ga.population.take_best(1).values.clone()
print(best_params)

In [None]:
y_hat = synth(best_params)

ipd.Audio(y_hat.cpu().numpy(), rate=48000)

In [None]:
audio = load_audio(TARGET, 48000, length=48000)
ipd.Audio(audio[0].cpu().numpy(), rate=48000)

In [None]:
task = load_model(VAE_CFG, device="cpu", load_data=True)

In [None]:
trainer = L.Trainer(accelerator="gpu", devices=1, max_epochs=50)
trainer.fit(task, train_dataloaders=dataloader)

In [None]:
p_in = dataloader.ga.population.take_best(1).values.clone()
p_out, _, _ = task(params=p_in.to(task.device))

out = torch.clamp(p_out, 0.0, 1)
ipd.Audio(synth(out).detach().cpu().numpy(), rate=48000)

In [None]:
p_in = dataloader.ga.population.take_best(1).values.clone()
ipd.Audio(synth(p_in).detach().cpu().numpy(), rate=48000)