# Analyse synthesis outputs

In [1]:
%matplotlib notebook

import numpy as np
from matplotlib import pyplot as plt
import glob
from importlib import reload

import scatcovjax.Sphere_lib as sphlib
from scatcovjax.Scattering_lib import scat_cov_axi, scat_cov_dir
from s2wav.filter_factory.filters import filters_directional_vectorised

import s2fft
import s2wav
import scatcovjax.plotting as plot

plot.notebook_plot_format()


# Parameters

In [2]:
reality = True
sampling = 'mw'
multiresolution = True
L = 256
N = 2
epochs = 400
J_min = 1

J_max = s2wav.utils.shapes.j_max(L)
J = J_max - J_min + 1
print(f'{J=} {J_max=}')

J=8 J_max=8


# Make filters

In [3]:
filters = filters_directional_vectorised(L, N, J_min)
plot.plot_filters(filters, real=True, m=L-1)

<IPython.core.display.Javascript object>

(<Figure size 800x600 with 1 Axes>, <Axes: >)

# Load the data

In [6]:
# List of job with these parameters
job = 7083
job_list = glob.glob(f'/travail/lmousset/scatJAX_tests_mai2023/*{job}*L{L}_N{N}_Jmin{J_min}_epochs{epochs}')
print(len(job_list), job_list)

# Choose the run
run = 0
output_dir = job_list[run]
print('\n', output_dir)

# flm
flm_target = np.load(output_dir + '/flm_target.npy')
flm_start = np.load(output_dir + '/flm_start.npy')
flm_end = np.load(output_dir + '/flm_end.npy')
print(flm_end.shape)

# Loss
loss_history = np.load(output_dir + '/loss.npy')

# Coeffs
tmean, tvar, tS1, tP00, tC01, tC11 = np.load(output_dir + '/coeffs_target.npy', allow_pickle=True)
smean, svar, sS1, sP00, sC01, sC11 = np.load(output_dir + '/coeffs_start.npy', allow_pickle=True)
emean, evar, eS1, eP00, eC01, eC11 = np.load(output_dir + '/coeffs_end.npy', allow_pickle=True)

1 ['/travail/lmousset/scatJAX_tests_mai2023/Job7083_L256_N2_Jmin1_epochs400']

 /travail/lmousset/scatJAX_tests_mai2023/Job7083_L256_N2_Jmin1_epochs400
(256, 256)


In [7]:
tP00

Array([4.48798951e-01, 4.48798951e-01, 4.48798951e-01, 1.04719755e-01,
       1.04719755e-01, 1.04719755e-01, 2.53354246e-02, 2.53354246e-02,
       2.53354246e-02, 6.23331876e-03, 6.23331876e-03, 6.23331876e-03,
       1.54605938e-03, 1.54605938e-03, 1.54605938e-03, 3.84999100e-04,
       3.84999100e-04, 3.84999100e-04, 9.60614192e-05, 9.60614192e-05,
       9.60614192e-05, 9.60614192e-05, 9.60614192e-05, 9.60614192e-05],      dtype=float64)

In [8]:
if reality:  # Get the full flm
    flm_target = sphlib.make_flm_full(flm_target, L)
    flm_start = sphlib.make_flm_full(flm_start, L)
    flm_end = sphlib.make_flm_full(flm_end, L)
print(flm_end.shape)

(256, 511)


In [9]:
### Cut the flm that are not contrained
flm_target = flm_target.at[0: 2**J_min + 1, :].set(0. + 0.j)

In [10]:
# Make the maps
f_target = s2fft.inverse_jax(flm_target, L, reality=reality)
f_start = s2fft.inverse_jax(flm_start, L, reality=reality)
f_end = s2fft.inverse_jax(flm_end, L, reality=reality)

# Plots

In [11]:
np.std(f_target)

Array(0.99977529, dtype=float64)

In [12]:
nit1 = 400
nit2 = 1000
step = 10

plt.figure(figsize=(8, 6))
plt.plot(np.arange(0, nit1, step), loss_history[:int(nit1/step)], 'ro--', label='All coeffs')
#plt.plot(np.arange(nit1, nit1 + nit2, step), loss_history[int(nit1/step):], 'ro--', label='All coeffs')
#plt.plot(loss_history, 'bo--')
plt.yscale('log')
plt.ylabel('Loss')
plt.xlabel('Number of iterations')
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f951a728340>

In [13]:
loss_history

array([7.87623083e+00, 5.16967353e+00, 1.50239465e+00, 4.34571698e-01,
       1.26603293e-01, 5.01709637e-02, 3.83917864e-02, 3.07387943e-02,
       2.21033793e-02, 1.73558185e-02, 1.51679689e-02, 1.31590561e-02,
       1.17444457e-02, 1.07244653e-02, 9.69177684e-03, 8.86511259e-03,
       8.30803910e-03, 7.72217649e-03, 7.17922731e-03, 6.80810987e-03,
       6.57654339e-03, 6.30588745e-03, 6.00317046e-03, 5.81199078e-03,
       5.70196381e-03, 5.61092629e-03, 5.41496518e-03, 5.27412732e-03,
       5.20972571e-03, 5.19928565e-03, 5.12912348e-03, 4.99897649e-03,
       4.95710381e-03, 4.95157139e-03, 4.88911365e-03, 4.79541785e-03,
       4.74162228e-03, 4.68165219e-03, 4.63288775e-03, 4.57574113e-03,
       4.54491693e-03])

In [14]:
mn, mx = np.nanmin(f_target), np.nanmax(f_target)
mn, mx = -1, 3

fig, (ax1,ax2, ax3) = plt.subplots(1,3, figsize=(10,3))
ax1.imshow(f_target, vmax=mx, vmin=mn, cmap='viridis')
ax2.imshow(f_start, vmax=mx, vmin=mn, cmap='viridis')
ax3.imshow(f_end, vmax=mx, vmin=mn, cmap='viridis')
plt.show()

<IPython.core.display.Javascript object>

In [15]:
print(np.std(f_target), np.std(f_start))

0.999775287940524 0.7096516898787615


In [16]:
reload(plot)
plot.plot_map_MW_Mollweide(f_target, vmin=mn, vmax=mx, figsize=(10, 6), title='')#, title=f'Target - {mn=:.2f}, {mx=:.2f}')

<IPython.core.display.Javascript object>

In [17]:
plot.plot_map_MW_Mollweide(f_end, vmin=mn, vmax=mx, figsize=(10, 6), title='')#, title=f'End - {mn=:.2f}, {mx=:.2f}')

<IPython.core.display.Javascript object>

In [18]:
plot.plot_map_MW_Mollweide(f_start, vmin=mn, vmax=mx, title='', figsize=(10, 6))#title=f'Start - {mn=:.2f}, {mx=:.2f}')

<IPython.core.display.Javascript object>

In [19]:
np.sqrt(np.mean(np.abs(flm_start[1:]) ** 2))

0.007742070432287793

In [20]:
Ilm_square = flm_target * np.conj(flm_target)
# Compute the variance : Sum all except the (l=0, m=0) term
# Todo: TEST
# var = (np.sum(Ilm_square) - Ilm_square[0, L - 1]) / (4 * np.pi)
var = np.mean(Ilm_square - Ilm_square[0, L - 1])
print(var)

(9.587336160556921e-05+0j)


In [21]:
# plot.plot_sphere(f_end, L, sampling=sampling)

### Power spectrum

In [22]:
ps_target = sphlib.compute_ps(flm_target)
ps_start = sphlib.compute_ps(flm_start)
ps_end = sphlib.compute_ps(flm_end)

plt.figure(figsize=(8, 6))
plt.plot(ps_target, 'b', label="Target")
plt.plot(ps_start, 'g', label="Start")
plt.plot(ps_end, 'r', label="End")
plt.yscale("log")
plt.xlabel(r'Multipole $\ell$')
plt.ylabel('Power spectrum')
plt.grid()
#plt.xlim(2, 256)
#plt.ylim(1e-4, 1)
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f94ec176d60>

In [23]:
print('Mean:', tmean, smean, emean)
print('Var:', tvar, svar, evar)

Mean: 0.016352476981190633 0.002051320457021274 0.01736145983804501
Var: 9.637311955554874e-05 5.99396545785049e-05 9.598407014609487e-05


In [24]:
plot.plot_scatcov_coeffs(tS1, tP00, tC01, tC11, name='Target', hold=True, color='blue')

plot.plot_scatcov_coeffs(sS1, sP00, sC01, sC11, name='Start', hold=False, color='green')

plot.plot_scatcov_coeffs(eS1, eP00, eC01, eC11, name='End', hold=False, color='red')

<IPython.core.display.Javascript object>