<a href="https://colab.research.google.com/github/michalshavitNYU/michalshavitnyu.github.io/blob/master/Periodic_Hilbert_Transform.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax.numpy as jnp

def x_grid(N):
    """Uniform grid on [-π, π)."""
    return jnp.linspace(-jnp.pi, jnp.pi, N, endpoint=False)


In [1]:
def hilbert_periodic(values):
    """
    Periodic Hilbert transform on uniform grid over [-π, π).
    values: (..., N)
    """
    N = values.shape[-1]
    Vhat = jnp.fft.fft(values, axis=-1)
    k = jnp.fft.fftfreq(N, d=1.0/N)
    s = jnp.sign(k)
    if N % 2 == 0:
        s = s.at[N//2].set(0.0)
    s = s.at[0].set(0.0)
    Hhat = Vhat * (-1j) * s
    H = jnp.fft.ifft(Hhat, axis=-1)
    return jnp.real(H)


In [None]:
N = 256
x = x_grid(N)
f = jnp.cos(3*x) + 0.2*jnp.sin(5*x)

# smooth taper
f_smooth = f * smooth_window(x, delta=0.15*jnp.pi)

# apply Hilbert transform
Hf = hilbert_periodic(f_smooth)
