# Graph ConvNet for cosmology: demo of spherical convolution

[Nathanaël Perraudin](http://perraudin.info), [Michaël Defferrard](http://deff.ch), Tomasz Kacprzak

In this small notebook, we test an implementation of a spherical convolution. The general idea is to use a graph instead of the tradtionial 2 dimensional grid as a support for convolution.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import healpy as hp
from pygsp import filters

from scnn import utils

In [None]:
plt.rcParams['figure.figsize'] = (17, 5)  # (9, 4) for matplotlib notebook

## 1 Graph

Let us start by constructing two small graphs on the healpix sampling scheme and visualizing them (with and without edges).

In [None]:
fig = plt.figure()

ax = fig.add_subplot(121, projection='3d')
G = utils.healpix_graph(nside=8, nest=True)
G.plotting.update(vertex_size=10)
G.plot(show_edges=False, ax=ax)

ax = fig.add_subplot(122, projection='3d')
G = utils.healpix_graph(nside=4, nest=True)
G.plotting.update(vertex_size=20)
G.plot(ax=ax)

The healpix sampling induces an 8 nearest neighbors graph, i.e. a graph where each vertex is connected to 8 vertices. Some vertices are however connected to 7 neighbors only.

In [None]:
for i in np.unique(G.d):
    print('Number of nodes with {} neighbors: {}'.format(i, np.sum(G.d == i)))

## 2 Fourier basis

Graph convolution is defined as the pointwise multiplication in the graph spectral domain. Hence it is important to verify the spectral property of the graph. Note that this operation requires the diagonalization of the Laplacian, which is very costly in computations and mermory. Nevertheless, when it comes to convolution, their exist fast methods that only require sparse matrix multiplications.

In [None]:
G = utils.healpix_graph(nside=16, lap_type='normalized', nest=True, dtype=np.float64)

print('max weighted degree: {:.2f}'.format(G.dw.max()))
print('min weighted degree: {:.2f}'.format(G.dw.min()))
print('mean weighted degree: {:.2f}'.format(G.dw.mean()))
print('Is the graph directed? {}'.format(G.is_directed()))
print('Number of nodes: {}'.format(G.N))

The eigenvectors are obtained by diagonalizing the graph laplacian defined as $L=I-D^{\frac{1}{2}}WD^{\frac{1}{2}}$, where $W$ is the weight/adjacency matrix and $D$ the degree matrix. 

The Fourier basis $U$ by definition satisfies
$$ L  = U \Lambda U^*. $$
Here the eigenvalues contained in the diagonal of $\Lambda$ somehow correspond to the graph squared frequencies.

In [None]:
# Compute all eigenvectors.
G.compute_fourier_basis()

In [None]:
print('Mean: {}'.format(np.mean(np.abs(G.U.ravel()))))
print('Min: {}'.format(np.min(G.U.ravel())))
print('Max: {}'.format(np.max(G.U.ravel())))
print('Perline: {}'.format(np.max(G.U, axis=1)))

Let us display a few Fourier modes on the healpix map.

In [None]:
fig = plt.figure()
ne = 16

for ind in range(ne):
    hp.mollview(G.U[:, ind], 
                title='Eigenvector {}'.format(ind), 
                nest=True, 
                sub=(np.sqrt(ne), np.sqrt(ne), ind+1),
                max=np.max(np.abs(G.U[:, :ne])),
                min=-np.max(np.abs(G.U[:, :ne])),
                cbar=False)

We should also check higher frequency modes as they can be more localized.

In [None]:
IND = 3000
fig = plt.figure(figsize=(3, 2))
hp.mollview(G.U[:, IND], title="Eigenvector {}".format(IND), nest=True, cbar=False, sub=(1, 1, 1))

The most localized eigenvector is considered to be the one with the heighest coherence.

In [None]:
ind = np.argmax(np.max(np.abs(G.U), axis=0))
fig = plt.figure(figsize=(3, 2))
hp.mollview(G.U[:, ind], title="Eigenvector {}".format(ind), nest=True, cbar=False, sub=(1, 1, 1))

This eigenvector is clearly very localized. Let us display the modulus of the Fourier eigenvector to have a more general idea about all eigenvectors.

In [None]:
fig = plt.figure()
plt.imshow(np.abs(G.U), cmap='Greys');

## 3 Convolution on graphs

The convolution of a signal $f$ and a kernel $k(x)$ on a graph is defined as the pointwise multiplication in the spectral domain, i.e.
$$f_c  = U k(\Lambda)U^*f. $$
Here $U^*f$ is the graph Fourier transform of $f$ and $k(\Lambda)$ is a diagonal matrix where the kernel $k$ is applied on each element of the diagonal of $\Lambda$. 

Let us start with the heat diffusion problem. We solve the following equation on the graph:
$$ L f(t) = \tau \partial_t f(t),$$
where $f(t): \mathbb{R}_+ \rightarrow \mathbb{R}^N$ is a multivariate function depending on the time, $L$ a positive semi-definite matrix representing the Laplacian of a graph, and $\tau$ a constant.

Given the vector $f_0 = f(0)$, the solution of this equation for time $t$ can be written as:
$$ f(t) = K_t(L) f_0, $$
where 
$$ K_t(L) = e^{-\tau t L}.$$
In the equation $f(t) = K_t(L) f_0$, the kernel $K_t(x)=e^{-\tau t x}$ can be considered as the convolution kernel and the heat diffusion problem can be solved using a simple convolution on the graph.

In [None]:
taus = [5, 10, 20, 50]
hf = filters.Heat(G, tau=taus)
fig, ax = plt.subplots()
hf.plot(plot_eigenvalues=True, show_sum=False, ax=ax)
ax.set_title('Filter frequency response');

In [None]:
for ind0 in [0, 500]:
    
    sig = np.zeros(G.N)
    sig[ind0] = 1
    conv = hf.analyze(sig)

    fig = plt.figure()

    for i, tau in enumerate(taus):
        hp.mollview(conv[:, i], 
                    title='ind0={}, tau={}'.format(ind0, tau), 
                    nest=True, 
                    sub=(np.sqrt(len(taus)), np.sqrt(len(taus)), i+1))

## 4 Smoothing a Planck map

Let us play with with a Planck map.

In [None]:
map_cmb, map_noise, map_mask = hp.read_map('data/COM_CMB_IQU-smica_1024_R2.02_full.fits', field=(0, 1, 3), nest=True)

In [None]:
hp.mollview(map_cmb, title='cmb', nest=True)
# hp.mollview(map_noise, title='noise', cmap='RdBu')
# hp.mollview(map_mask, title='mask', cmap='RdBu')

Let us first select a lower resolution: NSIDE=256, making the total number of pixels: $ N = 256^2 \cdot 12=786432.$

In [None]:
nside = 256
map_cmb_lores = hp.ud_grade(map_cmb, nside_out=nside, order_in='NESTED')
G = utils.healpix_graph(nside=nside, nest=True)
G.estimate_lmax()

Let apply our heat operator. It will smooth the map.

In [None]:
taus = [5, 10, 20, 50]
hf = filters.Heat(G, tau=taus)
conv_map_lowres = hf.analyze(map_cmb_lores)

fig = plt.figure(figsize=(9, 6))
for i, tau in enumerate(taus):
    hp.mollview(conv_map_lowres[:, i], 
                title="Tau: {}".format(tau), 
                nest=True, 
                sub=(np.sqrt(len(taus)), np.sqrt(len(taus)), i+1))

## 5 Power spectral density

Let us now compute the power spectral density on the sphere. This is going to be different than the traditionial one.

In [None]:
def estimate_graph_psd(G, sig, Nrand=10, Npoint=30):
    """Estimate the power spectral density on graph.
    
    Parameters
    ----------
    G : graph
    sig : ndarray
        Signal whose PSD is to be estimated.
    Nrand : int
        Number of random signals used for the estimation.
    Npoint : int
        Number of points at which the PSD is estimated.
    """
    
    # Define filterbank.
    g = filters.Itersine(G, Nf=Npoint, overlap=2)
    mu = np.linspace(0, G.lmax, Npoint)
    
    # Filter signal.
    sig_filt = g.filter(sig, method='chebyshev', order=2*Npoint)
    sig_dist = np.sum(sig_filt**2, axis=0)
    if sig_dist.ndim > 1:
        sig_dist = np.mean(sig_dist, axis=0).squeeze()
    
    # Estimate the eigenvectors by filtering random signals.
    rand_sig = np.random.binomial(n=1, p=0.5, size=[G.N, Nrand]) * 2 - 1
    rand_sig_filered = g.filter(rand_sig, method='chebyshev', order=2*Npoint)
    eig_dist = np.mean(np.sum(rand_sig_filered**2, axis=0), axis=0).squeeze()
    
    # Compute PSD.
    psd_values = sig_dist / eig_dist
    inter = interp1d(mu, psd_values, kind='linear')
    
    return filters.Filter(G, inter), (mu, psd_values)

psd_filter, psd_point = estimate_graph_psd(G, map_cmb_lores, Nrand=5, Npoint=30)

In [None]:
psd_filter.plot()
plt.plot(*psd_point, 'x')

plt.figure()
plt.semilogy(*psd_point);