# Graph ConvNet for cosmology: part of sphere classification

[Nathanaël Perraudin](http://perraudin.info), [Michaël Defferrard](http://deff.ch), Tomasz Kacprzak

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
import healpy as hp

from scnn import models

In [None]:
plt.rcParams['figure.figsize'] = (17, 5)

## 1 Load and visualize spherical data 

Load two maps with the same PSD and different high order statistics.

In [None]:
img1 = hp.read_map('data/same_psd/kappa_omega_m_0p3.fits')
img2 = hp.read_map('data/same_psd/kappa_omega_m_0p26.fits')
img1 = hp.reorder(img1, r2n=True)
img2 = hp.reorder(img2, r2n=True)

Downsample the maps.

In [None]:
Nside = 1024
img1 = hp.ud_grade(img1, nside_out=Nside, order_in='NESTED')
img2 = hp.ud_grade(img2, nside_out=Nside, order_in='NESTED')

Display the two maps.

In [None]:
cmin = min(np.min(img1), np.min(img2))
cmax = max(np.max(img1), np.max(img2))
hp.mollview(img1, title='Map 1, omega_m=0.31, pk_norm=0.82, h=0.7', nest=True, min=cmin, max=cmax)
hp.mollview(img2, title='Map 2, omega_m=0.26, sigma_8=0.91, h=0.7', nest=True, min=cmin, max=cmax)

Let us cut the sphere into 192 smaller subparts. We display 16 of them bellow.

In [None]:
order = 4

index = np.arange(hp.nside2npix(order)) + 2
mask = np.zeros_like(index, dtype=np.bool)
mask[:order**2] = 1
index *= mask
hp.mollview(index, title='Some sphere subparts', nest=True)

marker = np.zeros(hp.nside2npix(order))
marker[0] = 1
hp.mollview(marker, title='Selected indexes', nest=True)

## 2 Data preparation

### 2.1 Samples creation

We here create samples by dividing the two complete spheres in patches (based on healpix sampling).

In [None]:
def hp_split(img, order, nest=True):
    """
    Split the data of different part of the sphere. 
    Return the splitted data and some possible index on the sphere.
    """
    npix = len(img)
    nside = hp.npix2nside(npix)
    if hp.nside2order(nside) < order:
        raise ValueError('Order not compatible with data.')
    if not nest:
        raise NotImplementedError('Implement the change of coordidinate.')
    nsample = 12 * order**2
    return img.reshape([nsample, npix//nsample]), np.arange(npix//nsample)

In [None]:
data = dict()
data['class1'], _ = hp_split(img1, order=4)
data['class2'], index = hp_split(img2, order=4)

print('The data is of shape {}'.format(data['class1'].shape))

Let's look at one data sample on the entire sphere.

In [None]:
# npix = hp.nside2npix(nside)
# mask = np.ones([npix])>0
# mask[index] = False
# hp.ma()

img = img1.copy()
img[data['class1'].shape[1]:] = hp.UNSEEN
img = hp.ma(img)

projected_map = hp.mollview(img, nest=True, return_projected_map=True, xsize=1600)

plt.figure()
plt.imshow(projected_map[380:520, 530:670]);

### 2.2 Normalization and train / test split 

Let us split the data into training and testing sets. The raw data is stored into `x_raw` and the histograms into `x_trans`. As a transformation, we cannot use the power spectrum density. Hence we do an histogram of the data.

In [None]:
def histogram(x, cmin, cmax, bins=100):
    if x.ndim == 1:
        y, _ = np.histogram(x, bins=bins, range=[cmin, cmax])
        return y.astype(float)
    else:
        y = np.empty((len(x), bins), float)
        for i in range(len(x)):
            y[i], _ = np.histogram(x[i], bins=bins, range=[cmin, cmax])
        return y

In [None]:
# Normalize and transform the data, i.e. extract features.
x_raw = np.vstack((data['class1'], data['class2']))
x_raw = x_raw / np.mean(x_raw**2) # Apply some normalization (We do not want to affect the mean)
cmin = np.min(x_raw)
cmax = np.max(x_raw)
x_hist = histogram(x_raw, cmin, cmax)
x_trans = preprocessing.scale(x_hist)

# Create the label vector.
labels = np.zeros([x_raw.shape[0]], dtype=int)
labels[len(data['class1']):] = 1

# Random train / test split.
ntrain = 300
ret = train_test_split(x_raw, x_trans, labels, test_size=len(x_raw)-ntrain, shuffle=True)
x_raw_train, x_raw_test, x_trans_train, x_trans_test, labels_train, labels_test = ret

print('Class 1 VS class 2')
print('  Training set: {} / {}'.format(np.sum(labels_train==0), np.sum(labels_train==1)))
print('  Test set: {} / {}'.format(np.sum(labels_test==0), np.sum(labels_test==1)))

### 2.3 Histogram features visualization

Let us first plot the mean and then each feature individually.

In [None]:
fig, axes = plt.subplots(1, 2)

axes[0].plot(np.mean(x_hist[labels==0], axis=0), label='class 1')
axes[0].plot(np.mean(x_hist[labels==1], axis=0), label='class 2')
axes[0].legend()
axes[0].set_title('Mean histogram accross each class')

axes[1].plot(x_hist[labels==0].T, 'b')
axes[1].plot(x_hist[labels==1].T, 'r')
axes[1].set_title('Histograms of individual samples');

## 3  Classification using SVM

Let us test classify our data using an SVM classifier.

While running an SVM classifier on the data will fail because of their dimensionality, we observe that we can correctly classify our dataset using the histogram.

In [None]:
def print_error(model, x, labels, name):
    pred = model.predict(x)
    error = sum(np.abs(pred - labels)) / len(labels)
    print('{} error: {:.2%}'.format(name, error))

In [None]:
clf = SVC(kernel='rbf')
clf.fit(x_raw_train, labels_train)

print_error(clf, x_raw_train, labels_train, 'Training')
print_error(clf, x_raw_test, labels_test, 'Test')

In [None]:
clf = SVC(kernel='rbf')
clf.fit(x_trans_train, labels_train) 

print_error(clf, x_trans_train, labels_train, 'Training')
print_error(clf, x_trans_test, labels_test, 'Test')

## 4 Classification using a spherical CNN

Let us now classify our data using a spherical convolutional neural network.

In [None]:
nsides = [Nside, Nside, Nside//2, min(Nside//8, 128)]
# nsides = [2048, 1024, 256, 64]
# nsides = [128, 32, 16]

nsample = 12 * order**2
indexes = [np.arange(hp.nside2npix(nside)//nsample) for nside in nsides]

In [None]:
C = 2 # number of class

params = dict()
params['dir_name']       = 'sphere_part'
params['num_epochs']     = 10
params['batch_size']     = 20
params['eval_frequency'] = 10

# Building blocks.
params['brelu']          = 'b1lrelu' # Relu 
params['pool']           = 'apool1' # Average pooling

# Architecture.
params['nsides']         = nsides # Sizes of the laplacians are 12 * nsides**2.
params['indexes']        = indexes # Sizes of the laplacians are 12 * nsides**2.
params['F']              = [5, 20, 80, 10]  # Number of graph convolutional filters.
params['K']              = [10, 10, 10, 10]  # Polynomial orders.
params['batch_norm']     = [True, True, True, True]  # Batch norm.
params['M']              = [100, C]  # Output dimensionality of fully connected layers.

# Optimization.
params['regularization'] = 2e-4
params['dropout']        = 0.8
params['learning_rate']  = 1e-3
params['decay_rate']     = 0.95
params['momentum']       = 0.9
params['adam']           = True
params['decay_steps']    = ntrain / params['batch_size']

In [None]:
model = models.scnn(**params)

In [None]:
accuracy, loss, t_step = model.fit(x_raw_train, labels_train, x_raw_test, labels_test)

In [None]:
print_error(model, x_raw_train, labels_train, 'Training')
print_error(model, x_raw_test, labels_test, 'Test')

## 5 Discussion

Without subsampling
I train the spherical CNN a few minutes on CPU and I obtain 96% validation accuracy. 

SVM is consistenly failling with the raw data but succeed with the histograms.

Conclusion: the spherical CNN is able to discriminate over data with the same mean and same PSD using only 192th of the sphere.

Effect of subsampling
 - N=512, errors on training/testing: 11.66%, 78.57% => complete fail
 - N=1024, errors on training/testing: 0%, 0-3% => partial success
 - N=2048, errors on training/testing: 0%, 3% => partial success

Maybe this is also due to the fact that the training/validation sets are not the same for each run.
    

## Some other plotting

In [None]:
# from scnn import utils
# nside_v = 32
# nsample = 12 * (order**2)
# ind = np.array(list(range(hp.nside2npix(nside_v)//nsample)))
# G = utils.healpix_graph(nside=nside_v, nest=True, indexes=ind)

# G.plot()