Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Introduce xi_min #737

Closed
wants to merge 58 commits into from

Conversation

OverLordGoldDragon
Copy link
Collaborator

Introduces explicit guard against bandpasses ending up as pure sines.

See images. Attaining a proper wavelet can be very costly on padding, but at the same time we want to tile the entire frequency axis; I thus introduced max_pad_factor in latest JTFS, so user can choose faster compute at expense of boundary effects in small fraction of coefficients.

Rationale for xi_min = 2/N is, minimal possible wavelet is made with 3 samples in frequency domain (arguably 2), so peak cannot lie on bin 1, thus we put it at bin 2. As shown in "this PR only", this alone doesn't guarantee a good wavelet - it does, however, spare much padding that'd be required to make psi1_f[-1] a proper wavelet.

Note that other PRs in progress will allow a bandpass to become exactly a pure sine, exacerbating the problem.

Current behavior

image

After this PR + others

image

After this PR only

image

Plots code

show
# run from different branches
import numpy as np
from kymatio.numpy import Scattering1D
from ssqueezepy import ifft
from ssqueezepy.visuals import plot, scat

N = 2048 * 4
J = int(np.log2(N) - 1)
Q = 16
scattering = Scattering1D(J, N, Q, max_pad_factor=None)

p_fr = scattering.psi1_f[-1][0]
p_t = ifft(p_fr)

scat(p_fr[:30], show=1, title="psi1_f[-1][:30]")
plot(p_t, complex=1)
plot(p_t, abs=1, linestyle='--', color='k', title="ifft(psi1_f[-1])")

janden and others added 30 commits March 12, 2020 20:59
Co-authored-by: edouardoyallon <edouard.oyallon@lip6.fr>
 (kymatio#614)

* MAINT updated readme with paper

* Update for citation bibtex to appear

* MAINT Cite paper correctly

* Updated with bibtex

* DOC removed repetition
Make tensors real

So padding and modulus returns real tensors, FFTs accept real inputs and IFFTs return real outputs. Upgrade NumPy FFT to SciPy FFT. Remove FFT class and replace with functions. General refactoring of padding, modulus, and FFT backend functions.

* ENH addition of real to complex

* ENH removal of pad_1d, removal of casting to complex, complex modolus renamed to modolus

* ENH added rfft to fft call

* API refactored 1d scattering in torch to take advantage of rfft

* API refactored numpy backendAto take advantage of rfft

* API refactored tensorflow backend to take advantage of rfft

* MAINT removed print statements

* MAINT removed real backend primative

* API refactor torch 2d to use rfft

* TST changes to fix tests

* API refactor numpy 2d backend to take advantage of rfft

* API refactor tensorflow 2d backend to take advantage of rfft

* API refactor torch harmonic3d to take advantage of backend unification

* API refactor numpy harmonic3d to take advantage of backend unification

* API refactor tensorflow harmonic3d to take advantage of backend unification

* STY changed modulus_complex to modulus

* MAINT remove padding defaults

* STY removed white space

* MAINT fixed frontend stuffs

* MAINT removed _iscomplex from 3d backend

* MAINT removed complex_modulus function

* MAINT call type_check instead of repeating code

* TST test changes

* MAINT use scipy fft instead of numpy

* ENH add checks for complex

* MAINT add real and complex check in tensorflow backend

* MAINT added numpy real and complex check

* TEST check dtype rather than check type

* MAINT revmoed useless fft2

* MAINT rename

* MAINT remove unused sanity_check function

* MAINT pad header refactor

* MAINT replace scipy.fft with scipy.fftpack

* MAINT replaced torch.zeros_like with torch.zeros

* DOC removed unneeded docstrings

* MAINT renamed x to input_array

* TST added assert test, added spaces

* TST remove .squeeze

* TST changed tests to mimic numpy fft

* MAINT removed mode, value

* MAINT removed check for mode

* DOC docstring changes

* MAINT import fftpack, change modulus calculation

* TST removing unused code and line

* TST closed form for calculating fourier coefficents

* MAINT added contigouous check

* MAINT changed dtype of array

* MAINT changed way of initalizing modulus array

* API real tensors are now denoted with a final dim of 1.

* API 2d now marks tensors as real

* MAINT removed fft Complex to Complex

* API refactor each fft into its own function across all backends in scattering1d

* API removed base_backend, refactored FFT into seperate specific FFT functions.

* MAINT cdgmm now only takes in real tensors as input

* MAINT skcuda stuff runs

* MAINT modified kernels to output real arrays, added 3d modulus kernel.

* MAINT refactored numpy backend pad function

* MAINT revert addition of 3d modulus kernel

* MAINT removed usages of squeeze and unsqueeze introduced in the PR

* TST simplfication of complex coefficent calculation

* MAINT removed unused functions

* TST replace all instances of squeeze and unsqueeze

* MAINT specify axes in numpy FFTs

* MAINT use keepdims argument

* MAINT removed uneeded paren

* MAINT removed _is_real_ from 1d torch backend

* MAINT import _is_complex, cotiguous_check, complex_check

* MAINT replace checks with functions from general torch backend.

* MAINT replace all uses of new

* MAINT replaced new with empty_like

* TST readded tests

* TST readded test

* MAINT use empty or empty_like

* MAINT define n dimentional concatenate for 1d and 2d

* MAINT code changes

* TST remove squeeze

* MAINT added casting to complex in torch rfft backend primitives

* MAINT remove unused function

* MAINT use torch.empty

* MIANT replace np.real with .real

* MAINT we don't add a dimention and remove when padding

* MAINT added comments, changed modulus function return

* MAINT removed commas

* TST removed reshape operation

* TST added rfft test

* MAINT replace squeeze operation with reshape operation

* TST removed unused array, use meshgrid to generate indices now

* TST added FFT tests

* TST added fft tests

* MAINT removed paren

* MAINT remove unused import

* MAINT fft comments

* MAINT remove unused import

* TST removed indentation

* TST spaces

* MAINT changed  skcuda tensor initalization function

* TST changed fft tests

* MAINT spaces

* MAINT blank space

* MAINT removed empty line

* TST better FFT tests

* TST better fft test

* TST removed un-needed lines, whitespace

Co-authored-by: chaudhm <chaudhm@wwu.edu>
Right now, m2r is not compatible with Sphinx v3.
Safe way of doing this is with `python3 -m pip` instead of `pip3` since
this ensures that we're using the version of `pip` consistent with our
`python3` binary.
First, we output the current version of `pip`, then call `pip freeze` to
list the installed pacakges in a requirements-style format.
Need to access utils with `cupy._util` instead of `cupy.util`.
While SCC is investigating the feasability of upgrading the GPUs, we
only have access to the one.
janden and others added 11 commits February 12, 2021 20:35
MAINT Fix spelling of Skcuda backend class

TST Fix bug in 2D torch subsample tests

MAINT Update 2D Torch backend to use inheritance

MAINT Make torch_skcuda backend in 2D use inheritance

MAINT Make 3D torch backend use inheritance

MAINT Make 3D torch_skcuda backend use inheritance

FIX remove backend name as class init argument

ENH numpy 2d backend class

COSM remove commented code

ENH tensorflow 2d backend class

ENH numpy 3d backend class

ENH tensorflow 3d backend class

MAINT Update NP backends with class methods

MAINT Update TF backends with class methods

MAINT Update Torch backends to use class methods

MAINT remove cdgmm from tf backend because inherited from numpy backend

COSM remove commented code

MAINT make tf 3d backend use static and class methods

MAINT remove instantiation from torch backend test

MAINT remove tf import from tf 1d backend test

Co-authored-by: Joakim Andén <janden@kth.se>
We no longer support 3.5, so minimum version is 3.6.
The implementaiton used by TF to calculate the FFT is not very
accurate, so we need to relax the tolerance a little to eliminate
spurious test failures.
* MAINT bump version (kymatio#607)

* fix-653 (with CPU numpy only)

* splitting test into numpy torch

* fixing bugs

* adding CUDA support

* adding cupy..

* woof

* woof

* woof

* pytorch doesnt like multi processing

* figure out a solution

* figure out a solution

* figure out a solution

* figure out a solution

* ctg check

* adding more time

* improving tests of 1d 2d 3d

* imrpoving 123D test

* minor bug

* restructuring the tests

* refactor

* adding files

* modif

* minor fix

* minor fix

* doc

* doc

* micro bug with the backend

* micro bug with the backend

* addressing joakims comments and rebasing on dev

* random seeds

* removing some file

Co-authored-by: eickenberg <eickenberg@users.noreply.github.com>
Co-authored-by: edouardoyallon <edouard.oyallon@ens.fr>
@kymatio kymatio deleted a comment from codecov-commenter May 3, 2021
@kymatio kymatio deleted a comment from codecov-commenter May 3, 2021
@OverLordGoldDragon
Copy link
Collaborator Author

OverLordGoldDragon commented May 6, 2021

Problem is worse than I thought, it's not just "last few filters" but last few octaves; for large Q dozens of filters never decay to zero. With J, Q = 11, 16, here's psi1_f[-30], meaning 29 more wavelets have worse decay than this:

image

This is remedied by padding more, rather than via xi_min (though it helps by requiring less padding).

@kymatio kymatio deleted a comment from codecov-commenter May 28, 2021
@OverLordGoldDragon
Copy link
Collaborator Author

This can merge at any point (preferably after #673).

@lostanlen
Copy link
Collaborator

so this PR effectively reduces the number of filters in the filterbank while leaving everything else the same?

i see your line of thinking for this bounds check but why allow the user to override it? wouldn't that add unnecessary complexity to filter design?

@OverLordGoldDragon
Copy link
Collaborator Author

It's passed internally by scattering_filter_factory, not user-exposed.

@OverLordGoldDragon
Copy link
Collaborator Author

OverLordGoldDragon commented Mar 26, 2022

Post inspired by #800.

Design concerns A) feature quality, and B) preserving of information. As we pad more, we lower the lowest possible frequency, and hence add more room for low-frequency wavelets. Likewise, more padding means populating more lower frequencies of the input signal, fft(x_pad). Incomplete tiling == inability to invert, hence loss of information... yet at the same time, padding == adding no information. Who "wins"?

First, the wavelets: top row setting the lowest frequency at input's minimum (DFT bin 1), bottom row setting it at padded's bin 1:

As a reminder, at this scale (J = log2(N)) we require padding, else we get a pure sine for a wavelet: (note, while the two psi_t look similar, only one suffers from circular conv.)

Taking x = randn(N) and comparing with padded:

Padding is by a factor of 16, hence unpadded bin 1 moves to padded bin 16. Yet, bins 1 to 15 are filled, and if we don't tile these, information is lost. However,

has nothing in 1-15, hence the answer: what's lost is what's needed to recover the padded portion. The "extra" frequencies in zero or whichever padding are required to encode the fact that our transform is to invert to zeros outside the original input's interval.

  • Caveat, bit of a lapse in reasoning: the padded portion serves not only as an "optional" segment to be inverted, but also as a contributing interval to a large kernel convolution. Since xi=1/N is all that's needed to fully tile the unpadded signal, and since the trimmed temporal waveforms are close in decay and identical in modulation... in zero-pad case we have N slidings of the wavelet with same carrier but different windowings, so all that'll differ for recovery is the un-windowing.

This covers B). For A, we ask whether a wavelet that completes less than one cycle over the input is important. This wavelet correlates to zero with periodic padding, as evident from freq-domain, but not with other padding. While both 1/N and 1/N_pad for xi have the same support (since non-CQT portion) of which the original input comprises only a small fraction, a lower frequency requires a longer observation interval to ascertain the measurement; in this sense 1/N_pad is worse.

Favoring 1/N also helps keep filterbank design independent of padding settings, convenient for a number of reasons, particularly in JTFS. Last point concerns: 1/N or 2/N? Latter helps guarantee a wavelet as opposed to a pure tone, as a wavelet is minimally three (arguably two) points in frequency domain, and we can't plant at DC so the peak must be at 2. More importantly, insufficient padding with J=log2(N) yields a near-pure sine regardless, which risks not tiling bin 1 at all. Hence, 1/N.

@OverLordGoldDragon OverLordGoldDragon mentioned this pull request Mar 26, 2022
@OverLordGoldDragon OverLordGoldDragon changed the title Introduce xi_min [WIP] Introduce xi_min Mar 26, 2022
@OverLordGoldDragon
Copy link
Collaborator Author

"WIP" just to change to 1/N and see if tests bork

@OverLordGoldDragon
Copy link
Collaborator Author

Missed edge case; 'reflect' can meaningfully increase variability:

code
import numpy as np
from numpy.fft import rfft
import matplotlib.pyplot as plt

t = np.linspace(0, 1, 129, 1)
g = np.cos(2*np.pi * .5 * t)
gp = np.pad(g, [64, 63], mode='reflect')
gf, gpf = rfft(g), rfft(gp)

tkw = dict(weight='bold', fontsize=17, loc='left')
fig, axes = plt.subplots(1, 2, sharey=True)
axes[0].plot(g)
axes[1].plot(gp)
axes[0].set_title("x", **tkw)
axes[1].set_title("x_pad", **tkw)
fig.subplots_adjust(left=0, right=1, wspace=.02)
fig.set_size_inches(11, 4.5)
plt.show()

fig, axes = plt.subplots(1, 2, sharey=True)
axes[0].plot(np.abs(gf))
axes[1].plot(np.abs(gpf))
axes[1].scatter(np.arange(len(gpf))[:3], np.abs(gpf)[:3])
axes[0].set_title("|rfft(x)|", **tkw)
axes[1].set_title("|rfft(x_pad)|", **tkw)
fig.subplots_adjust(left=0, right=1, wspace=.02)
fig.set_size_inches(11, 4.5)
plt.show()

Half cycle -> full cycle, and the padded wavelet whose frequency is now lesser than it was possible without padding will capture this frequency more accurately than the fully invertible filterbank of unpadded, winning in "feature quality" and tying in invertibility.

Whether this is plausible or not is a secondary matter; the only important downside is making the filterbank depend upon padding settings. But this is meaningful, since the padded part cannot simply be "detached" from the original (despite me trying to do this). It's also arguable whether "feature quality" is improved since the half cycle persisting for a full cycle is a prior - but by retaining full tiling of padded we assure everything's captured for any padding choice. The added wavelets are max j, so again minimal speed and no memory overhead.

Hence, xi_min = 1/2**J_pad. This PR does 2/2**J_pad, but also in JTFS I made it so that xi_min is computed inside of compute_params_filterbank. Might update later, or just keep this as a "concept" PR and merge the real thing with JTFS.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants