# Optimizing a Harmonic Synthesizer

In this section we look at using gradient descent to learn parameters for a harmonic
synthesizer to match an instrumental sound.

We build on the harmonic synthesizer from the previous section and add several features
that support gradient-based optimization. These additions are taken directly from Engel et al.'s
differentiable harmonic synthesizer [cite] and include: 
1) constraining harmonic amplitudes to
sum to one; 
2) adding a global amplitude parameter; 
3) parameter scaling to constrain the possible range of amplitudes;
4) removing frequencies above the Nyquist frequency which will result in aliasing;
5) interpolation of parameters from frame rate to sample rate.

The updated formula for our harmonic synthesizer is:

$$
    y[n] = A[n]\sum_{k=1}^{K}\hat{\alpha}_k[n]\sin\left(k\sum_{m=0}^{n}\omega_{0}[m]\right)
$$

where $A[n]$ is a global amplitude parameter, and $\hat{\alpha}_k[n]$ is the normalized
amplitude for the $k^{\text{th}}$ sinusoidal component. $\hat{\alpha}_k[n]$ is normalized
such that $\sum_{k}\hat{\alpha}_k[n] = 1$ and $\hat{\alpha}_k[n] > 0$. $\omega_{0}[n]$ is
a time-varying fundamental frequency that is pre-computed using a pitch extraction algorithm.
Methods for parameter scaling and removing frequencies above the Nyquist frequency will be introduced inline below.

Instead of specifying parameters at a resolution equivalent to the audio sampling rate we'll specify parameters at a frame rate of 100Hz.
This sets a reasonable upper bound on the frequency of change of our control signals and has
the added benefit of decreasing the dimenionsality of the optimization problem. We only
need to learn \~200 values per harmonic for a second of audio at 16kHz instead of 16k!

Finally, we'll use gradient descent with a spectral loss function to match sounds from 
the [NSynth test dataset](https://magenta.tensorflow.org/datasets/nsynth#files) [cite].

In [None]:
import math

import torch
import torchaudio
import matplotlib.pyplot as plt
import matplotlib.ticker as mplticker
from matplotlib.animation import FuncAnimation
import IPython.display as ipd
from tqdm import trange

import crepe
import auraloss

torch.manual_seed(0)

In [None]:
def scale_function(
    x: torch.Tensor,
    exponent: float = 10.0,
    max_value: float = 2.0,
    threshold: float = 1e-7,
):
    """
    Scales a parameter to a range of [threshold, max_value] with a slope of exponent.
    A threshold is used to stabilize the gradient near zero.
    """
    return max_value * torch.sigmoid(x) ** math.log(exponent) + threshold

In [None]:
def remove_above_nyquist(
    harmonic_amps: torch.Tensor,
    frequencies: torch.Tensor,
):
    return harmonic_amps * (frequencies < torch.pi).float()

In [None]:
def get_harmonic_frequencies(f0, num_harmonics):
    # Create integer harmonic ratios and reshape to (1, n_harmonics, 1) so we can
    # multiply with fundamental frequency tensor repeated for num_harmonics
    harmonic_ratios = torch.arange(1, num_harmonics + 1).view(1, -1, 1)

    # Duplicate the fundamental frequency for each harmonic
    frequency = f0.unsqueeze(1).repeat(1, num_harmonics, 1)

    # Multiply the fundamental frequency by the harmonic ratios
    frequency = frequency * harmonic_ratios

    return frequency

In [None]:
def additive_synth(
    frequencies: torch.Tensor,  # Angular frequencies (rad / sample) - frame rate
    amplitudes: torch.Tensor,  # Amplitudes
    n_samples: int,  # Number of samples to synthesize
):
    assert (
        frequencies.ndim == 3
    ), "Frequencies must be 3D (batch, n_frequencies, n_frames)"
    assert (
        frequencies.shape == amplitudes.shape
    ), "Frequency and amplitude shapes must match"

    # Upsample frequency and amplitude envelopes to sample rate
    f_up = torch.nn.functional.interpolate(frequencies, size=n_samples, mode="linear")
    a_up = torch.nn.functional.interpolate(amplitudes, size=n_samples, mode="linear")

    # Set initial phase to zero, prepend to frequency envelope
    initial_phase = torch.zeros_like(f_up[:, :, :1])
    f_up = torch.cat([initial_phase, f_up], dim=-1)[..., :-1]

    # Create the phase track and remove the last sample (since we added initial phase)
    phase = torch.cumsum(f_up, dim=-1)

    y = torch.sin(phase) * a_up
    y = torch.sum(y, dim=1)

    return y

In [None]:
def harmonic_synth(
    f0: torch.Tensor,  # Angular fundamental frequency (batch, n_samples)
    global_amp: torch.Tensor,  # Global amplitude (batch, n_samples)
    harmonic_amps: torch.Tensor,  # Amplitudes of harmonics (batch, n_harmonics, n_samples)
    num_samples: int,  # Number of samples to synthesize
):
    assert f0.ndim == 2, "Fundamental frequency must be 2D (batch, n_samples)"
    assert (
        harmonic_amps.ndim == 3
    ), "Harmonic amplitudes must be 3D (batch, n_harmonics, n_samples)"

    # Get the harmonic frequencies
    frequency = get_harmonic_frequencies(f0, harmonic_amps.shape[1])

    # Scale the amplitudes
    harmonic_amps = scale_function(harmonic_amps)
    global_amp = scale_function(global_amp)

    # Remove frequencies above Nyquist
    harmonic_amps = remove_above_nyquist(harmonic_amps, frequency)

    # Normalize the harmonic amplitudes
    harmonic_amps = harmonic_amps / torch.sum(harmonic_amps, dim=1, keepdim=True)

    # Multiply all harmonic amplitudes by the global amplitude
    harmonic_amps = harmonic_amps * global_amp.unsqueeze(1)

    return additive_synth(frequency, harmonic_amps, num_samples)

In [None]:
# Load audio
audio, sample_rate = torchaudio.load("../audio/reed_acoustic_011-045-050.wav")

# Extract the first 3.25 seconds of the audio (chop silence from the end)
audio = audio[:, : int(sample_rate * 2.0)]

ipd.Audio(audio.numpy(), rate=sample_rate)

Pitch extraction

In [None]:
frame_rate = 100  # Hz
frame_ms = 1000 / frame_rate  # ms

_, f0, _, _ = crepe.predict(
    audio.numpy()[0], sample_rate, step_size=frame_ms, viterbi=True
)
f0 = torch.from_numpy(f0).unsqueeze(0)

In [None]:
plt.plot(f0[0].numpy())
print(f0.min(), f0.max())

print(f0.shape)

In [None]:
def plot_spectrogram(
    x: torch.Tensor, sample_rate: int, fig: plt.Figure = None, ax: plt.Axes = None
):
    n_fft = 2048
    hop_length = 512
    X = torch.stft(
        x,
        n_fft=n_fft,
        hop_length=hop_length,
        return_complex=True,
        window=torch.hann_window(n_fft),
    )

    # Convert to decibels
    X_mag = torch.abs(X)
    X_db = 20.0 * torch.log10(X_mag + 1e-6)

    # Get frequencies for each FFT bin in hertz
    fft_freqs = torch.abs(torch.fft.fftfreq(2048, 1 / sample_rate)[: X_db.shape[1]])

    # Time in seconds for each frame
    times = torch.arange(X_db.shape[-1]) * hop_length / sample_rate

    # Plot the spectrogram
    if fig is None and ax is None:
        fig, ax = plt.subplots()

    ax.pcolormesh(times, fft_freqs, X_db[0].numpy())

    # Set the y-axis to log scale
    ax.set_yscale("symlog", base=2.0)
    ax.set_ylim(40.0, 8000.0)

    ax.yaxis.set_major_formatter(mplticker.ScalarFormatter())
    ax.yaxis.set_label_text("Frequency (Hz)")

    ax.xaxis.set_major_formatter(mplticker.ScalarFormatter())
    ax.xaxis.set_label_text("Time (Seconds)")

    return fig, ax, times, fft_freqs


fig, ax, xaxis, yaxis = plot_spectrogram(audio, sample_rate)

f0_interp = f0.unsqueeze(0)
f0_interp = torch.nn.functional.interpolate(
    f0_interp, size=xaxis.shape[0], mode="linear"
)
ax.plot(xaxis, f0_interp[0, 0].numpy(), color="red", label="Detected F0")

ax.legend()
ax.set_title("Target Spectrogram with F0")
fig.tight_layout()
plt.show()

## Multiresolution Spectral Loss

[todo] write own code

In [None]:
n_ffts = [2048, 1024, 512, 256, 128, 64]
hop_sizes = [n // 4 for n in n_ffts]
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=n_ffts,
    hop_sizes=hop_sizes,
    win_lengths=n_ffts,
    w_sc=0.0,
    w_lin_mag=1.0,
    w_log_mag=1.0,
)

## Optimization

In [None]:
# Convert fundamental frequency to angular frequency
w0 = f0 * 2 * torch.pi / sample_rate

# Create a harmonic amplitude envelope
num_harmonics = 80
harmonic_amplitudes = torch.randn(1, num_harmonics, f0.shape[-1])

# Create PyTorch parameters - these are the variables we will optimize
harmonic_amplitudes = torch.nn.Parameter(harmonic_amplitudes)
global_amp = torch.nn.Parameter(torch.randn_like(w0))

In [None]:
optimizer = torch.optim.Adam([harmonic_amplitudes, global_amp], lr=0.05)

In [None]:
y_hat = harmonic_synth(w0, global_amp, harmonic_amplitudes, audio.shape[-1])
fig, ax, *_ = plot_spectrogram(y_hat.detach(), sample_rate)

ax.set_title("Initial Randomized Harmonic Synthesis")
plt.tight_layout()
plt.show()

In [None]:
loss_log = []
audio_log = []
t = trange(1000, desc="Error", leave=True)
for i in t:
    # 1. Compute a forward pass using our learned parameter
    y_pred = harmonic_synth(w0, global_amp, harmonic_amplitudes, audio.shape[-1])

    # 2. Compute multiresolution spectral resolution loss
    loss = loss_fn(audio.unsqueeze(0), y_pred.unsqueeze(0))

    # Store the current loss value for plotting later
    loss_log.append(loss.item())

    # 3. Reset gradients
    optimizer.zero_grad()

    # 4. Compute the gradients
    loss.backward()

    # 5. Update the parameters
    optimizer.step()

    # Log audio and loss
    if i % 50 == 0:
        audio_log.append(y_pred.detach().cpu().numpy())

    t.set_description(f"Error: {loss.detach().cpu().numpy()}")
    t.refresh()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
fig, im_ax, *_ = plot_spectrogram(y_hat.detach(), sample_rate, fig=fig, ax=axes[0])

axes[1].set_ylim(0.0, max(loss_log))
axes[1].set_xlim(0, len(loss_log))

(line,) = axes[1].plot([], [], lw=2)

num_frames = len(audio_log)


def animate(i):
    iteration = i * 50
    im_ax.set_title(f"Harmonic Synthesis Iteration: {iteration}")
    _ = plot_spectrogram(torch.from_numpy(audio_log[i]), sample_rate, fig=fig, ax=im_ax)
    axes[1].set_title("Loss: {:.4f}".format(loss_log[iteration]))
    line.set_data(torch.arange(iteration), loss_log[:iteration])
    return (line,)


# Create the animation
anim = FuncAnimation(fig, animate, frames=len(audio_log), interval=250, blit=True)

plt.close(fig)
# To display the animation in the Jupyter notebook:
display(ipd.HTML(anim.to_html5_video()))

In [None]:
ipd.Audio(y_pred[0].detach().numpy(), rate=sample_rate)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

_ = plot_spectrogram(y_pred.detach(), sample_rate, fig, axes[0])
_ = plot_spectrogram(audio, sample_rate, fig, axes[1])

axes[0].set_title("Optimized Harmonic Synthesis")
axes[1].set_title("Target Audio")

plt.tight_layout()
plt.show()