Skip to content

Commit

Permalink
Merge pull request #36 from epfl-lts2/wave-filter
Browse files Browse the repository at this point in the history
Implement a filter that solves the wave equation
  • Loading branch information
mdeff committed Dec 19, 2018
2 parents fa0e3d0 + ca0e79b commit c30cec5
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 84 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ main objects of the package.
>>> G = graphs.Logo()
>>> G.compute_fourier_basis() # Fourier to plot the eigenvalues.
>>> # G.estimate_lmax() is otherwise sufficient.
>>> g = filters.Heat(G, tau=50)
>>> g = filters.Heat(G, scale=50)
>>> fig, ax = g.plot()

.. image:: ../pygsp/data/readme_example_filter.png
Expand Down
1 change: 1 addition & 0 deletions doc/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ History
* Better documentation of the frame and its bounds.
* ``g.inverse()`` returns the pseudo-inverse of the filter bank.
* ``g.complement()`` returns the filter that makes the frame tight.
* Wave filter bank which application simulates the propagation of a wave.

Experimental filter API (to be tested and validated):

Expand Down
14 changes: 14 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,17 @@ @inproceedings{leonardi2011wavelet
year={2011},
organization={IEEE}
}

@inproceedings{grassi2016timevertex,
title={Tracking time-vertex propagation using dynamic graph wavelets},
author={Grassi, Francesco and Perraudin, Nathanael and Ricaud, Benjamin},
year={2016},
booktitle={Signal and Information Processing (GlobalSIP), 2016 IEEE Global Conference on},
}

@article{grassi2018timevertex,
title={A time-vertex signal processing framework: Scalable processing and meaningful representations for time-series on graphs},
author={Grassi, Francesco and Loukas, Andreas and Perraudin, Nathanael and Ricaud, Benjamin},
year={2018},
journal={IEEE Transactions on Signal Processing},
}
11 changes: 11 additions & 0 deletions pygsp/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@
Then, derived classes implement various common graph filters.
**Filters that solve differential equations**
The following filters solve partial differential equations (PDEs) on graphs,
which model processes such as heat diffusion or wave propagation.
.. autosummary::
Heat
Wave
**Low-pass filters**
.. autosummary::
Expand Down Expand Up @@ -122,6 +132,7 @@
'Regular',
'Simoncelli',
'SimpleTight',
'Wave',
]
_APPROXIMATIONS = [
'compute_cheby_coeff',
Expand Down
2 changes: 1 addition & 1 deletion pygsp/filters/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def estimate_frame_bounds(self, x=None):
Without a null-space, the heat kernel forms a frame:
>>> g = filters.Heat(G, tau=[1, 10])
>>> g = filters.Heat(G, scale=[1, 10])
>>> A, B = g.estimate_frame_bounds()
>>> print('A={:.3f}, B={:.3f}'.format(A, B))
A=0.135, B=2.000
Expand Down
2 changes: 1 addition & 1 deletion pygsp/filters/gabor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Gabor(Filter):
>>>
>>> g1 = filters.Expwin(G, band_min=None, band_max=0, slope=3)
>>> g2 = filters.Rectangular(G, band_min=-0.05, band_max=0.05)
>>> g3 = filters.Heat(G, tau=10)
>>> g3 = filters.Heat(G, scale=10)
>>>
>>> fig, axes = plt.subplots(3, 2, figsize=(10, 10))
>>> for g, ax in zip([g1, g2, g3], axes):
Expand Down
94 changes: 48 additions & 46 deletions pygsp/filters/heat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,67 +10,55 @@
class Heat(Filter):
r"""Design a filter bank of heat kernels.
The (low-pass) heat kernel filter is defined in the spectral domain as
The (low-pass) heat kernel is defined in the spectral domain as
.. math:: \hat{g}_\tau(\lambda) =
\exp \left( -\tau \frac{\lambda}{\lambda_\text{max}} \right).
.. math:: g_\tau(\lambda) = \exp(-\tau \lambda),
where :math:`\lambda \in [0, 1]` are the normalized eigenvalues of the
graph Laplacian, and :math:`\tau` is a parameter that captures both time
and thermal diffusivity.
The heat kernel is the fundamental solution to the heat equation
.. math:: \tau L f(t) = - \partial_t f(t),
.. math:: - \tau L f(t) = \partial_t f(t),
where :math:`f: \mathbb{R}_+ \rightarrow \mathbb{R}^N`. Given the initial
condition :math:`f(0)`, the solution of the heat equation is expressed as
where :math:`f: \mathbb{R}_+ \rightarrow \mathbb{R}^N` is the heat
distribution over the graph at time :math:`t`. Given the initial condition
:math:`f(0)`, the solution of the heat equation is expressed as
.. math:: f(t) = e^{-L \tau t} f(0)
= U e^{-\Lambda \tau t} U^\top f(0)
= K_t(L) f(0).
.. math:: f(t) = e^{-\tau t L} f(0)
= U e^{-\tau t \Lambda} U^\top f(0)
= g_{\tau t}(L) f(0).
The above is, by definition, the convolution of the signal :math:`f(0)`
with the kernel :math:`K_t(\lambda) = \exp(-\tau t \lambda) = \hat{g}_\tau
(t \lambda \lambda_\text{max})`.
with the kernel :math:`g_{\tau t}(\lambda) = \exp(-\tau t \lambda)`.
Hence, applying this filter to a signal simulates heat diffusion.
Since the kernel is applied to the graph eigenvalues :math:`\Lambda`, which
Since the kernel is applied to the graph eigenvalues :math:`\lambda`, which
can be interpreted as squared frequencies, it can also be considered as a
generalization of the Gaussian kernel on graphs.
Parameters
----------
G : graph
tau : int or list of ints
Scaling parameter. If a list, creates a filter bank with one filter per
value of tau.
scale : float or iterable
Scaling parameter. When solving heat diffusion, it encompasses both
time and thermal diffusivity.
If iterable, creates a filter bank with one filter per value.
normalize : bool
Normalizes the kernel. Needs the eigenvalues.
Whether to normalize the kernel to have unit L2 norm.
The normalization needs the eigenvalues of the graph Laplacian.
Examples
--------
Regular heat kernel.
>>> G = graphs.Logo()
>>> g = filters.Heat(G, tau=[5, 10])
>>> print('{} filters'.format(g.Nf))
2 filters
>>> y = g.evaluate(G.e)
>>> print('{:.2f}'.format(np.linalg.norm(y[0])))
9.76
Normalized heat kernel.
>>> g = filters.Heat(G, tau=[5, 10], normalize=True)
>>> y = g.evaluate(G.e)
>>> print('{:.2f}'.format(np.linalg.norm(y[0])))
1.00
Filter bank's representation in Fourier and time (ring graph) domains.
>>> import matplotlib.pyplot as plt
>>> G = graphs.Ring(N=20)
>>> G.estimate_lmax()
>>> G.set_coordinates('line1D')
>>> g = filters.Heat(G, tau=[5, 10, 100])
>>> g = filters.Heat(G, scale=[5, 10, 100])
>>> s = g.localize(G.N // 2)
>>> fig, axes = plt.subplots(1, 2)
>>> _ = g.plot(ax=axes[0])
Expand All @@ -89,7 +77,8 @@ class Heat(Filter):
>>> delta = np.zeros(graph.n_vertices)
>>> delta[sources] = 5
>>> steps = np.array([1, 5])
>>> g = filters.Heat(graph, tau=10*steps)
>>> diffusivity = 10
>>> g = filters.Heat(graph, scale=diffusivity*steps)
>>> diffused = g.filter(delta)
>>> fig, axes = plt.subplots(1, len(steps), figsize=(10, 4))
>>> _ = fig.suptitle('Heat diffusion', fontsize=16)
Expand All @@ -99,28 +88,41 @@ class Heat(Filter):
... ax.set_aspect('equal', 'box')
... ax.set_axis_off()
Normalized heat kernel.
>>> G = graphs.Logo()
>>> G.compute_fourier_basis()
>>> g = filters.Heat(G, scale=5)
>>> y = g.evaluate(G.e)
>>> print('norm: {:.2f}'.format(np.linalg.norm(y[0])))
norm: 9.76
>>> g = filters.Heat(G, scale=5, normalize=True)
>>> y = g.evaluate(G.e)
>>> print('norm: {:.2f}'.format(np.linalg.norm(y[0])))
norm: 1.00
"""

def __init__(self, G, tau=10, normalize=False):
def __init__(self, G, scale=10, normalize=False):

try:
iter(tau)
iter(scale)
except TypeError:
tau = [tau]
scale = [scale]

self.tau = tau
self.scale = scale
self.normalize = normalize

def kernel(x, t):
return np.minimum(np.exp(-t * x / G.lmax), 1)
def kernel(x, scale):
return np.minimum(np.exp(-scale * x / G.lmax), 1)

kernels = []
for t in tau:
norm = np.linalg.norm(kernel(G.e, t)) if normalize else 1
kernels.append(lambda x, t=t, norm=norm: kernel(x, t) / norm)
for s in scale:
norm = np.linalg.norm(kernel(G.e, s)) if normalize else 1
kernels.append(lambda x, s=s, norm=norm: kernel(x, s) / norm)

super(Heat, self).__init__(G, kernels)

def _get_extra_repr(self):
tau = '[' + ', '.join('{:.2f}'.format(t) for t in self.tau) + ']'
return dict(tau=tau, normalize=self.normalize)
scale = '[' + ', '.join('{:.2f}'.format(s) for s in self.scale) + ']'
return dict(scale=scale, normalize=self.normalize)
132 changes: 132 additions & 0 deletions pygsp/filters/wave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-

from __future__ import division

from functools import partial

import numpy as np

from . import Filter # prevent circular import in Python < 3.5


class Wave(Filter):
r"""Design a filter bank of wave kernels.
The wave kernel is defined in the spectral domain as
.. math:: g_{\tau, t}(\lambda) = \cos \left( t
\arccos \left( 1 - \frac{\tau^2}{2} \lambda \right) \right),
where :math:`\lambda \in [0, 1]` are the normalized eigenvalues of the
graph Laplacian, :math:`t` is time, and :math:`\tau` is the propagation
speed.
The wave kernel is the fundamental solution to the wave equation
.. math:: - \tau^2 L f(t) = \partial_{tt} f(t),
where :math:`f: \mathbb{R}_+ \rightarrow \mathbb{R}^N` models, for example,
the mechanical displacement of a wave on a graph. Given the initial
condition :math:`f(0)` and assuming a vanishing initial velocity, i.e., the
first derivative in time of the initial distribution equals zero, the
solution of the wave equation is expressed as
.. math:: f(t) = U g_{\tau, t}(\Lambda) U^\top f(0)
= g_{\tau, t}(L) f(0).
The above is, by definition, the convolution of the signal :math:`f(0)`
with the kernel :math:`g_{\tau, t}`.
Hence, applying this filter to a signal simulates wave propagation.
Parameters
----------
G : graph
time : float or iterable
Time step.
If iterable, creates a filter bank with one filter per value.
speed : float or iterable
Propagation speed, bounded by 0 (included) and 2 (excluded).
If iterable, creates a filter bank with one filter per value.
References
----------
:cite:`grassi2016timevertex`, :cite:`grassi2018timevertex`
Examples
--------
Filter bank's representation in Fourier and time (ring graph) domains.
>>> import matplotlib.pyplot as plt
>>> G = graphs.Ring(N=20)
>>> G.estimate_lmax()
>>> G.set_coordinates('line1D')
>>> g = filters.Wave(G, time=[5, 15], speed=1)
>>> s = g.localize(G.N // 2)
>>> fig, axes = plt.subplots(1, 2)
>>> _ = g.plot(ax=axes[0])
>>> _ = G.plot(s, ax=axes[1])
Wave propagation from two sources on a grid.
>>> import matplotlib.pyplot as plt
>>> n_side = 11
>>> graph = graphs.Grid2d(n_side)
>>> graph.estimate_lmax()
>>> sources = [
... (n_side//4 * n_side) + (n_side//4),
... (n_side*3//4 * n_side) + (n_side*3//4),
... ]
>>> delta = np.zeros(graph.n_vertices)
>>> delta[sources] = 5
>>> steps = np.array([5, 10])
>>> g = filters.Wave(graph, time=steps, speed=1)
>>> propagated = g.filter(delta)
>>> fig, axes = plt.subplots(1, len(steps), figsize=(10, 4))
>>> _ = fig.suptitle('Wave propagation', fontsize=16)
>>> for i, ax in enumerate(axes):
... _ = graph.plot(propagated[:, i], highlight=sources,
... title='step {}'.format(steps[i]), ax=ax)
... ax.set_aspect('equal', 'box')
... ax.set_axis_off()
"""

def __init__(self, G, time=10, speed=1):

try:
iter(time)
except TypeError:
time = [time]
try:
iter(speed)
except TypeError:
speed = [speed]

self.time = time
self.speed = speed

if len(time) != len(speed):
if len(speed) == 1:
speed = speed * len(time)
elif len(time) == 1:
time = time * len(speed)
else:
raise ValueError('If both parameters are iterable, '
'they should have the same length.')

if np.any(np.asarray(speed) >= 2):
raise ValueError('The wave propagation speed should be in [0, 2[')

def kernel(x, time, speed):
return np.cos(time * np.arccos(1 - speed**2 * x / G.lmax / 2))

kernels = [partial(kernel, time=t, speed=s)
for t, s in zip(time, speed)]

super(Wave, self).__init__(G, kernels)

def _get_extra_repr(self):
time = '[' + ', '.join('{:.2f}'.format(t) for t in self.time) + ']'
speed = '[' + ', '.join('{:.2f}'.format(s) for s in self.speed) + ']'
return dict(time=time, speed=speed)

0 comments on commit c30cec5

Please sign in to comment.