In [None]:
# Specify CUDA device
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from jax.config import config
config.update("jax_enable_x64", True)

# Check we're running on GPU
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

from jax import jit, grad 
import jax.numpy as jnp 
import numpy as np 

from matplotlib.image import pil_to_array
from PIL import Image

import s2fft 
import s2wav 
from scatcovjax.Scattering_lib import scat_cov_dir, quadrature

In [None]:
sampling = "mw"
multiresolution = True
reality = True
L = 256
N = 3
J_min = 0
np.random.seed(98956)

In [None]:
# Planet image
# import pyssht as ssht
# grayscale_pil_image = Image.open('../texture_maps/venus.jpg').convert("L")
# I = pil_to_array(grayscale_pil_image).astype(np.float64)
# I = np.ascontiguousarray(I[:,:-1])
# L_temp = I.shape[0]
# Ilm = ssht.forward(I, L_temp, Reality=reality)
# Ilm = Ilm[:L**2]
# I = ssht.inverse(Ilm, L, Reality=reality)

# Weak lensing map
I = np.load(f'../texture_maps/raw_data/CosmoML_shell_40_L_{L}.npy')
I -= np.nanmean(I)
I /= np.nanstd(I)
Ilm = s2fft.forward_jax(I, L, reality=reality)

In [None]:
filters = s2wav.filter_factory.filters.filters_directional_vectorised(L, N, J_min)[0]
weights = quadrature(L, J_min, sampling, None, multiresolution)
precomps = s2wav.transforms.jax_wavelets.generate_wigner_precomputes(L, N, J_min, 2.0, sampling, None, False, reality, multiresolution)

f = np.random.randn(L,2*L-1).astype(np.float64)
flm = s2fft.forward_jax(f, L, reality=reality)
flm = flm[:,L-1:] if reality else flm
flm_start = jnp.copy(flm)

In [None]:
ps_true = jnp.sum(jnp.abs(Ilm[:,L-1:])**2, axis=-1)
@jit
def func_ps_only(flm):
    ps = jnp.sum(jnp.abs(flm)**2, axis=-1)
    return jnp.sum(jnp.abs(ps - ps_true)**2)
grad_func_ps = jit(grad(func_ps_only))

loss_0_ps = func_ps_only(flm_start)

import optax 
lr = 1e-3
optimizer = optax.adam(lr)
opt_state = optimizer.init(flm)

for i in range(1000):
    grads = jnp.conj(grad_func_ps(flm))
    updates, opt_state = optimizer.update(grads, opt_state)
    flm = optax.apply_updates(flm, updates)
    if i % 100 == 0: 
        print(f"Iteration {i}: Loss/Loss-0 = {func_ps_only(flm)}/{loss_0_ps}")

In [None]:
mean, var, S1, P00, C01, C11 = scat_cov_dir(Ilm[:,L-1:], L, N, J_min, sampling, None, reality, multiresolution, filters=filters, quads=weights, precomps=precomps)

In [None]:
@jit
def func(flm):
    mean_new, var_new, S1_new, P00_new, C01_new, C11_new = scat_cov_dir(flm, L, N, J_min, sampling, None, reality, multiresolution, filters=filters, quads=weights, precomps=precomps)

    # Control for mean + var
    loss = jnp.abs(mean-mean_new)**2
    loss += jnp.abs(var-var_new)**2
    
    # Add S1 loss
    loss += jnp.sum(jnp.abs(S1-S1_new)**2)

    # Add P00 loss
    loss += jnp.sum(jnp.abs(P00-P00_new)**2)

    # Add C01 loss
    loss += jnp.sum(jnp.abs(C01-C01_new)**2)

    # Add C11 loss
    loss += jnp.sum(jnp.abs(C11-C11_new)**2)

    return loss

grad_func = jit(grad(func))
loss_0 = func(flm_start)

In [None]:
import optax 
lr = 1e-2
optimizer = optax.adam(lr)
opt_state = optimizer.init(flm)

for i in range(100000):
    grads = jnp.conj(grad_func(flm))
    updates, opt_state = optimizer.update(grads, opt_state)
    flm = optax.apply_updates(flm, updates)
    if i % 10 == 0: 
        print(f"Iteration {i}: Loss/Loss-0 = {func(flm)}/{loss_0}")

In [None]:
from matplotlib import pyplot as plt 

if reality:
    # Create and store signs 
    msigns = (-1)**jnp.arange(1,L)

    # Reflect and apply hermitian symmetry
    flm_full_end = jnp.zeros((L, 2*L-1), dtype=jnp.complex128)
    flm_full_end = flm_full_end.at[:,L-1:].set(flm)
    flm_full_end = flm_full_end.at[:, : L - 1].set(jnp.flip(jnp.conj(flm_full_end[:, L:])*msigns, axis=-1))

    # Reflect and apply hermitian symmetry
    flm_full_start = jnp.zeros((L, 2*L-1), dtype=jnp.complex128)
    flm_full_start = flm_full_start.at[:,L-1:].set(flm_start)
    flm_full_start = flm_full_start.at[:, : L - 1].set(jnp.flip(jnp.conj(flm_full_start[:, L:])*msigns, axis=-1))

I = s2fft.inverse_jax(Ilm, L, reality=reality)
f_start = s2fft.inverse_jax(flm_full_start, L, reality=reality)
f_end = s2fft.inverse_jax(flm_full_end, L, reality=reality)

mx, mn = np.nanmax(I), np.nanmin(I)
fig, (ax1,ax2, ax3) = plt.subplots(3,1, figsize=(14,20), dpi=200)
ax1.imshow(I, vmax=mx, vmin=mn, cmap='magma')
ax2.imshow(f_start, vmax=mx, vmin=mn, cmap='magma')
ax3.imshow(f_end, vmax=mx, vmin=mn, cmap='magma')
plt.show()

In [None]:
# Upscale for nice plot
L_new = 2*L
data = [I, f_start, f_end]
data_new = []
import pyssht as ssht
for d, datum in enumerate(data):
    dlm = ssht.forward(np.array(datum),L)
    dlm_new = np.zeros(L_new**2, dtype=np.complex128)
    dlm_new[:L**2] = dlm
    data_new.append(np.real(ssht.inverse(dlm_new, L_new)))

mx, mn = np.nanmax(data_new[0]), np.nanmin(data_new[0])
fig, (ax1,ax2, ax3) = plt.subplots(3,1, figsize=(14,20), dpi=200)
ax1.imshow(data_new[0], vmax=mx, vmin=mn, cmap='magma')
ax2.imshow(data_new[1], vmax=mx, vmin=mn, cmap='magma')
ax3.imshow(data_new[2], vmax=mx, vmin=mn, cmap='magma')
plt.show()

In [None]:
def compute_ps(flm):
    return jnp.sum(jnp.abs(flm)**2,axis=-1)

In [None]:
from matplotlib import pyplot as plt 
ps_true = compute_ps(Ilm)
ps_start = compute_ps(flm_full_start)
ps_end = compute_ps(flm_full_end)
plt.plot(ps_true, 'b', label="target")
plt.plot(ps_start, 'r', label="start")
plt.plot(ps_end, 'g', label="end")
plt.yscale("log")
plt.legend()
plt.show()

In [None]:
mean_start, var_start, S1_start, P00_start, C01_start, C11_start = scat_cov_dir(flm_start, L, N, J_min, sampling, None, reality, multiresolution, filters=filters, quads=weights, precomps=precomps)
mean_end, var_end, S1_end, P00_end, C01_end, C11_end = scat_cov_dir(flm, L, N, J_min, sampling, None, reality, multiresolution, filters=filters, quads=weights, precomps=precomps)

print("mean", np.real(mean), np.real(mean_start), np.real(mean_end))
print("var", np.real(var), np.real(var_start), np.real(var_end))

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(3,1, figsize=(40,8))
y_max, y_min = np.nanmax(np.real(S1)), np.nanmin(np.real(S1))
ax1.plot(np.real(S1), 'b', label="Target")
ax2.plot(np.real(S1_start), 'r', label="Target")
ax3.plot(np.real(S1_end), 'g', label="Target")
ax1.set_ylim([y_min, y_max])
ax2.set_ylim([y_min, y_max])
ax3.set_ylim([y_min, y_max])
plt.show()

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(3,1, figsize=(40,8))
y_max, y_min = np.nanmax(np.real(P00)), np.nanmin(np.real(P00))
ax1.plot(np.real(P00), 'b', label="Target")
ax2.plot(np.real(P00_start), 'r', label="Target")
ax3.plot(np.real(P00_end), 'g', label="Target")
ax1.set_ylim([y_min, y_max])
ax2.set_ylim([y_min, y_max])
ax3.set_ylim([y_min, y_max])
plt.show()

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(3,1, figsize=(40,8))
y_max, y_min = np.nanmax(np.real(C01)), np.nanmin(np.real(C01))
ax1.plot(np.real(C01), 'b', label="Target")
ax2.plot(np.real(C01_start), 'r', label="Target")
ax3.plot(np.real(C01_end), 'g', label="Target")
ax1.set_ylim([y_min, y_max])
ax2.set_ylim([y_min, y_max])
ax3.set_ylim([y_min, y_max])
plt.show()

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(3,1, figsize=(40,8))
y_max, y_min = np.nanmax(np.real(C11)), np.nanmin(np.real(C11))
ax1.plot(np.real(C11), 'b', label="Target")
ax2.plot(np.real(C11_start), 'r', label="Target")
ax3.plot(np.real(C11_end), 'g', label="Target")
ax1.set_ylim([y_min, y_max])
ax2.set_ylim([y_min, y_max])
ax3.set_ylim([y_min, y_max])
plt.show()