In [None]:

import IPython.display as ipd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
import torch

from synthmap.synth import Snare808
from synthmap.params import DiscretizedNumericalParameters

%load_ext autoreload
%autoreload 2

In [None]:
snare = Snare808(48000, 48000)

num_params = snare.get_num_params()
params = torch.rand(2, num_params)

y = snare(params)

ipd.display(ipd.Audio(y, rate=48000))

In [None]:
dp = DiscretizedNumericalParameters(params.shape[-1], 64)

print(dp.num_discrete_params)

one_hot = dp.discretize(params)
print(one_hot.shape)

In [None]:
x_1 = one_hot[0, :, 0].numpy()

x_smooth = gaussian_filter1d(x_1, sigma=2.0)
x_smooth = torch.from_numpy(x_smooth)
x_smooth = x_smooth / torch.sum(x_smooth)

In [None]:
plt.plot(x_1)
plt.plot(x_smooth)

In [None]:
loss = torch.nn.CrossEntropyLoss()

In [None]:
logits = torch.rand(2, dp.num_discrete_params, requires_grad=True)
print(logits.shape)

grouped = dp.group_parameters(logits)
print(grouped.shape)

In [None]:
optim = torch.optim.Adam([logits], lr=0.01)

In [None]:
for i in range(1000):
    optim.zero_grad()
    grouped = dp.group_parameters(logits)
    loss_val = loss(grouped, one_hot)
    loss_val.backward()
    optim.step()

    if i % 100 == 0:
        print(loss_val.item())

In [None]:
p_hat = dp.inverse(grouped)
print(p_hat)

In [None]:
print(params)