# Optimizing a Harmonic Synthesizer

In this section we look at using gradient descent to learn parameters for a harmonic
synthesizer to match an instrumental sound.

We build on the harmonic synthesizer from the previous section and add several features
that support gradient-based optimization. These additions are taken directly from Engel et al.'s
differentiable harmonic synthesizer [cite] and include: 
1) constraining harmonic amplitudes to
sum to one; 
2) adding a global amplitude parameter; 
3) parameter scaling to constrain the possible range of amplitudes;
4) removing frequencies above the Nyquist frequency which will result in aliasing;
5) interpolation of parameters from frame rate to sample rate.

The updated formula for our harmonic synthesizer is:

$$
    y[n] = A[n]\sum_{k=1}^{K}\hat{\alpha}_k[n]\sin\left(k\sum_{m=0}^{n}\omega_{0}[m]\right)
$$

where $A[n]$ is a global amplitude parameter, and $\hat{\alpha}_k[n]$ is the normalized
amplitude for the $k^{\text{th}}$ sinusoidal component. $\hat{\alpha}_k[n]$ is normalized
such that $\sum_{k}\hat{\alpha}_k[n] = 1$ and $\hat{\alpha}_k[n] > 0$. $\omega_{0}[n]$ is
a time-varying fundamental frequency that is pre-computed using a pitch extraction algorithm.
Methods for parameter scaling and removing frequencies above the Nyquist frequency will be introduced inline below.

Instead of specifying parameters at a resolution equivalent to the audio sampling rate we'll specify parameters at a frame rate of 100Hz.
This sets a reasonable upper bound on the frequency of change of our control signals and has
the added benefit of decreasing the dimenionsality of the optimization problem. We only
need to learn \~200 values per harmonic for a second of audio at 16kHz instead of 16k!

Finally, we'll use gradient descent with a spectral loss function to match sounds from 
the [NSynth test dataset](https://magenta.tensorflow.org/datasets/nsynth#files) [cite].

In [None]:
import math

import torch
import torchaudio
import matplotlib.pyplot as plt
import IPython.display as ipd
from tqdm import trange

import crepe
import auraloss

In [None]:
def scale_function(
    x: torch.Tensor,
    exponent: float = 10.0,
    max_value: float = 2.0,
    threshold: float = 1e-7,
):
    """
    Scales a parameter to a range of [threshold, max_value] with a slope of exponent.
    A threshold is used to stabilize the gradient near zero.
    """
    return max_value * torch.sigmoid(x) ** math.log(exponent) + threshold

In [None]:
def remove_above_nyquist(harmonic_amps, frequencies):
    harmonic_amps = harmonic_amps * (frequencies < torch.pi).float()

In [None]:
def get_harmonic_frequencies(f0, num_harmonics):
    # Create integer harmonic ratios and reshape to (1, n_harmonics, 1) so we can
    # multiply with fundamental frequency tensor repeated for num_harmonics
    harmonic_ratios = torch.arange(1, num_harmonics + 1).view(1, -1, 1)

    # Duplicate the fundamental frequency for each harmonic
    frequency = f0.unsqueeze(1).repeat(1, num_harmonics, 1)

    # Multiply the fundamental frequency by the harmonic ratios
    frequency = frequency * harmonic_ratios

    return frequency

In [None]:
def additive_synth(
    frequencies: torch.Tensor,  # Angular frequencies (rad / sample) - frame rate
    amplitudes: torch.Tensor,  # Amplitudes
    n_samples: int,  # Number of samples to synthesize
):
    assert (
        frequencies.ndim == 3
    ), "Frequencies must be 3D (batch, n_frequencies, n_frames)"
    assert (
        frequencies.shape == amplitudes.shape
    ), "Frequency and amplitude shapes must match"

    # Upsample frequency and amplitude envelopes to sample rate
    f_up = torch.nn.functional.interpolate(frequencies, size=n_samples, mode="linear")
    a_up = torch.nn.functional.interpolate(amplitudes, size=n_samples, mode="linear")

    # Set initial phase to zero, prepend to frequency envelope
    initial_phase = torch.zeros_like(f_up[:, :, :1])
    f_up = torch.cat([initial_phase, f_up], dim=-1)[..., :-1]

    # Create the phase track and remove the last sample (since we added initial phase)
    phase = torch.cumsum(f_up, dim=-1)

    y = torch.sin(phase) * a_up
    y = torch.sum(y, dim=1)

    return y

In [None]:
def harmonic_synth(
    f0: torch.Tensor,  # Angular fundamental frequency (batch, n_samples)
    harmonic_amps: torch.Tensor,  # Amplitudes of harmonics (batch, n_harmonics, n_samples)
    num_samples: int,  # Number of samples to synthesize
    global_amp: torch.Tensor = None,  # Global amplitude, applied to all partials
    normalize: bool = True,
):
    assert f0.ndim == 2, "Fundamental frequency must be 2D (batch, n_samples)"
    assert (
        harmonic_amps.ndim == 3
    ), "Harmonic amplitudes must be 3D (batch, n_harmonics, n_samples)"

    # Get the harmonic frequencies
    frequency = get_harmonic_frequencies(f0, harmonic_amps.shape[1])

    # Scale the amplitudes
    harmonic_amps = scale_function(harmonic_amps)

    # Remove frequencies above Nyquist
    harmonic_amps = harmonic_amps * (frequency < torch.pi).float()

    # Normalize amplitudes to sum to 1 at each sample
    if normalize:
        harmonic_amps = harmonic_amps / torch.sum(harmonic_amps, dim=1, keepdim=True)

    # If no global amplitude is provided, apply a static amplitude of 1
    if global_amp is not None:
        global_amp = scale_function(global_amp)
        harmonic_amps = harmonic_amps * global_amp.unsqueeze(1)

    return additive_synth(frequency, harmonic_amps, num_samples)

In [None]:
# Load audio
audio, sample_rate = torchaudio.load("../audio/reed_acoustic_011-045-050.wav")

# Extract the first 3.25 seconds of the audio (chop silence from the end)
audio = audio[:, : int(sample_rate * 2.0)]

ipd.Audio(audio.numpy(), rate=sample_rate)

In [None]:
hop_length = int(sample_rate / 200.0)
print(f"hop_length: {hop_length}")

In [None]:
# frame_rate = 100 # Hz
# step_size = 1000.0 / frame_rate # ms

# _, f0, _, _ = pesto.predict(audio, sample_rate, step_size=step_size, convert_to_freq=True)

# plt.plot(f0)

In [None]:
# frame_rate = 100 # Hz

# hop_length = int(sample_rate / frame_rate) # milliseconds

# print(hop_length)
# print(sample_rate)
# f0 = torchcrepe.predict(audio, sample_rate, hop_length=hop_length, batch_size=128, device="cpu", model = 'full')

# plt.plot(f0[0].numpy())

In [None]:
frame_rate = 100  # Hz
# f0 = torchaudio.functional.detect_pitch_frequency(audio, sample_rate, frame_time=1.0 / frame_rate, win_length=30)

# hop_length = int(sample_rate / frame_rate)
# f0, _, _ = librosa.pyin(audio.numpy()[0], fmin=50, fmax=2000, sr=sample_rate, hop_length=hop_length)
# f0 = torch.from_numpy(f0).unsqueeze(0)
# f0 = torch.nan_to_num(f0, nan=0.0)

time, frequency, confidence, activation = crepe.predict(
    audio.numpy()[0], sample_rate, step_size=1000 / frame_rate, viterbi=True
)
f0 = torch.from_numpy(frequency).unsqueeze(0)

# timesteps, pitch, confidence, activations = pesto.predict(audio, sample_rate, step_size=1000.0/frame_rate)
# f0 = pitch.unsqueeze(0)

plt.plot(f0[0].numpy())
print(f0.min(), f0.max())

print(f0.shape)

In [None]:
X = torch.stft(
    audio,
    n_fft=2048,
    hop_length=512,
    return_complex=True,
    window=torch.hann_window(2048),
)
X_mag = torch.abs(X)
X_db = 20.0 * torch.log10(X_mag + 1e-6)

plt.imshow(X_db[0].numpy(), aspect="auto", origin="lower")

f0_bins = f0 * 2048 / sample_rate
f0_bins = f0_bins.unsqueeze(0)
f0_bins = torch.nn.functional.interpolate(f0_bins, size=X_db.shape[2], mode="linear")
# plt.plot(f0_bins[0,0].numpy(), 'r')
# plt.plot(f0_bins[0,0].numpy() * 4.0, 'r')

# TODO - fix the scaling on this
y_ticks = plt.yticks()
y_tick = torch.logspace(5, 13, 6, base=2.0)

In [None]:
def additive_synth(
    frequencies: torch.Tensor,  # Angular frequencies (rad / sample) - frame rate
    amplitudes: torch.Tensor,  # Amplitudes
    n_samples: int,  # Number of samples to synthesize
):
    assert (
        frequencies.ndim == 3
    ), "Frequencies must be 3D (batch, n_frequencies, n_frames)"
    assert (
        frequencies.shape == amplitudes.shape
    ), "Frequency and amplitude shapes must match"

    # Upsample frequency and amplitude envelopes to sample rate
    f_up = torch.nn.functional.interpolate(frequencies, size=n_samples, mode="linear")
    a_up = torch.nn.functional.interpolate(amplitudes, size=n_samples, mode="linear")

    # Set initial phase to zero, prepend to frequency envelope
    initial_phase = torch.zeros_like(f_up[:, :, :1])
    f_up = torch.cat([initial_phase, f_up], dim=-1)[..., :-1]

    # Create the phase track and remove the last sample (since we added initial phase)
    phase = torch.cumsum(f_up, dim=-1)

    y = torch.sin(phase) * a_up
    y = torch.sum(y, dim=1)

    return y

In [None]:
def scale_function(x):
    return 2.0 * torch.sigmoid(x) ** math.log(10.0) + 1e-7

In [None]:
def get_harmonic_frequencies(f0, num_harmonics):
    # Create integer harmonic ratios and reshape to (1, n_harmonics, 1) so we can
    # multiply with fundamental frequency tensor repeated for num_harmonics
    harmonic_ratios = torch.arange(1, num_harmonics + 1).view(1, -1, 1)

    # Duplicate the fundamental frequency for each harmonic
    frequency = f0.unsqueeze(1).repeat(1, num_harmonics, 1)

    # Multiply the fundamental frequency by the harmonic ratios
    frequency = frequency * harmonic_ratios

    return frequency

In [None]:
def harmonic_synth(
    f0: torch.Tensor,  # Angular fundamental frequency (batch, n_samples)
    harmonic_amps: torch.Tensor,  # Amplitudes of harmonics (batch, n_harmonics, n_samples)
    num_samples: int,  # Number of samples to synthesize
    global_amp: torch.Tensor = None,  # Global amplitude, applied to all partials
    normalize: bool = True,
):
    assert f0.ndim == 2, "Fundamental frequency must be 2D (batch, n_samples)"
    assert (
        harmonic_amps.ndim == 3
    ), "Harmonic amplitudes must be 3D (batch, n_harmonics, n_samples)"

    # Get the harmonic frequencies
    frequency = get_harmonic_frequencies(f0, harmonic_amps.shape[1])

    # Scale the amplitudes
    harmonic_amps = scale_function(harmonic_amps)

    # Remove frequencies above Nyquist
    harmonic_amps = harmonic_amps * (frequency < torch.pi).float()

    # Normalize amplitudes to sum to 1 at each sample
    if normalize:
        harmonic_amps = harmonic_amps / torch.sum(harmonic_amps, dim=1, keepdim=True)

    # If no global amplitude is provided, apply a static amplitude of 1
    if global_amp is not None:
        global_amp = scale_function(global_amp)
        harmonic_amps = harmonic_amps * global_amp.unsqueeze(1)

    return additive_synth(frequency, harmonic_amps, num_samples)

In [None]:
# Convert to angular frequency
w0 = f0 * 2 * torch.pi / sample_rate

In [None]:
num_harmonics = 80

amplitudes = 1.0 / (torch.arange(num_harmonics) + 1).float()
print(amplitudes)
amplitudes = torch.ones(1, num_harmonics, w0.shape[-1]) * amplitudes.view(1, -1, 1)

In [None]:
print(w0.shape, amplitudes.shape)

In [None]:
y = harmonic_synth(w0, amplitudes, audio.shape[-1])

In [None]:
ipd.Audio(y[0].numpy(), rate=sample_rate)

In [None]:
N = 512
X = torch.stft(
    y, n_fft=N, hop_length=N // 4, return_complex=True, window=torch.hann_window(N)
)
X_mag = torch.abs(X)
X_db = 20.0 * torch.log10(X_mag + 1e-6)

plt.imshow(X_db[0].numpy(), aspect="auto", origin="lower")

In [None]:
n_ffts = [2048, 1024, 512, 256, 128, 64]
# n_ffts = [128]
hop_sizes = [n // 4 for n in n_ffts]
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=n_ffts,
    hop_sizes=hop_sizes,
    win_lengths=n_ffts,
    w_sc=0.0,
    w_lin_mag=1.0,
    w_log_mag=1.0,
)
# loss_fn = auraloss.freq.MultiResolutionSTFTLoss()

In [None]:
# TODO: zero-out frequencies above Nyquist
amp_param = torch.randn_like(amplitudes)
frequencies = get_harmonic_frequencies(w0, amp_param.shape[1])
# amp_param = amp_param * (frequencies < torch.pi).float()

amp_param = torch.nn.Parameter(amp_param)
global_amp = torch.nn.Parameter(torch.rand_like(w0))

optimizer = torch.optim.Adam([amp_param, global_amp], lr=0.05)
# optimizer = torch.optim.SGD([amp_param, global_amp], lr=10.0)

In [None]:
plt.imshow(
    scale_function(amp_param)[0].detach().numpy(),
    aspect="auto",
    origin="lower",
    interpolation="none",
)

In [None]:
loss_log = []
t = trange(1000, desc="Error", leave=True)
for i in t:
    # 1. Compute a forward pass using our learned parameter
    y_pred = harmonic_synth(w0, amp_param, audio.shape[-1], global_amp=global_amp)

    # 2. Compute multiresolution spectral resolution loss
    loss = loss_fn(
        audio.unsqueeze(0), y_pred.unsqueeze(0)
    )  # + 0.1 * torch.mean(amp_param)

    # Store the current loss value for plotting later
    loss_log.append(loss.item())

    # 3. Reset gradients
    optimizer.zero_grad()

    # 4. Compute the gradients
    loss.backward()

    # 5. Update the parameters
    optimizer.step()

    t.set_description(f"Error: {loss.detach().cpu().numpy()}")
    t.refresh()

In [None]:
torch.mean(torch.abs(amp_param * (frequencies > torch.pi).float()))

In [None]:
plt.plot(loss_log)

In [None]:
ipd.Audio(y_pred[0].detach().numpy(), rate=sample_rate)

In [None]:
X = torch.stft(
    y_pred.detach(),
    n_fft=2048,
    hop_length=512,
    return_complex=True,
    window=torch.hann_window(2048),
)
X_mag = torch.abs(X)
X_db = 20.0 * torch.log10(X_mag)

plt.imshow(X_db[0].numpy(), aspect="auto", origin="lower")


# TODO - fix the scaling on this
y_ticks = plt.yticks()
y_tick = torch.logspace(5, 13, 6, base=2.0)

In [None]:
plt.plot(scale_function(global_amp[0]).detach().numpy())

In [None]:
plt.imshow(
    scale_function(amp_param)[0].detach().numpy(),
    aspect="auto",
    origin="lower",
    interpolation="none",
)

In [None]:
scaled_amp = scale_function(amp_param)
print(scaled_amp.shape)
print(scaled_amp.min(), scaled_amp.max())

In [None]:
scale_function(torch.tensor(-20.0))