# Experiments on pooling
* non-uniform sampling scheme
* random part of sphere

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import sys
sys.path.append("../..")
sys.path.append("..")
import numpy as np
import matplotlib.pyplot as plt
import healpy as hp
from tqdm import tqdm

from mpl_toolkits.mplot3d import Axes3D
import cartopy.crs as ccrs

In [None]:
pathfig = './figures/'

## Non-uniform sampling scheme

Using GHCN data

In [None]:
from GHCN.GHCN_preprocessing import get_data, get_stations, sphereGraph

datapath = "/mnt/nas/LTS2/datasets/ghcn-daily/processed/"
years = np.arange(2010,2015)
features = ["TMIN"]

n_stations, ghcn_to_local, lat, lon, _, _ = get_stations(datapath, years)
data, n_days = get_data(datapath, years, features, ghcn_to_local)

print(f'n_stations: {n_stations}, n_days: {n_days}')

In [None]:
dataset = data.transpose((1, 0, 2))
keepToo = ~np.isnan(dataset).any(axis=0)
keepToo = keepToo.all(axis=1)
dataset = dataset[:, keepToo, :]

#### Original signal

In [None]:
fig = plt.figure(figsize=(25, 25))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))
ax.set_global()
ax.coastlines(linewidth=2)

zmin, zmax = -20, 40

sc = ax.scatter(lon[keepToo], lat[keepToo], s=10,
                c=np.clip(data[keepToo, 0, 0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),
                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())

#### Using spectral clustering

Use KNN graph, and merge the spectral cluster

In [None]:
k = 5
NCluster = dataset.shape[1]//k
g_1 = sphereGraph(lon[keepToo], lat[keepToo], 10, rad=False, epsilon=False)
g_1.plot()
g_1.compute_laplacian('combinatorial')
g_1.compute_fourier_basis(n_eigenvectors=NCluster)

In [None]:
from sklearn.cluster import KMeans
eig_vectors = g_1.U[:,1:NCluster+1]
clusters = KMeans(n_clusters=NCluster).fit_predict(eig_vectors)

In [None]:
clust_lon, clust_lat = np.empty(NCluster), np.empty(NCluster)
size = dataset.shape
size = list(size)
size[1] = NCluster
size = tuple(size)
new_map = np.zeros(size)
pool = 'max'
pool_fun = getattr(np, pool)
for i in range(NCluster):
    indices = np.where(clusters==i)[0]
    clust_lon[i] = lon[keepToo][indices].mean()
    clust_lat[i] = lat[keepToo][indices].mean()
    data_p = dataset[:, indices, :]
    new_map[:,i,:] = pool_fun(data_p, axis=1)

In [None]:
fig = plt.figure(figsize=(25, 25))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))
ax.set_global()
ax.coastlines(linewidth=2)

zmin, zmax = -20, 40

sc = ax.scatter(clust_lon, clust_lat, s=50,
                c=np.clip(new_map[0,:,0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),
                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())

In [None]:
NCluster_2 = new_map.shape[1]//k
g_2 = sphereGraph(clust_lon, clust_lat, 10, rad=False, epsilon=False)
fig = plt.figure(figsize=(25,25))
axes = fig.add_subplot(111, projection='3d')
g_2.plot(vertex_size=50, ax=axes)
g_2.compute_laplacian('combinatorial')
g_2.compute_fourier_basis(n_eigenvectors=NCluster_2)

In [None]:
eig_vectors2 = g_2.U[:,1:NCluster_2+1]
clusters2 = KMeans(n_clusters=NCluster_2).fit_predict(eig_vectors2)

In [None]:
clust_lon2, clust_lat2 = np.empty(NCluster_2), np.empty(NCluster_2)
size = new_map.shape
size = list(size)
size[1] = NCluster_2
size = tuple(size)
new_map2 = np.zeros(size)
pool = 'max'
pool_fun = getattr(np, pool)
for i in range(NCluster_2):
    indices = np.where(clusters2==i)[0]
    clust_lon2[i] = clust_lon[indices].mean()
    clust_lat2[i] = clust_lat[indices].mean()
    data_p = new_map[:, indices, :]
    new_map2[:,i,:] = pool_fun(data_p, axis=1)

In [None]:
fig = plt.figure(figsize=(25, 25))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))
ax.set_global()
ax.coastlines(linewidth=2)

zmin, zmax = -20, 40

sc = ax.scatter(clust_lon2, clust_lat2, s=40,
                c=np.clip(new_map2[0,:,0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),
                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())

TODO: implement these operation in Tensorflow

#### Sparsify the graph

In [None]:
import pygsp as pg
from pygsp.reduction import graph_multiresolution, graph_sparsify

In [None]:
gNew1 = graph_sparsify(g_1, 0.8)
## find a way to find the new coordinates

#### Using uniform sampling scheme

In [None]:
import healpy as hp

In [None]:
pool = 'max' # in ['max', 'average', ...]
Nside = 64
theta = lon[keepToo]#np.deg2rad(lon[keepToo])
phi = lat[keepToo]#np.deg2rad(lat[keepToo])
# dataset_temp
pix, weights = hp.get_interp_weights(Nside, theta, phi, nest=True, lonlat=True)
indexes = np.unique(pix)
size = dataset.shape
size = list(size)
size[1] = len(indexes) # hp.nside2npix(Nside)
size = tuple(size)
new_map = np.zeros(size)
# new_map[new_map==0] = hp.UNSEEN
pool_fun = getattr(np, pool)
for i, index in enumerate(indexes):
    pl = np.where(pix==index)
#     if pl[0].shape[0] == 1:
#         continue
    wght = 1/(weights[pl]+1e-8)
    wght[wght>1] = 1
    data_p = wght[np.newaxis,:,np.newaxis] * dataset[:, pl[1], :]
    new_map[:,i,:] = pool_fun(data_p, axis=1)
new_lon, new_lat = hp.pix2ang(Nside, indexes, nest=True, lonlat=True)

In [None]:
fig = plt.figure(figsize=(25, 25))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))
ax.set_global()
ax.coastlines(linewidth=2)

zmin, zmax = -20, 40

sc = ax.scatter(new_lon, new_lat, s=40,
                c=np.clip(new_map[0, :,0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),
                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())

## Part of sphere

In [None]:
def ds_index(index, nsides, nest=True):
    """Return list of indexes sampled at specific nsides.
    
    The given index must be sampled at the first nside given
    Parameters
    ----------
    index : list of pixel position for part of sphere
    nsides : list of nside for the desired scale
    """
    assert isinstance(nsides, list)
    assert len(nsides) > 1
    assert nest  # not implemented yet
    
    indexes = [index]
    for nside in nsides[1:]:
        p = (nsides[0]/nside)**2
        if p < 1:
            raise NotImplementedError("upsampling not implemented yet")
        temp_index = index//p
        indexes.append(np.unique(temp_index).astype(int))            
    
    return indexes

In [None]:
def pool_part_max(x, p, Nside, index):
        """Max pooling of size p on partial sphere. Sould be a power of 2."""
        if p > 1:
            indexes = ds_index(index, [Nside, Nside//(p**0.5)])[1]
            full_map = tf.ones([x.shape[0], hp.nside2npix(Nside), x.shape[2]]) * -1e8
#             j = 0
#             full_map = []
#             for i in tqdm(range(hp.nside2npix(Nside))):
#                 if i in index:
#                     full_map.append(x[:,j,:])
#                     j += 1
#                 else:
#                     full_map.append(tf.ones([x.shape[0], x.shape[2]]) * -1e8)
            for i, ind in tqdm(enumerate(index)):
                new_full_map = tf.Variable(full_map, trainable=False)[:,ind,:].assign(x[:,i,:])
#             full_map[:,index.astype(np.int32),:].assign(x[:,:,:])
#             full_map = tf.stack(full_map, axis=1)
            new_full_map = tf.expand_dims(new_full_map, 3)
            new_full_map = tf.nn.max_pool(new_full_map, ksize=[1,p,1,1], strides=[1,p,1,1], padding='SAME')
            new_full_map = tf.squeeze(new_full_map, [3])
            x_new = tf.gather(full_map, indexes, axis=1)
            return x_new
        else:
            return x, _, _

Takes too much time

Must find a way to make it faster

In [None]:
import tensorflow as tf
x = tf.placeholder(dtype=tf.float32, shape = new_map.shape, name='input_data')
feed_dict = {x: new_map}
op_pool = pool_part_max(x, 4, Nside, indexes)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
new_map2 = sess.run(op_pool, feed_dict=feed_dict)

In [None]:
new_lon2, new_lat2 = hp.pix2ang(Nside//2, ds_index(indexes, [Nside, Nside//(4**0.5)])[1], nest=True, lonlat=True)
fig = plt.figure(figsize=(25, 50))
ax = fig.add_subplot(2, 1, 1, projection=ccrs.Orthographic(-50, 90))
ax.set_global()
ax.coastlines(linewidth=2)

zmin, zmax = -20, 40

sc = ax.scatter(new_lon, new_lat, s=40,
                c=np.clip(new_map[0, :, 0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),
                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())

ax = fig.add_subplot(2, 1, 2, projection=ccrs.Orthographic(-50, 90))
ax.set_global()
ax.coastlines(linewidth=2)

sc = ax.scatter(new_lon2, new_lat2, s=40,
                c=np.clip(new_map2[0, :, 0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),
                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())