# Graph ConvNet for cosmology: whole sphere classification

[Nathanaël Perraudin](http://perraudin.info), [Michaël Defferrard](http://deff.ch), Tomasz Kacprzak

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = ""

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from scnn import experiment_helper
from scnn import models, utils, plot
from scnn.data import LabeledDataset

In [None]:
plt.rcParams['figure.figsize'] = (17, 5)

In [None]:
EXP_NAME = 'whole_sphere'

## 1 Data loading

The data consists of a toy dataset that is sufficiently small to have fun with. It is made of 200 maps of size `NSIDE=64` splitted into 2 classes. 

In [None]:
data = np.load('data/maps_downsampled_64.npz')
assert(len(data['class1']) == len(data['class2']))
nclass = len(data['class1'])

Let us plot a map of each class. It is not simple to visually catch the differences.

In [None]:
cmin = min(np.min(data['class1']), np.min(data['class2']))
cmax = max(np.max(data['class1']), np.max(data['class2']))
cm = plt.cm.RdBu_r
cm.set_under('w')
hp.mollview(data['class1'][0], title='class 1', nest=True, cmap=cm, min=cmin, max=cmax)
hp.mollview(data['class2'][0], title='class 2', nest=True, cmap=cm, min=cmin, max=cmax)

However, those maps have different Power Spectral Densities PSD.

In [None]:
sample_psd_class1 = np.empty((nclass, 192))
sample_psd_class2 = np.empty((nclass, 192))

for i in range(nclass):
    sample_psd_class1[i] = experiment_helper.psd(data['class1'][i])
    sample_psd_class2[i] = experiment_helper.psd(data['class2'][i])

In [None]:
ell = np.arange(sample_psd_class1.shape[1])
plot.plot_with_std(ell, sample_psd_class1*ell*(ell+1), label='class 1, Omega_matter=0.3, mean', color='b')
plot.plot_with_std(ell,sample_psd_class2*ell*(ell+1), label='class 2, Omega_matter=0.5, mean', color='r')
plt.legend(fontsize=16);
plt.xlim([10, np.max(ell)])
plt.ylim([1e-6, 1e-3])
# plt.yscale('log')
plt.xscale('log')
plt.xlabel('$\ell$: spherical harmonic index', fontsize=18)
plt.ylabel('$C_\ell \cdot \ell \cdot (\ell+1)$', fontsize=18)
plt.title('Power Spectrum Density, 3-arcmin smoothing, noiseless, Nside=1024', fontsize=18);

## 2 Data preparation

Let us split the data into training and testing sets. The raw data is stored into `x_raw` and the power spectrum densities into `x_psd`.

In [None]:
# Normalize and transform the data, i.e. extract features.
x_raw = np.vstack((data['class1'], data['class2']))
x_raw = x_raw / np.mean(x_raw**2) # Apply some normalization (We do not want to affect the mean)
x_psd = preprocessing.scale(np.vstack((sample_psd_class1, sample_psd_class2)))

# Create the label vector
labels = np.zeros([x_raw.shape[0]], dtype=int)
labels[nclass:] = 1

# Random train / test split
ntrain = 150
ret = train_test_split(x_raw, x_psd, labels, test_size=2*nclass-ntrain, shuffle=True)
x_raw_train, x_raw_test, x_psd_train, x_psd_test, labels_train, labels_test = ret

print('Class 1 VS class 2')
print('  Training set: {} / {}'.format(np.sum(labels_train==0), np.sum(labels_train==1)))
print('  Test set: {} / {}'.format(np.sum(labels_test==0), np.sum(labels_test==1)))

## 3 Classification using SVM

As a baseline, let us classify our data using an SVM classifier.

* An SVM based on the raw feature cannot discriminate the data because the dimensionality of the data is too large.
* We however observe that the PSD features are linearly separable.

In [None]:
clf = SVC(kernel='rbf')
clf.fit(x_raw_train, labels_train) 

e_train = experiment_helper.model_error(clf, x_raw_train, labels_train)
e_test = experiment_helper.model_error(clf, x_raw_test, labels_test)
print('The training error is: {}%'.format(e_train*100))
print('The testing error is: {}%'.format(e_test*100))

In [None]:
clf = SVC(kernel='linear')
clf.fit(x_psd_train, labels_train) 

e_train = experiment_helper.model_error(clf, x_psd_train, labels_train)
e_test = experiment_helper.model_error(clf, x_psd_test, labels_test)
print('The training error is: {}%'.format(e_train*100))
print('The testing error is: {}%'.format(e_test*100))

## 4 Classification using a spherical CNN

Let us now classify our data using a spherical convolutional neural network.

In [None]:
params = dict()
params['dir_name']       = EXP_NAME
params['num_epochs']     = 5
params['batch_size']     = 10
params['eval_frequency'] = 10

# Building blocks.
params['brelu']          = 'b1relu'  # Activation.
params['pool']           = 'apool1'  # Pooling.

# Architecture.
params['nsides']         = [64, 32, 16]  # Sizes of the laplacians are 12 * nsides**2.
params['F']              = [5, 10, 10]  # Number of graph convolutional filters.
params['K']              = [10, 10, 10]  # Polynomial orders.
params['batch_norm']     = [True, True, True]  # Batch norm.
params['M']              = [100, 2]  # Output dimensionality of fully connected layers.

# Optimization.
params['regularization'] = 1e-4
params['dropout']        = 0.5 # 1 is no dropout
params['learning_rate']  = 1e-3
params['decay_rate']     = 0.98
params['momentum']       = 0.9
params['adam']           = True
params['decay_steps']    = ntrain / params['batch_size']

In [None]:
model = models.scnn(**params)

In [None]:
training = LabeledDataset(x_raw_train, labels_train)
testing = LabeledDataset(x_raw_test, labels_test)

In [None]:
accuracy, loss, t_step = model.fit(training, testing)

In [None]:
e_train = experiment_helper.model_error(model, x_raw_train, labels_train)
e_test = experiment_helper.model_error(model, x_raw_test, labels_test)
print('The training error is: {}%'.format(e_train*100))
print('The testing error is: {}%'.format(e_test*100))

## 5 Filters visualization

The package offers a few different visualizations for the learned filters. First we can simply look at the Chebyshef coefficients. This visualization is not very interpretable for human, but can help for debugging problems related to optimization.

In [None]:
layer=2
model.plot_chebyshev_coeffs(layer)

We observe the Chebyshef polynomial, i.e the filters in the graph spectral domain. This visuallization can help to understand wich graph frequencies are picked by the filtering operation. It mostly interpretable by the people for the graph signal processing community.

In [None]:
model.plot_filters_spectral(layer);

Here comes one of the most human friendly representation of the filters. It consists the section of the filters "projected" on the sphere. Because of the irregularity of the healpix sampling, this representation of the filters may not look very smooth.

In [None]:
matplotlib.rcParams.update({'font.size': 16})
model.plot_filters_section(layer, title='');

Eventually, we can simply look at the filters on sphere. This representation clearly displays the sampling artifacts.

In [None]:
model.plot_filters_gnomonic(layer, title='');