In [None]:
from einops import rearrange
import IPython.display as ipd
import lightning as L
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
import torch
from tqdm import tqdm

from synthmap.synth import Snare808
from synthmap.data import SynthesizerDataModule
from synthmap.model import MLP
from synthmap.model import AutoEncoder
from synthmap.task import SynthMapTask
from synthmap.params import DiscretizedNumericalParameters

%load_ext autoreload
%autoreload 2

In [None]:
snare = Snare808(48000, 48000)

num_params = snare.get_num_params()
params = torch.rand(2, num_params)

y = snare(params)

ipd.display(ipd.Audio(y, rate=48000))

In [None]:
dp = DiscretizedNumericalParameters(params.shape[-1], 32)

In [None]:
loss = torch.nn.CrossEntropyLoss()

In [None]:
bottleneck = 16
variational = True
encoder_bottleneck = bottleneck * 2 if variational else bottleneck

encoder = MLP(
    num_params,
    256,
    encoder_bottleneck,
    3,
    torch.nn.LeakyReLU(),
    layer_norm=True,
    init_std=0.1,
)
decoder = MLP(
    bottleneck,
    256,
    dp.num_discrete_params,
    3,
    torch.nn.LeakyReLU(),
    layer_norm=True,
    init_std=0.1,
)
vae = AutoEncoder(encoder, decoder, bottleneck=variational, beta=0.01)

# vae(params)

In [None]:
synthmap = SynthMapTask(vae, lr=1e-3, param_discretizer=dp, loss_fn=loss)

data = SynthesizerDataModule(
    synth=snare, batch_size=64, num_train=100000, return_sound=False
)
data.setup("fit")
train_dataloader = data.train_dataloader()

In [None]:
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = L.Trainer(max_epochs=4, accelerator=accelerator)

trainer.fit(synthmap, train_dataloaders=train_dataloader)

In [None]:
params = torch.rand(2, num_params)
y, _, _ = synthmap(params.to(synthmap.device))

p_hat = dp.inverse(dp.group_parameters(y))

ipd.display(ipd.Audio(snare(p_hat[0:1]), rate=48000))
ipd.display(ipd.Audio(snare(params[0:1]), rate=48000))

print(p_hat)
print(params)

In [None]:
z = synthmap.autoencoder.encoder(params.to(synthmap.device))
z, _ = synthmap.autoencoder.bottleneck(z)

z[0, 4:16] = 0.1

y = synthmap.autoencoder.decoder(z)

p_hat = dp.inverse(dp.group_parameters(y))
ipd.display(ipd.Audio(snare(p_hat[0:1]), rate=48000))