<a href="https://colab.research.google.com/github/davidwhogg/StarDemodulator/blob/main/notebooks/wrap_finufft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# wrap `finufft` so it looks like `scipy.fftpack`

## Author:
- **David W. Hogg** *(NYU) (MPIA) (Flatiron)*

## To-do items:
- Make regularly sampled data for performing exact comparisons to the `scipy` FFT.
- Rearrange outputs ordering and units to match `scipy`.
- Remove dependency on dumb pickle file.

In [None]:
!pip install finufft

In [None]:
import numpy as np
import scipy.sparse.linalg as sp
from functools import partial
from scipy.fftpack import fftfreq
import finufft

In [None]:
FEPS, ATOL, BTOL = 1.e-6, 1.e-6, 1.e-6 # made up

def nufft1d2_pinv(x, c, N):
    """
    The pseudo-inverse of `nufft1d2()`.
    """
    M = len(x)
    f0 = finufft.nufft1d1(x, c, N, eps=FEPS)
    R = partial(finufft.nufft1d2, x, eps=FEPS)
    RT = partial(finufft.nufft1d1, x, n_modes=N, eps=FEPS)
    RR = sp.LinearOperator((M, N), matvec=R, rmatvec=RT, dtype=complex)
    res = sp.lsqr(RR, c, x0=f0, atol=ATOL, btol=BTOL)
    print("nufft1d2_pinv(): completed in", res[2], "iterations")
    return res[0]

In [None]:
def _hogg_delta_omega(ts):
    """
    ## Bugs / issues:
    - Doesn't check that the input has the right units!
    - Doesn't have a proper code header.
    """
    Nt = len(ts)
    assert ts.shape == (Nt, )
    # check that `ts` is units of time.
    # choose a sensible conversion of `ts` to dimensionless positions `xs`.
    return 2. * np.pi * (Nt / (Nt + 1)) / (max(ts) - min(ts))

def hogg_ft_1d(ts, ys, max_freq=np.Inf):
    """
    ## Bugs / issues:
    - This makes tons of decisions "for" the user.
    - Not tested. Not even run!
    - Output fs have really weird units, people.
    - Doesn't have a proper code header.
    - Doesn't rearrange the frequencies into `scipy.fftpack` format.
    """
    Nt = len(ts)
    assert ts.shape == ys.shape
    # check that `ts` is units of time.
    # choose a sensible conversion of `ts` to dimensionless positions `xs`.
    median_t = np.nanmedian(ts)
    Delta_omega = _hogg_delta_omega(ts)
    xs = Delta_omega * (ts - median_t)
    assert (max(xs) - min(xs)) < 2. * np.pi
    # run `finufft` pseudo-inverse.
    Nf = (Nt // 2) * 2 + 1 # user doesn't choose! We might want to change this.
    omegas = Delta_omega * np.arange(-(Nf // 2), Nf // 2 + 0.5)
    freqs = omegas / (2. * np.pi)
    freqs = freqs[np.abs(freqs) < max_freq]
    Nf = len(freqs)
    fs = nufft1d2_pinv(xs, ys.astype(complex), Nf)
    # convert output to frequency units.
    assert freqs.shape == fs.shape
    return freqs, fs

def hogg_ift_1d(freqs, fs, ts):
    """
    ## Bugs:
    - No comment header
    """
    Nf = len(freqs)
    assert Nf % 2 == 1
    assert freqs.shape == fs.shape
    assert freqs.shape == (Nf, )
    Delta_omega = _hogg_delta_omega(ts)
    xs = Delta_omega * (ts - np.nanmedian(ts))
    return finufft.nufft1d2(xs, fs)

In [None]:
import pickle
with open("data.pkl", "rb") as fd:
    foo = pickle.load(fd).T
print(foo.shape)

In [None]:
import matplotlib.pyplot as plt
ts, ys = foo[0], foo[7]
plt.plot(ts, ys, "k.")

In [None]:
# See this `max_freq` input: If you don't set this, the ft is a complete basis,
# and it does crazy things to match the data exactly.
freqs, fs = hogg_ft_1d(ts, ys, max_freq=0.05)

In [None]:
plt.axvline(1. / 100., color="g", lw=0.5)
plt.axhline(0., color='k', lw=0.5)
plt.plot(freqs, np.abs(fs), "k.")
plt.step(freqs, np.abs(fs), where="mid", color="k", lw=0.5)
plt.plot(freqs, fs.real, "b.", alpha=0.25)
plt.plot(freqs, fs.imag, "r.", alpha=0.25)
plt.xlim(0., 0.05)
foo = np.max(np.abs(fs))
plt.ylim(-0.2 * foo, 1.2 * foo)

In [None]:
# This reconstructs the data from the ft.
# The reconstruction is not perfect because the basis isn't complete.
foo = hogg_ift_1d(freqs, fs, ts)
plt.plot(ts, foo, "k.")

In [None]:
# This filters the data.
# Note that it distorts heavily at the edges, because it has to make a periodic
# model!
fs_filtered = 1. * fs
fs_filtered[freqs > 0.013] = 0.+0.j
fs_filtered[freqs < 0.007] = 0.+0.j
foo = hogg_ift_1d(freqs, fs_filtered, ts)
plt.plot(ts, foo, "k.")

In [None]:
fs_filtered = 1. * fs
fs_filtered[freqs > 0.0205] = 0.+0.j
fs_filtered[freqs < 0.0195] = 0.+0.j
foo = hogg_ift_1d(freqs, fs_filtered, ts)
plt.plot(ts, foo, "k.")