/
filtering.py
66 lines (52 loc) · 2.09 KB
/
filtering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
r"""
Filtering a signal
==================
A graph signal is filtered by transforming it to the spectral domain (via the
Fourier transform), performing a point-wise multiplication (motivated by the
convolution theorem), and transforming it back to the vertex domain (via the
inverse graph Fourier transform).
.. note::
In practice, filtering is implemented in the vertex domain to avoid the
computationally expensive graph Fourier transform. To do so, filters are
implemented as polynomials of the eigenvalues / Laplacian. Hence, filtering
a signal reduces to its multiplications with sparse matrices (the graph
Laplacian).
"""
import numpy as np
from matplotlib import pyplot as plt
import pygsp as pg
G = pg.graphs.Sensor(seed=42)
G.compute_fourier_basis()
#g = pg.filters.Rectangular(G, band_max=0.2)
g = pg.filters.Expwin(G, band_max=0.5)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.subplots_adjust(hspace=0.5)
x = np.random.default_rng(1).normal(size=G.N)
#x = np.random.default_rng(42).uniform(-1, 1, size=G.N)
x = 3 * x / np.linalg.norm(x)
y = g.filter(x)
x_hat = G.gft(x).squeeze()
y_hat = G.gft(y).squeeze()
limits = [x.min(), x.max()]
G.plot(x, limits=limits, ax=axes[0], title='input signal $x$ in the vertex domain')
axes[0].text(0, -0.1, '$x^T L x = {:.2f}$'.format(G.dirichlet_energy(x)))
axes[0].set_axis_off()
g.plot(ax=axes[1], alpha=1)
line_filt = axes[1].lines[-2]
line_in, = axes[1].plot(G.e, np.abs(x_hat), '.-')
line_out, = axes[1].plot(G.e, np.abs(y_hat), '.-')
#axes[1].set_xticks(range(0, 16, 4))
axes[1].set_xlabel(r'graph frequency $\lambda$')
axes[1].set_ylabel(r'frequency content $\hat{x}(\lambda)$')
axes[1].set_title(r'signals in the spectral domain')
axes[1].legend(['input signal $\hat{x}$'])
labels = [
r'input signal $\hat{x}$',
'kernel $g$',
r'filtered signal $\hat{y}$',
]
axes[1].legend([line_in, line_filt, line_out], labels, loc='upper right')
G.plot(y, limits=limits, ax=axes[2], title='filtered signal $y$ in the vertex domain')
axes[2].text(0, -0.1, '$y^T L y = {:.2f}$'.format(G.dirichlet_energy(y)))
axes[2].set_axis_off()
fig.tight_layout()