In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pygsp
import healpy as hp
import pycgsp

## Ground truth sky

In [None]:
wham_gt = hp.read_map("data/lambda_WHAM_1_256.fits")
nside = 2*16
wham_gt -=  wham_gt.min() - 1e-7
wham_gt = np.log10(wham_gt)
wham_gt -=  wham_gt.min()
wham_gt /=  np.max(wham_gt)
wham_gt = hp.pixelfunc.ud_grade(wham_gt, nside)

In [None]:
hp.orthview(wham_gt, cmap="cubehelix")

In [None]:
hp.mollview(wham_gt, cmap="cubehelix")

## Create spherical graph

In [None]:
from functools import partial
from scipy import sparse, spatial, linalg

def _kernel_exponential(distance, power=1, value_at_one=0.5):
        cst = np.log(value_at_one)
        return np.exp(cst * distance**power)

def hpix_nngraph(hpix_map):
    npix = len(hpix_map)
    nside = hp.npix2nside(npix)
    x, y, z = hp.pix2vec(nside, np.arange(npix))
    R = np.stack((x, y, z), axis=-1)
    cols = hp.get_all_neighbours(nside, np.arange(npix)).transpose().reshape(-1)
    cols[cols == -1] = npix - 1
    rows = np.repeat(np.arange(npix), 8, axis=-1).transpose().reshape(-1)

    W = sparse.coo_matrix((cols * 0 + 1, (rows, cols)), shape=(npix, npix))
    extended_row = np.concatenate([W.row, W.col])
    extended_col = np.concatenate([W.col, W.row])
    W.row, W.col = extended_row, extended_col
    W.data = np.concatenate([W.data, W.data])
    W = W.tocsr().tocoo()
    distance = linalg.norm(R[W.row, :] - R[W.col, :], axis=-1)
    rho = np.mean(distance)
    W.data = np.exp(- (distance / rho) ** 2)
    W = W.tocsc()

    return W, R



def hpix_graph(hpix_map, indexes=None, nest=False, **kwargs):
    
    n_vertices = len(hpix_map)
    nside = hp.npix2nside(n_vertices)
    indexes = np.arange(n_vertices)
    x, y, z = hp.pix2vec(nside, indexes, nest=nest)
    coords = np.stack([x, y, z], axis=1)
    k = kwargs.pop('k', 10)
    params = dict(p=2, k=k+1, eps=0)
    tree = spatial.cKDTree(coords)
    
    distances, neighbors = tree.query(coords, **params)
    n_edges = [len(n) - 1 for n in neighbors]  # remove distance to self

    value = np.empty(sum(n_edges), dtype=float)
    row = np.empty_like(value, dtype=int)
    col = np.empty_like(value, dtype=int)
    start = 0
    for vertex in range(n_vertices):
        assert n_edges[vertex] == k
        end = start + n_edges[vertex]
        value[start:end] = distances[vertex][1:]
        row[start:end] = np.full(n_edges[vertex], vertex)
        col[start:end] = neighbors[vertex][1:]
        start = end
    W = sparse.csr_matrix((value, (row, col)), (n_vertices, n_vertices))
        
    kernel_width = np.mean(W.data)
    kernel = partial(_kernel_exponential, power=2)
    W.data = kernel(W.data / kernel_width)
    lat, lon = hp.pix2ang(nside, indexes, nest=nest, lonlat=False)
    lat = np.pi/2 - lat  # colatitude to latitude
    return W, coords, lat, lon


In [None]:
import pycsou.abc as pyca
class SphericalPooling(pyca.LinOp):
    def __init__(self, nside_in: int, nside_out: int, order_in: str = 'RING', order_out: str = 'RING',
                 pooling_func: str = 'mean'):
        r"""

        Parameters
        ----------
        nside_in: int
            Parameter NSIDE of the input HEALPix map.
        nside_out: int
            Parameter NSIDE of the pooled HEALPix map.
        order_in: str ['RING', 'NESTED']
            Ordering of the input HEALPix map.
        order_out: str ['RING', 'NESTED']
            Ordering of the pooled HEALPix map.
        pooling_func: str ['mean', 'sum']
            Pooling function.
        dtype: type
            Data type of the linear operator.

        Raises
        ------
        ValueError
            If ``nside_out >= nside_in``.
        """
        if nside_out >= nside_in:
            raise ValueError('Parameter nside_out must be smaller than nside_in.')
        self.nside_in = nside_in
        self.nside_out = nside_out
        self.order_in = order_in
        self.order_out = order_out
        self._power = None if pooling_func == 'mean' else -2
        dim = hp.pixelfunc.nside2npix(nside_in)
        codim = hp.pixelfunc.nside2npix(nside_out)
        super(SphericalPooling, self).__init__(shape=(codim, dim))


    def apply(self, map_in: np.ndarray) -> np.ndarray:
        return hp.pixelfunc.ud_grade(map_in=map_in, nside_out=self.nside_out, order_in=self.order_in, order_out=self.order_out,
                        dtype=map_in.dtype, power=self._power)

    def adjoint(self, pooled_map: np.ndarray) -> np.ndarray:
        return hp.pixelfunc.ud_grade(map_in=pooled_map, nside_out=self.nside_in, order_in=self.order_out, order_out=self.order_in,
                        dtype=pooled_map.dtype)

In [None]:
#W, coords, lat, lon = hpix_graph(wham_gt)
W, R = hpix_nngraph(wham_gt)
import pygsp
sphere = pygsp.graphs.Graph(W, coords=R)
sphere.N, sphere.Ne

In [None]:
fig, ax = plt.subplots(1,2, figsize=(15,5),subplot_kw = {"projection":"3d"})
ax[0].scatter(*R.T, s=5, c=wham_gt, alpha=0.5, cmap="cubehelix")
sphere.plot_signal(wham_gt, ax=ax[1])
ax[1].set_title("")
plt.show()

## Data acquisition (blurring)

In [None]:
import pycgsp.operator.linop.conv as pycgspc
import pycgsp.operator.linop.diff as pycgspd

In [None]:
laplacian = pycgspd.GraphLaplacian(W)
laplacian

In [None]:
lmax = laplacian.lipschitz(tight=False)
lmax

In [None]:
kernel = lambda x: np.exp(-2*x)
conv, _, _ = pycgspc.GraphConvolution(L=laplacian, 
                             kernel=kernel,
                             lmax=laplacian.lipschitz(tight=True),
                             order=20)
conv.lipschitz(tight=True)

In [None]:
n = 0.01*np.random.randn(sphere.N)
y = conv(wham_gt) + n

In [None]:
hp.mollview(wham_gt, cmap="cubehelix")
hp.mollview(y, cmap="cubehelix")

In [None]:
from pycsou.opt.solver.pds import PD3O
from pycsou._dev import SquaredL2Norm, L1Norm

In [None]:
l22_loss = (1/2) * SquaredL2Norm(y.size).argshift(-y)
fidelity = l22_loss * conv
fidelity.diff_lipschitz(tight=True)
K = pycgspd.GraphGradient(W)
K.lipschitz(tight=True)
H = 0.0005 * L1Norm()
G = 0.1 * L1Norm()

In [None]:
loss = fidelity + G + H * K
x0 = np.random.randn(wham_gt.size)
print(fidelity(x0))
print(G(x0))
print((H * K)(x0))
print(loss(x0))

In [None]:
solver = PD3O(f=fidelity, g=G, h=H, K=K, show_progress=True, verbosity=250)
solver.fit(x0=x0, tuning_strategy=2)
recons = solver.solution()

In [None]:
hp.mollview(wham_gt, cmap="cubehelix")
hp.mollview(y, cmap="cubehelix")
hp.mollview(recons, cmap="cubehelix")

# Add downsampling (pooling)

In [None]:
d_factor = 2

nsensors = hp.nside2npix(nside//d_factor)

pooling = SphericalPooling(nside_in=nside, nside_out=nside//d_factor, pooling_func="sum")
wham_pooled = pooling(wham_gt)


n = 0.01*np.random.randn(nsensors)
y2 = pooling(conv(wham_gt)) + n

In [None]:
hp.mollview(wham_gt, cmap="cubehelix")
hp.mollview(y2, cmap="cubehelix")

In [None]:
l22_loss = (1/2) * SquaredL2Norm(y2.size).argshift(-y2)
fidelity2 = l22_loss * pooling * conv
fidelity2.diff_lipschitz(tight=True)
K = pycgspd.GraphGradient(W)
K.lipschitz(tight=True)

loss = fidelity2 + G + H * K
x0 = np.random.randn(wham_gt.size)
print(fidelity(x0))
print(G(x0))
print((H * K)(x0))
print(loss(x0))

In [None]:
solver = PD3O(f=fidelity2, g=G, h=H, K=K, show_progress=True, verbosity=250)
solver.fit(x0=x0, tuning_strategy=2)
recons2 = solver.solution()

In [None]:
hp.mollview(wham_gt, cmap="cubehelix")
hp.mollview(y2, cmap="cubehelix")
hp.mollview(recons2, cmap="cubehelix")