In [None]:
import numpy as np
import jax.numpy as jnp
from scipy import signal

In [None]:
from jaxdsp.ddsp import spectral_ops

amp = 1e-2
audio = amp * (np.random.rand(64000).astype(np.float32) * 2.0 - 1.0)
frame_size = 2048
hop_size = 128
overlap = 1.0 - float(hop_size) / frame_size
pad_end = True

s_np = signal.stft(audio,
                   nperseg=int(frame_size),
                   noverlap=int(overlap),
                   nfft=int(frame_size),
                   padded=pad_end)
s_jdsp = spectral_ops.stft(audio, frame_size=frame_size, overlap=overlap, pad_end=pad_end)

np.allclose(s_np[0], s_jdsp[0]), np.allclose(s_np[1], s_jdsp[1]), np.allclose(s_np[2], s_jdsp[2])

In [None]:
import jaxdsp.loss

input_audio = jnp.zeros((2, 8000))
target_audio = jnp.ones((2, 8000))

loss_value = jaxdsp.loss.spectral(input_audio, target_audio)
float(loss_value)

In [None]:
import matplotlib.pyplot as plt

sample_rate = 16000

frequency_diffs = np.linspace(0, 10, 100)

def gen_sinusoid(frequency, amp, sample_rate, audio_len_sec):
    t = np.linspace(0, audio_len_sec, int(audio_len_sec * sample_rate))
    return amp * (np.sin(2 * np.pi * frequency * t))

def loss_for_frequency_diff(frequency_diff):
    f_0 = 440.0
    f_1 = f_0 + frequency_diff
    s_0 = gen_sinusoid(f_0, 1.0, sample_rate, 1.0)
    s_1 = gen_sinusoid(f_1, 1.0, sample_rate, 1.0)
    return jaxdsp.loss.spectral(s_0, s_1)

In [None]:
_ = plt.plot(frequency_diffs, [loss_for_frequency_diff(frequency_diff) for frequency_diff in frequency_diffs])