# Example Notebook

In [None]:
import IPython.display as ipd
import lightning as L
import torch
from tqdm import tqdm

from synthmap.synth import Snare808
from synthmap.data import SynthesizerDataModule
from synthmap.model import MLP
from synthmap.model import AutoEncoder
from synthmap.task import SynthMapTask
from synthmap.params import DiscretizedNumericalParameters

%load_ext autoreload
%autoreload 2

## Synthesizer

In [None]:
snare = Snare808(48000, 48000)

num_params = snare.get_num_params()
params = torch.rand(1, num_params)

y = snare(params)

ipd.display(ipd.Audio(y, rate=48000))

## Datamodule

In [None]:
data = SynthesizerDataModule(
    synth=snare, batch_size=64, num_train=100000, return_sound=True
)
data.setup("fit")
train_dataloader = data.train_dataloader()

y, params = next(iter(train_dataloader))

In [None]:
print(y.shape, params.shape)
ipd.display(ipd.Audio(y[0], rate=48000))

## Encoder / Decoder

In [None]:
bottleneck = 8
variational = True
encoder_bottleneck = bottleneck * 2 if variational else bottleneck

encoder = MLP(num_params, 256, encoder_bottleneck, 3, torch.nn.ReLU(), layer_norm=True)
decoder = MLP(bottleneck, 256, num_params, 3, torch.nn.ReLU(), layer_norm=True)
vae = AutoEncoder(encoder, decoder, bottleneck=variational, beta=1e-8)

y_hat, _, _ = vae(params)
print(y_hat.shape)

## Task

In [None]:
synthmap = SynthMapTask(vae, lr=1e-3)

data = SynthesizerDataModule(
    synth=snare, batch_size=8, num_train=100000, return_sound=False
)
data.setup("fit")
train_dataloader = data.train_dataloader()

In [None]:
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = L.Trainer(max_epochs=4, accelerator=accelerator)

trainer.fit(synthmap, train_dataloaders=train_dataloader)

In [None]:
params = torch.rand(1, num_params, device=synthmap.device)
y = snare(params)
ipd.display(ipd.Audio(y.detach().cpu(), rate=48000))

p_hat, _, _ = synthmap.forward(params)
y_hat = snare(torch.clamp(p_hat, 0.0, 1.0))
ipd.display(ipd.Audio(y_hat.detach().cpu(), rate=48000))

In [None]:
print(params)
vae(params)

In [None]:
torch.nn.functional.l1_loss(params, p_hat)