In [None]:
import numpy as np
import jax.numpy as jnp
from jax import value_and_grad, jit
from scipy import signal

In [None]:
from jaxdsp.loss import LossOptions, loss_fn, stft

amp = 1e-2
sample_rate = 16_000
audio = amp * (np.random.rand(sample_rate * 4).astype(np.float32) * 2.0 - 1.0)
frame_size = 2048
hop_size = 128
overlap = 1.0 - float(hop_size) / frame_size
pad_end = True

s_np = signal.stft(audio,
                   nperseg=int(frame_size),
                   noverlap=int(overlap),
                   nfft=int(frame_size),
                   padded=pad_end)
s_jdsp = stft(audio, frame_size=frame_size, overlap=overlap, pad_end=pad_end)

np.allclose(s_np[0], s_jdsp[0]), np.allclose(s_np[1], s_jdsp[1]), np.allclose(s_np[2], s_jdsp[2])

In [None]:
input_audio = jnp.zeros((2, sample_rate // 2))
target_audio = jnp.ones(input_audio.shape)

spectral_loss_opts = LossOptions(
    weights={
        "cumsum_freq": 1.0,
    },
    distance_types={
        "sample": "L2",
        "frequency": "L2",
    }
)
loss_value = loss_fn(input_audio, target_audio, spectral_loss_opts)
float(loss_value)

In [None]:
import matplotlib.pyplot as plt

@jit
def gen_sinusoid(frequency, amplitude=1.0, length_seconds=1.0):
    t = jnp.linspace(0.0, length_seconds, int(length_seconds * sample_rate))
    return amplitude * jnp.sin(2 * jnp.pi * frequency * t)

target_frequency = 443.0
frequencies = np.linspace(420.0, 444.0, 100)

def loss_for_frequency(frequency):
    X = gen_sinusoid(frequency)
    Y = gen_sinusoid(target_frequency)
    return loss_fn(X, Y, spectral_loss_opts)

In [None]:
# TODO this should make the blog post:
# There are tiny non-convex ranges of the loss fn when small fft sizes are included in
# multi-spectral loss
# Plotting est frequency vs loss with L1 distance shows a V with occasional tiny bumps.
# The fn is optimizable between these bumps, but not across.
# (all fft_sizes are enabled, fft_sizes=(2048, 1024, 512, 256, 128, 64), and cumsum_freq_weight=1.0)
# TODO I expect to see a relationship between minimum optimizable frequency, and the fft sizes that allow
# smooth gradients. Note that a 20 Hz sine sampled at 16kHz has a period of 800 samples.
# TODO show sine frequency optimization across a 20-16000 Hz range for a 44100 Hz sample rate
_ = plt.plot(frequencies, [loss_for_frequency(frequency) for frequency in frequencies])

In [None]:
estimated_frequency = 400.0

estimated_frequencies = []
losses = []
grad_fn = value_and_grad(loss_for_frequency)
learning_rate = 10.0
for _ in range(100):
    loss, grad_value = grad_fn(estimated_frequency)
    estimated_frequency -= grad_value * learning_rate

    losses.append(loss)
    estimated_frequencies.append(estimated_frequency)

estimated_frequency

In [None]:
plt.title('Estimated frequency over time')
plt.plot(estimated_frequencies)
_ = plt.axhline(y=target_frequency, c="g", linestyle="--", label='Target frequency')

In [None]:
plt.title('Loss over time')
_ = plt.plot(losses)