In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pygsp
import healpy as hp
import pycgsp

## Ground truth sky

In [None]:
# Let's create an empty healpix map
nside = 2**9
gt = np.zeros(hp.pixelfunc.nside2npix(nside))

# Let's work on a field of view of 20 degrees, centered at [0, 0, 0]
fov = 20 / 180 * np.pi # rad
zoom = hp.query_disc(nside, [0, 0, 0], radius=fov/2)

# Let's add some circles to our sky
n_circles = 10
circles = []
rng = np.random.default_rng(0)
for i in range(n_circles):
    max_size = fov / 5
    size = rng.uniform(low=max_size/2, high=max_size) 
    pos = rng.uniform(low=-fov/2 + max_size, high=fov/2 - max_size, size=2)
    # 3rd dimension enforces to be located on the sphere
    pos = np.concatenate([pos, np.sqrt(1 - (pos**2).sum(keepdims=True))])
    circles.append((pos, size))
    
    # Draw circles
    ids = hp.query_disc(nside=nside, vec=pos, radius= size/2)
    gt[ids] += 1.

gt /= gt.max()

print(len(gt))

In [None]:
hp.orthview(gt, cmap="cubehelix", rot=(0, 90, 0), half_sky=True)

In [None]:
hp.gnomview(gt, rot=(0, 90,0), xsize=1000, cmap="cubehelix")

## Create spherical graph

In [None]:
from functools import partial
from scipy import sparse, spatial, linalg
import time
def _kernel_exponential(distance, power=1, value_at_one=0.5):
        cst = np.log(value_at_one)
        return np.exp(cst * distance**power)

def hpix_nngraph(hpix_map, ids=None):
    if ids is None:
        ids = np.arange(npix)
        
    npix = len(hpix_map)
    nside = hp.npix2nside(npix)
    x, y, z = hp.pix2vec(nside, ids)
    R = np.stack((x, y, z), axis=-1)
    G = pygsp.graphs.NNGraph(R, k=8)
    W = G.W.tocoo()
    extended_row = np.concatenate([W.row, W.col])
    extended_col = np.concatenate([W.col, W.row])
    W.row, W.col = extended_row, extended_col
    W.data = np.concatenate([W.data, W.data])
    W = W.tocsr().tocoo()
    distance = linalg.norm(R[W.row, :] - R[W.col, :], axis=-1)
    rho = np.mean(distance)
    W.data = np.exp(- (distance / rho) ** 2)
    W = W.tocsc()
    return W, R

W, R = hpix_nngraph(gt, zoom)
sphere = pygsp.graphs.Graph(W, coords=R)
sphere.N, sphere.Ne

## Create measured data on sphere


In [None]:
import pycgsp.operator.linop.conv as pycgspc
import pycgsp.operator.linop.diff as pycgspd

laplacian = pycgspd.GraphLaplacian(W)
lmax = laplacian.lipschitz(tight=False)
kernel = lambda x: np.exp(-5*abs(x))
conv, _, _ = pycgspc.GraphConvolution(L=laplacian, 
                             kernel=kernel,
                             lmax=laplacian.lipschitz(tight=True),
                             order=20)
conv.lipschitz(tight=True)
n = 0.05*np.random.randn(sphere.N)

In [None]:
y = np.zeros_like(gt)
y[zoom] = conv(gt[zoom]) + n

hp.gnomview(gt, rot=(0, 90,0), xsize=1000, cmap="cubehelix")
hp.gnomview(y, rot=(0, 90,0), xsize=1000, cmap="cubehelix")

## Create direction cosine grid (tangent plane approximation) and interpolate

In [None]:
resolution = hp.nside2resol(nside)
dcos_ax = np.arange(-fov/2, fov/2, resolution)
n_pix = len(dcos_ax)
dcos_x, dcos_y = np.meshgrid(dcos_ax, -dcos_ax)
dcos = np.stack([dcos_x.ravel(), dcos_y.ravel(), np.sqrt(1- dcos_x.ravel()**2 + dcos_y.ravel()**2)]).T

theta, phi = hp.pixelfunc.vec2ang(dcos, lonlat=False)
y_dcos = hp.pixelfunc.get_interp_val(y, theta, phi, nest=False, lonlat=False).reshape(n_pix, n_pix).T

In [None]:
plt.imshow(y_dcos, cmap="cubehelix", interpolation='none', extent=[-fov/2,fov/2,-fov/2,fov/2])
plt.colorbar()
plt.xlabel("l")
plt.ylabel("m")
plt.title("Tangent plane projection (Direction cosines)")

## Reconstruct in Sphere graph vs in tangent plane

In [None]:
from pycsou.opt.solver.pds import PD3O
from pycsou.operator.func import SquaredL2Norm, L1Norm

### In spherical graph

In [None]:
l22_loss = (1/2) * SquaredL2Norm(zoom.size).argshift(-y[zoom])
fidelity = l22_loss * conv
fidelity.diff_lipschitz(tight=True)
K = pycgspd.GraphGradient(W)
K.lipschitz(tight=True)

In [None]:
H = 0.01 * L1Norm()
G = 0.01 * L1Norm()
loss = fidelity + G + H * K
x0 = np.random.randn(zoom.size)
print(fidelity(x0))
print(G(x0))
print((H * K)(x0))
print(loss(x0))

In [None]:
solver = PD3O(f=fidelity, g=G, h=H, K=K, show_progress=True, verbosity=250)
solver.fit(x0=x0, tuning_strategy=2)
recons = solver.solution()

In [None]:
y_recons = np.zeros_like(gt)
y_recons[zoom] = recons
hp.gnomview(gt, rot=(0, 90,0), xsize=1000, cmap="cubehelix")
hp.gnomview(y, rot=(0, 90,0), xsize=1000, cmap="cubehelix")
hp.gnomview(y_recons, rot=(0, 90,0), xsize=1000, cmap="cubehelix")
plt.show()

### In tangent plane projection

In [None]:
from pycsou.operator.linop import Gradient, Convolve

In [None]:
k_width = 10
support = np.arange(-k_width//2, k_width//2 + 1) * resolution
dcos_kernel = kernel(support)
dcos_kernel /= dcos_kernel.sum()

plt.plot(support, dcos_kernel)
plt.show()

In [None]:
arg_shape = y_dcos.shape
conv_dcos = Convolve(arg_shape=arg_shape,
                     kernel=[dcos_kernel, dcos_kernel],
                     center=[len(dcos_kernel)//2, len(dcos_kernel)//2])
l22_loss = (1/2) * SquaredL2Norm(y_dcos.size).argshift(-y_dcos.ravel())
fidelity = l22_loss * conv_dcos
fidelity.diff_lipschitz(tight=False)
K = Gradient(arg_shape=y_dcos.shape)
K.lipschitz(tight=False)

In [None]:
H = 0.01 * L1Norm()
G = 0.01 * L1Norm()
loss = fidelity + G + H * K
x0 = np.random.randn(y_dcos.size)
print(fidelity(x0))
print(G(x0))
print((H * K)(x0))
print(loss(x0))

In [None]:
solver = PD3O(f=fidelity, g=G, h=H, K=K, show_progress=True, verbosity=250)
solver.fit(x0=x0, tuning_strategy=2)
recons_dcos = solver.solution().reshape(n_pix, n_pix)

In [None]:
plt.figure(figsize=(15,5))
plt.subplot(121)
plt.imshow(y_dcos, cmap="cubehelix", interpolation='none', extent=[-fov/2,fov/2,-fov/2,fov/2])
plt.colorbar()
plt.xlabel("l")
plt.ylabel("m")
plt.title("Tangent plane projection (Direction cosines)")
plt.subplot(122)
plt.imshow(recons_dcos, cmap="cubehelix", interpolation='none', extent=[-fov/2,fov/2,-fov/2,fov/2])
plt.colorbar()
plt.xlabel("l")
plt.ylabel("m")
plt.title("Reconstruction")
plt.show()

In [None]:
# Interpolate back to original positions in sphere
from scipy.interpolate import interpn
y_recons_dcos = np.zeros_like(gt)    
y_recons_dcos[zoom] = interpn(
    (dcos_ax, dcos_ax), 
    np.fliplr(recons_dcos),
    R[:, :2],
    method="linear",
    bounds_error=False,
    fill_value=np.nan,
)


In [None]:
hp.gnomview(y_recons, rot=(0, 90,0), xsize=1000, cmap="cubehelix")
hp.gnomview(y_recons_dcos, rot=(0, 90,0), xsize=1000, cmap="cubehelix")
plt.show()

In [None]:
nmse = lambda x, y: np.linalg.norm(x - y) / np.linalg.norm(x)
print(f"Error for HPix grid = {nmse(gt[zoom],y_recons[zoom])*100:0.2f}%")
print(f"Error for DCOS grid = {nmse(gt[zoom],y_recons_dcos[zoom])*100:0.2f}%")