/
filter_bank.py
733 lines (645 loc) · 27.6 KB
/
filter_bank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
import numpy as np
import math
import warnings
from scipy.fftpack import ifft
def adaptive_choice_P(sigma, eps=1e-7):
"""
Adaptive choice of the value of the number of periods in the frequency
domain used to compute the Fourier transform of a Morlet wavelet.
This function considers a Morlet wavelet defined as the sum
of
* a Gabor term hat psi(omega) = hat g_{sigma}(omega - xi)
where 0 < xi < 1 is some frequency and g_{sigma} is
the Gaussian window defined in Fourier by
hat g_{sigma}(omega) = e^{-omega^2/(2 sigma^2)}
* a low pass term \\hat \\phi which is proportional to \\hat g_{\\sigma}.
If \\sigma is too large, then these formula will lead to discontinuities
in the frequency interval [0, 1] (which is the interval used by numpy.fft).
We therefore choose a larger integer P >= 1 such that at the boundaries
of the Fourier transform of both filters on the interval [1-P, P], the
magnitude of the entries is below the required machine precision.
Mathematically, this means we would need P to satisfy the relations:
|\\hat \\psi(P)| <= eps and |\\hat \\phi(1-P)| <= eps
Since 0 <= xi <= 1, the latter implies the former. Hence the formula which
is easily derived using the explicit formula for g_{\\sigma} in Fourier.
Parameters
----------
sigma: float
Positive number controlling the bandwidth of the filters
eps : float, optional
Positive number containing required precision. Defaults to 1e-7
Returns
-------
P : int
integer controlling the number of periods used to ensure the
periodicity of the final Morlet filter in the frequency interval
[0, 1[. The value of P will lead to the use of the frequency
interval [1-P, P[, so that there are 2*P - 1 periods.
"""
val = math.sqrt(-2 * (sigma**2) * math.log(eps))
P = int(math.ceil(val + 1))
return P
def periodize_filter_fourier(h_f, nperiods=1):
"""
Computes a periodization of a filter provided in the Fourier domain.
Parameters
----------
h_f : array_like
complex numpy array of shape (N*n_periods,)
n_periods: int, optional
Number of periods which should be used to periodize
Returns
-------
v_f : array_like
complex numpy array of size (N,), which is a periodization of
h_f as described in the formula:
v_f[k] = sum_{i=0}^{n_periods - 1} h_f[i * N + k]
"""
N = h_f.shape[0] // nperiods
v_f = h_f.reshape(nperiods, N).mean(axis=0)
return v_f
def morlet_1d(N, xi, sigma, normalize='l1', P_max=5, eps=1e-7):
"""
Computes the Fourier transform of a Morlet filter.
A Morlet filter is the sum of a Gabor filter and a low-pass filter
to ensure that the sum has exactly zero mean in the temporal domain.
It is defined by the following formula in time:
psi(t) = g_{sigma}(t) (e^{i xi t} - beta)
where g_{sigma} is a Gaussian envelope, xi is a frequency and beta is
the cancelling parameter.
Parameters
----------
N : int
size of the temporal support
xi : float
central frequency (in [0, 1])
sigma : float
bandwidth parameter
normalize : string, optional
normalization types for the filters. Defaults to 'l1'.
Supported normalizations are 'l1' and 'l2' (understood in time domain).
P_max: int, optional
integer controlling the maximal number of periods to use to ensure
the periodicity of the Fourier transform. (At most 2*P_max - 1 periods
are used, to ensure an equal distribution around 0.5). Defaults to 5
Should be >= 1
eps : float
required machine precision (to choose the adequate P)
Returns
-------
morlet_f : array_like
numpy array of size (N,) containing the Fourier transform of the Morlet
filter at the frequencies given by np.fft.fftfreq(N).
"""
if type(P_max) != int:
raise ValueError('P_max should be an int, got {}'.format(type(P_max)))
if P_max < 1:
raise ValueError('P_max should be non-negative, got {}'.format(P_max))
# Find the adequate value of P
P = min(adaptive_choice_P(sigma, eps=eps), P_max)
assert P >= 1
# Define the frequencies over [1-P, P[
freqs = np.arange((1 - P) * N, P * N, dtype=float) / float(N)
if P == 1:
# in this case, make sure that there is continuity around 0
# by using the interval [-0.5, 0.5]
freqs_low = np.fft.fftfreq(N)
elif P > 1:
freqs_low = freqs
# define the gabor at freq xi and the low-pass, both of width sigma
gabor_f = np.exp(-(freqs - xi)**2 / (2 * sigma**2))
low_pass_f = np.exp(-(freqs_low**2) / (2 * sigma**2))
# discretize in signal <=> periodize in Fourier
gabor_f = periodize_filter_fourier(gabor_f, nperiods=2 * P - 1)
low_pass_f = periodize_filter_fourier(low_pass_f, nperiods=2 * P - 1)
# find the summation factor to ensure that morlet_f[0] = 0.
kappa = gabor_f[0] / low_pass_f[0]
morlet_f = gabor_f - kappa * low_pass_f
# normalize the Morlet if necessary
morlet_f *= get_normalizing_factor(morlet_f, normalize=normalize)
return morlet_f
def get_normalizing_factor(h_f, normalize='l1'):
"""
Computes the desired normalization factor for a filter defined in Fourier.
Parameters
----------
h_f : array_like
numpy vector containing the Fourier transform of a filter
normalized : string, optional
desired normalization type, either 'l1' or 'l2'. Defaults to 'l1'.
Returns
-------
norm_factor : float
such that h_f * norm_factor is the adequately normalized vector.
"""
h_real = ifft(h_f)
if np.abs(h_real).sum() < 1e-7:
raise ValueError('Zero division error is very likely to occur, ' +
'aborting computations now.')
if normalize == 'l1':
norm_factor = 1. / (np.abs(h_real).sum())
elif normalize == 'l2':
norm_factor = 1. / np.sqrt((np.abs(h_real)**2).sum())
else:
raise ValueError("Supported normalizations only include 'l1' and 'l2'")
return norm_factor
def gauss_1d(N, sigma, normalize='l1', P_max=5, eps=1e-7):
"""
Computes the Fourier transform of a low pass gaussian window.
\\hat g_{\\sigma}(\\omega) = e^{-\\omega^2 / 2 \\sigma^2}
Parameters
----------
N : int
size of the temporal support
sigma : float
bandwidth parameter
normalize : string, optional
normalization types for the filters. Defaults to 'l1'
Supported normalizations are 'l1' and 'l2' (understood in time domain).
P_max : int, optional
integer controlling the maximal number of periods to use to ensure
the periodicity of the Fourier transform. (At most 2*P_max - 1 periods
are used, to ensure an equal distribution around 0.5). Defaults to 5
Should be >= 1
eps : float, optional
required machine precision (to choose the adequate P)
Returns
-------
g_f : array_like
numpy array of size (N,) containing the Fourier transform of the
filter (with the frequencies in the np.fft.fftfreq convention).
"""
# Find the adequate value of P
if type(P_max) != int:
raise ValueError('P_max should be an int, got {}'.format(type(P_max)))
if P_max < 1:
raise ValueError('P_max should be non-negative, got {}'.format(P_max))
P = min(adaptive_choice_P(sigma, eps=eps), P_max)
assert P >= 1
# switch cases
if P == 1:
freqs_low = np.fft.fftfreq(N)
elif P > 1:
freqs_low = np.arange((1 - P) * N, P * N, dtype=float) / float(N)
# define the low pass
g_f = np.exp(-freqs_low**2 / (2 * sigma**2))
# periodize it
g_f = periodize_filter_fourier(g_f, nperiods=2 * P - 1)
# normalize the signal
g_f *= get_normalizing_factor(g_f, normalize=normalize)
# return the Fourier transform
return g_f
def compute_sigma_psi(xi, Q, r=math.sqrt(0.5)):
"""
Computes the frequential width sigma for a Morlet filter of frequency xi
belonging to a family with Q wavelets.
The frequential width is adapted so that the intersection of the
frequency responses of the next filter occurs at a r-bandwidth specified
by r, to ensure a correct coverage of the whole frequency axis.
Parameters
----------
xi : float
frequency of the filter in [0, 1]
Q : int
number of filters per octave, Q is an integer >= 1
r : float, optional
Positive parameter defining the bandwidth to use.
Should be < 1. We recommend keeping the default value.
The larger r, the larger the filters in frequency domain.
Returns
-------
sigma : float
frequential width of the Morlet wavelet.
Refs
----
Convolutional operators in the time-frequency domain, V. Lostanlen,
PhD Thesis, 2017
https://tel.archives-ouvertes.fr/tel-01559667
"""
factor = 1. / math.pow(2, 1. / Q)
term1 = (1 - factor) / (1 + factor)
term2 = 1. / math.sqrt(2 * math.log(1. / r))
return xi * term1 * term2
def compute_temporal_support(h_f, criterion_amplitude=1e-3):
"""
Computes the (half) temporal support of a family of centered,
symmetric filters h provided in the Fourier domain
This function computes the support T which is the smallest integer
such that for all signals x and all filters h,
\\| x \\conv h - x \\conv h_{[-T, T]} \\|_{\\infty} \\leq \\epsilon
\\| x \\|_{\\infty} (1)
where 0<\\epsilon<1 is an acceptable error, and h_{[-T, T]} denotes the
filter h whose support is restricted in the interval [-T, T]
The resulting value T used to pad the signals to avoid boundary effects
and numerical errors.
If the support is too small, no such T might exist.
In this case, T is defined as the half of the support of h, and a
UserWarning is raised.
Parameters
----------
h_f : array_like
a numpy array of size batch x time, where each row contains the
Fourier transform of a filter which is centered and whose absolute
value is symmetric
criterion_amplitude : float, optional
value \\epsilon controlling the numerical
error. The larger criterion_amplitude, the smaller the temporal
support and the larger the numerical error. Defaults to 1e-3
Returns
-------
t_max : int
temporal support which ensures (1) for all rows of h_f
"""
h = ifft(h_f, axis=1)
half_support = h.shape[1] // 2
# compute ||h - h_[-T, T]||_1
l1_residual = np.fliplr(
np.cumsum(np.fliplr(np.abs(h)[:, :half_support]), axis=1))
# find the first point above criterion_amplitude
if np.any(np.max(l1_residual, axis=0) <= criterion_amplitude):
# if it is possible
T = np.min(
np.where(np.max(l1_residual, axis=0) <= criterion_amplitude)[0])\
+ 1
else:
# if there is none:
T = half_support
# Raise a warning to say that there will be border effects
warnings.warn('Signal support is too small to avoid border effects')
return T
def get_max_dyadic_subsampling(xi, sigma, alpha=5.):
"""
Computes the maximal dyadic subsampling which is possible for a Gabor
filter of frequency xi and width sigma
Finds the maximal integer j such that:
omega_0 < 2^{-(j + 1)}
where omega_0 is the boundary of the filter, defined as
omega_0 = xi + alpha * sigma
This ensures that the filter can be subsampled by a factor 2^j without
aliasing.
We use the same formula for Gabor and Morlet filters.
Parameters
----------
xi : float
frequency of the filter in [0, 1]
sigma : float
frequential width of the filter
alpha : float, optional
parameter controlling the error done in the aliasing.
The larger alpha, the smaller the error. Defaults to 5.
Returns
-------
j : int
integer such that 2^j is the maximal subsampling accepted by the
Gabor filter without aliasing.
"""
upper_bound = min(xi + alpha * sigma, 0.5)
j = math.floor(-math.log2(upper_bound)) - 1
j = int(j)
return j
def move_one_dyadic_step(cv, Q, alpha=5.):
"""
Computes the parameters of the next wavelet on the low frequency side,
based on the parameters of the current wavelet.
This function is used in the loop defining all the filters, starting
at the wavelet frequency and then going to the low frequencies by
dyadic steps. This makes the loop in compute_params_filterbank much
simpler to read.
The steps are defined as:
xi_{n+1} = 2^{-1/Q} xi_n
sigma_{n+1} = 2^{-1/Q} sigma_n
Parameters
----------
cv : dictionary
stands for current_value. Is a dictionary with keys:
*'key': a tuple (j, n) where n is a counter and j is the maximal
dyadic subsampling accepted by this wavelet.
*'xi': central frequency of the wavelet
*'sigma': width of the wavelet
Q : int
number of wavelets per octave. Controls the relationship between
the frequency and width of the current wavelet and the next wavelet.
alpha : float, optional
tolerance parameter for the aliasing. The larger alpha,
the more conservative the algorithm is. Defaults to 5.
Returns
-------
new_cv : dictionary
a dictionary with the same keys as the ones listed for cv,
whose values are updated
"""
factor = 1. / math.pow(2., 1. / Q)
n = cv['key']
new_cv = {'xi': cv['xi'] * factor, 'sigma': cv['sigma'] * factor}
# compute the new j
new_cv['j'] = get_max_dyadic_subsampling(new_cv['xi'], new_cv['sigma'], alpha=alpha)
new_cv['key'] = n + 1
return new_cv
def compute_xi_max(Q):
"""
Computes the maximal xi to use for the Morlet family, depending on Q.
Parameters
----------
Q : int
number of wavelets per octave (integer >= 1)
Returns
-------
xi_max : float
largest frequency of the wavelet frame.
"""
xi_max = max(1. / (1. + math.pow(2., 3. / Q)), 0.35)
return xi_max
def compute_params_filterbank(sigma_low, Q, r_psi=math.sqrt(0.5), alpha=5.):
"""
Computes the parameters of a Morlet wavelet filterbank.
This family is defined by constant ratios between the frequencies and
width of adjacent filters, up to a minimum frequency where the frequencies
are translated.
This ensures that the low-pass filter has the largest temporal support
among all filters, while preserving the coverage of the whole frequency
axis.
The keys of the dictionaries are tuples of integers (j, n) where n is a
counter (starting at 0 for the highest frequency filter) and j is the
maximal dyadic subsampling accepted by this filter.
Parameters
----------
sigma_low : float
frequential width of the low-pass filter. This acts as a
lower-bound on the frequential widths of the band-pass filters,
so as to ensure that the low-pass filter has the largest temporal
support among all filters.
Q : int
number of wavelets per octave.
r_psi : float, optional
Should be >0 and <1. Controls the redundancy of the filters
(the larger r_psi, the larger the overlap between adjacent wavelets).
Defaults to sqrt(0.5).
alpha : float, optional
tolerance factor for the aliasing after subsampling.
The larger alpha, the more conservative the value of maximal
subsampling is. Defaults to 5.
Returns
-------
xi : dictionary
dictionary containing the central frequencies of the wavelets.
sigma : dictionary
dictionary containing the frequential widths of the wavelets.
Refs
----
Convolutional operators in the time-frequency domain, 2.1.3, V. Lostanlen,
PhD Thesis, 2017
https://tel.archives-ouvertes.fr/tel-01559667
"""
xi_max = compute_xi_max(Q)
sigma_max = compute_sigma_psi(xi_max, Q, r=r_psi)
xi = []
sigma = []
j = []
if sigma_max <= sigma_low:
# in this exceptional case, we will not go through the loop, so
# we directly assign
last_xi = sigma_max
else:
# fill all the dyadic wavelets as long as possible
current = {'key': 0, 'j': 0, 'xi': xi_max, 'sigma': sigma_max}
while current['sigma'] > sigma_low: # while we can attribute something
xi.append(current['xi'])
sigma.append(current['sigma'])
j.append(current['j'])
current = move_one_dyadic_step(current, Q, alpha=alpha)
# get the last key
last_xi = xi[-1]
# fill num_interm wavelets between last_xi and 0, both excluded
num_intermediate = Q - 1
for q in range(1, num_intermediate + 1):
factor = (num_intermediate + 1. - q) / (num_intermediate + 1.)
new_xi = factor * last_xi
new_sigma = sigma_low
xi.append(new_xi)
sigma.append(new_sigma)
j.append(get_max_dyadic_subsampling(new_xi, new_sigma, alpha=alpha))
# return results
return xi, sigma, j
def calibrate_scattering_filters(J, Q, r_psi=math.sqrt(0.5), sigma0=0.1,
alpha=5.):
"""
Calibrates the parameters of the filters used at the 1st and 2nd orders
of the scattering transform.
These filterbanks share the same low-pass filterbank, but use a
different Q: Q_1 = Q and Q_2 = 1.
The dictionaries for the band-pass filters have keys which are 2-tuples
of the type (j, n), where n is an integer >=0 counting the filters (for
identification purposes) and j is an integer >= 0 denoting the maximal
subsampling 2**j which can be performed on a signal convolved with this
filter without aliasing.
Parameters
----------
J : int
maximal scale of the scattering (controls the number of wavelets)
Q : int
number of wavelets per octave for the first order
r_psi : float, optional
Should be >0 and <1. Controls the redundancy of the filters
(the larger r_psi, the larger the overlap between adjacent wavelets).
Defaults to sqrt(0.5)
sigma0 : float, optional
frequential width of the low-pass filter at scale J=0
(the subsequent widths are defined by sigma_J = sigma0 / 2^J).
Defaults to 1e-1
alpha : float, optional
tolerance factor for the aliasing after subsampling.
The larger alpha, the more conservative the value of maximal
subsampling is. Defaults to 5.
Returns
-------
sigma_low : float
frequential width of the low-pass filter
xi1 : dictionary
dictionary containing the center frequencies of the first order
filters. See above for a decsription of the keys.
sigma1 : dictionary
dictionary containing the frequential width of the first order
filters. See above for a description of the keys.
xi2 : dictionary
dictionary containing the center frequencies of the second order
filters. See above for a decsription of the keys.
sigma2 : dictionary
dictionary containing the frequential width of the second order
filters. See above for a description of the keys.
"""
if Q < 1:
raise ValueError('Q should always be >= 1, got {}'.format(Q))
sigma_low = sigma0 / math.pow(2, J) # width of the low pass
xi1, sigma1, j1 = compute_params_filterbank(sigma_low, Q, r_psi=r_psi,
alpha=alpha)
xi2, sigma2, j2 = compute_params_filterbank(sigma_low, 1, r_psi=r_psi,
alpha=alpha)
return sigma_low, xi1, sigma1, j1, xi2, sigma2, j2
def scattering_filter_factory(J_support, J_scattering, Q, r_psi=math.sqrt(0.5),
criterion_amplitude=1e-3, normalize='l1',
max_subsampling=None, sigma0=0.1, alpha=5.,
P_max=5, eps=1e-7, **kwargs):
"""
Builds in Fourier the Morlet filters used for the scattering transform.
Each single filter is provided as a dictionary with the following keys:
* 'xi': central frequency, defaults to 0 for low-pass filters.
* 'sigma': frequential width
* k where k is an integer bounded below by 0. The maximal value for k
depends on the type of filter, it is dynamically chosen depending
on max_subsampling and the characteristics of the filters.
Each value for k is an array (or tensor) of size 2**(J_support - k)
containing the Fourier transform of the filter after subsampling by
2**k
Parameters
----------
J_support : int
2**J_support is the desired support size of the filters
J_scattering : int
parameter for the scattering transform (2**J_scattering
corresponds to the averaging support of the low-pass filter)
Q : int
number of wavelets per octave at the first order. For audio signals,
a value Q >= 12 is recommended in order to separate partials.
r_psi : float, optional
Should be >0 and <1. Controls the redundancy of the filters
(the larger r_psi, the larger the overlap between adjacent wavelets).
Defaults to sqrt(0.5).
criterion_amplitude : float, optional
Represents the numerical error which is allowed to be lost after
convolution and padding. Defaults to 1e-3.
normalize : string, optional
Normalization convention for the filters (in the
temporal domain). Supported values include 'l1' and 'l2'; a ValueError
is raised otherwise. Defaults to 'l1'.
max_subsampling: int or None, optional
maximal dyadic subsampling to compute, in order
to save computation time if it is not required. Defaults to None, in
which case this value is dynamically adjusted depending on the filters.
sigma0 : float, optional
parameter controlling the frequential width of the
low-pass filter at J_scattering=0; at a an absolute J_scattering, it
is equal to sigma0 / 2**J_scattering. Defaults to 1e-1
alpha : float, optional
tolerance factor for the aliasing after subsampling.
The larger alpha, the more conservative the value of maximal
subsampling is. Defaults to 5.
P_max : int, optional
maximal number of periods to use to make sure that the Fourier
transform of the filters is periodic. P_max = 5 is more than enough for
double precision. Defaults to 5. Should be >= 1
eps : float, optional
required machine precision for the periodization (single
floating point is enough for deep learning applications).
Defaults to 1e-7
Returns
-------
phi_f : dictionary
a dictionary containing the low-pass filter at all possible
subsamplings. See above for a description of the dictionary structure.
The possible subsamplings are controlled by the inputs they can
receive, which correspond to the subsamplings performed on top of the
1st and 2nd order transforms.
psi1_f : dictionary
a dictionary containing the band-pass filters of the 1st order,
only for the base resolution as no subsampling is used in the
scattering tree.
Each value corresponds to a dictionary for a single filter, see above
for an exact description.
The keys of this dictionary are of the type (j, n) where n is an
integer counting the filters and j the maximal dyadic subsampling
which can be performed on top of the filter without aliasing.
psi2_f : dictionary
a dictionary containing the band-pass filters of the 2nd order
at all possible subsamplings. The subsamplings are determined by the
input they can receive, which depends on the scattering tree.
Each value corresponds to a dictionary for a single filter, see above
for an exact description.
The keys of this dictionary are of th etype (j, n) where n is an
integer counting the filters and j is the maximal dyadic subsampling
which can be performed on top of this filter without aliasing.
t_max_phi : int
temporal size to use to pad the signal on the right and on the
left by making at most criterion_amplitude error. Assumes that the
temporal support of the low-pass filter is larger than all filters.
Refs
----
Convolutional operators in the time-frequency domain, V. Lostanlen,
PhD Thesis, 2017
https://tel.archives-ouvertes.fr/tel-01559667
"""
# compute the spectral parameters of the filters
sigma_low, xi1, sigma1, j1s, xi2, sigma2, j2s = calibrate_scattering_filters(
J_scattering, Q, r_psi=r_psi, sigma0=sigma0, alpha=alpha)
# instantiate the dictionaries which will contain the filters
phi_f = {}
psi1_f = []
psi2_f = []
# compute the band-pass filters of the second order,
# which can take as input a subsampled
for (n2, j2) in enumerate(j2s):
# compute the current value for the max_subsampling,
# which depends on the input it can accept.
if max_subsampling is None:
possible_subsamplings_after_order1 = [
j1 for j1 in j1s if j2 > j1]
if len(possible_subsamplings_after_order1) > 0:
max_sub_psi2 = max(possible_subsamplings_after_order1)
else:
max_sub_psi2 = 0
else:
max_sub_psi2 = max_subsampling
# We first compute the filter without subsampling
T = 2**J_support
psi_f = {}
psi_f[0] = morlet_1d(
T, xi2[n2], sigma2[n2], normalize=normalize, P_max=P_max,
eps=eps)
# compute the filter after subsampling at all other subsamplings
# which might be received by the network, based on this first filter
for subsampling in range(1, max_sub_psi2 + 1):
factor_subsampling = 2**subsampling
psi_f[subsampling] = periodize_filter_fourier(
psi_f[0], nperiods=factor_subsampling)
psi2_f.append(psi_f)
# for the 1st order filters, the input is not subsampled so we
# can only compute them with T=2**J_support
for (n1, j1) in enumerate(j1s):
T = 2**J_support
psi1_f.append({0: morlet_1d(
T, xi1[n1], sigma1[n1], normalize=normalize,
P_max=P_max, eps=eps)})
# compute the low-pass filters phi
# Determine the maximal subsampling for phi, which depends on the
# input it can accept (both 1st and 2nd order)
if max_subsampling is None:
max_subsampling_after_psi1 = max(j1s)
max_subsampling_after_psi2 = max(j2s)
max_sub_phi = max(max_subsampling_after_psi1,
max_subsampling_after_psi2)
else:
max_sub_phi = max_subsampling
# compute the filters at all possible subsamplings
phi_f[0] = gauss_1d(T, sigma_low, P_max=P_max, eps=eps)
for subsampling in range(1, max_sub_phi + 1):
factor_subsampling = 2**subsampling
# compute the low_pass filter
phi_f[subsampling] = periodize_filter_fourier(
phi_f[0], nperiods=factor_subsampling)
# Embed the meta information within the filters
for (n1, j1) in enumerate(j1s):
psi1_f[n1]['xi'] = xi1[n1]
psi1_f[n1]['sigma'] = sigma1[n1]
psi1_f[n1]['j'] = j1
for (n2, j2) in enumerate(j2s):
psi2_f[n2]['xi'] = xi2[n2]
psi2_f[n2]['sigma'] = sigma2[n2]
psi2_f[n2]['j'] = j2
phi_f['xi'] = 0.
phi_f['sigma'] = sigma_low
phi_f['j'] = 0
# compute the support size allowing to pad without boundary errors
# at the finest resolution
t_max_phi = compute_temporal_support(
phi_f[0].reshape(1, -1), criterion_amplitude=criterion_amplitude)
# return results
return phi_f, psi1_f, psi2_f, t_max_phi