In [1]:
# General imports

%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import torch

from gaussian_npe import gaussian_npe, utils
from gaussian_npe.gaussian_npe_network_disco_dj_CG_Chebyshev_Qx_tests_masked import Gaussian_NPE_Network

from datetime import datetime
current_time = datetime.now().strftime("%y%m%d_%H%M%S")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
########## BOX PARAMETERS ##########

box_params = {
        'box_size': 1000.,       #Mpc/h
        'grid_res': 64,         #resolution
        'h': 0.6711,
        'dim': 3
        }

# box = utils.Power_Spectrum_Sampler(box_parameters, device = 'cpu')

########## COSMO PARAMETERS ##########

cosmo_params = {
        'h': 0.6711,
        'Omega_b': 0.049,
        'Omega_cdm': 0.2685,
        # 'A_s': 2.1413e-09,
        'n_s': 0.9624,
        'non linear': 'halofit',
        'sigma8': 0.834,
    }

### Prior

In [3]:
box = utils.Power_Spectrum_Sampler(box_params, device = device)
prior = box.get_prior_Q_factors(lambda k: torch.tensor(utils.get_pk_class(cosmo_params, 0, np.array(k)), device = device))

  prior = box.get_prior_Q_factors(lambda k: torch.tensor(utils.get_pk_class(cosmo_params, 0, np.array(k)), device = device))


### Sampling the trained model

In [9]:
sample_obs = torch.load('sample_obs_disco_dj.pt', weights_only=False)

mask = torch.tensor(utils.create_cone_mask(fov_angle=[53.13], res=box_params['grid_res']), device='cuda')
sigma_noise = 0.1

delta_fin = torch.from_numpy(sample_obs['delta_fin'].astype('f')).cuda()
delta_fin = delta_fin * mask  + torch.randn_like(delta_fin) * sigma_noise  # Apply cone mask and noise
delta_ic = torch.from_numpy(sample_obs['delta_ic'].astype('f')).cuda()
delta_ic = delta_ic # * mask

sample_obs['delta_ic'] = delta_ic.cpu().numpy().astype('f')
sample_obs['delta_fin'] = delta_fin.cpu().numpy().astype('f')
sample_obs['delta_obs'] = sample_obs['delta_fin']

In [7]:
model = Gaussian_NPE_Network(box, prior, mask.float(), k_cut=0.03).to(device)
model.load_state_dict(torch.load('plots/250908_121758_Chebyshev_20ep_2ksims_lr_5e-3_Qx_masked_init_with_mask/best_model_250908_121758_Chebyshev_20ep_2ksims_lr_5e-3_Qx_masked_init_with_mask.pt'))
model.eval()

Gaussian_NPE_Network(
  (unet): UNet(
    (conv_l0): ConvBlock(
      (convs): Sequential(
        (0): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (1): LeakyReLU(negative_slope=0.01)
        (2): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (3): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): LeakyReLU(negative_slope=0.01)
      )
    )
    (down_l0): ConvBlock(
      (convs): Sequential(
        (0): Conv3d(8, 8, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
    )
    (conv_l1): ConvBlock(
      (convs): Sequential(
        (0): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
        (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride

In [27]:
# samples = model.sample(num_samples = 100, x_obs = delta_fin)
# z_MAP = model.get_z_MAP(delta_fin)
# samples = torch.tensor(samples)

run_name = 'OU_sampling_Qx_mask_ic'
utils.plot_samples_analysis(sample_obs, samples, z_MAP, box_params, cosmo_params, time=current_time, run_name=run_name, mask = mask)

Sample absolute standard deviation from the truth: 10^3 * std(delta_ic - sample) = 1.0736234


  def hartley(x, dim = (-3, -2, -1)):



Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time taken = 0.01 seconds

Computing power spectra of the fields...
Time FFTS = 0.00
Time loop = 0.01
Time t

In [8]:
plt.figure()
plt.title(r'Q matrix diagonal values as a function of k, $Q_{like} = U^T D \, U$')
plt.xlabel(r'$k~[h{\rm Mpc}^{-1}]$')
plt.ylabel(r'$D(k)$')

K = box.get_k().cpu().numpy().flatten()

k_Nq = model.box.k_Nq
# K = get_k().flatten()
mask = (K < k_Nq) * (K > 1e-3)
print(mask.sum())
plt.scatter(K[mask][::100], model.Q_like.D.detach().flatten().cpu().numpy()[mask][::100], s=1, label=r'$D_{like}$')
plt.scatter(K[mask][::100], model.Q_prior.D.detach().flatten().cpu().numpy()[mask][::100], s=1, label=r'$D_{prior}$')
plt.scatter(K[mask][::1], model.Q_like.D.detach().flatten().cpu().numpy()[mask][::1] + model.Q_prior.D.detach().cpu().numpy()[mask][::1], s=1, label=r'$D_{posterior}$', alpha=0.5)
plt.axvline(x = k_Nq, color='r', linestyle='--', label='$k_{Nyq}$')

plt.legend(loc='best')
plt.xscale('log')
plt.savefig("./plots/"+current_time+"_"+run_name+"/Q_FFT_matrix_"+current_time+"_"+run_name+".png")
plt.show()
# plt.ylim([5e-3, k_Nq])
# plt.yscale('log')
plt.ylim([0, 100])

137059


(0.0, 100.0)

In [20]:
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

# samples = samples.cpu().numpy().astype('f')
draw = samples[0]
delta_target = sample_obs['delta_ic']
MAP = z_MAP.cpu().numpy()
var = samples.var(axis=0)
# residual = draw - delta_target
# print('Draw absolute standard deviation from the truth: 10^3 * std(delta - draw) =', 1e3 * residual.std())

extent=[0, 1000, 0, 1000]
vmin1, vmax1 = -3, 3
vmin2, vmax2 = -0.5, 2.5
# ticks = np.arange(0, 1001, 100)
# ticklabels = ['0', '', '200', '', '400', '', '600', '', '800', '', '1000']


fig, axes = plt.subplots(2, 3,  figsize=(12, 8), layout='compressed')#, sharex=True, sharey=True,)

figure = axes[0, 0].imshow(sample_obs['delta_ic'][-4:].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
axes[0, 0].set_title(r'True initial', fontsize=22)
axes[0, 0].set_xticklabels([])
axes[0, 0].set_yticklabels([])
# axes[0, 0].set_ylabel(r'$[{\rm Mpc} / h]$', fontsize=14)
# axes[0, 0].set_xticklabels([])
# cbar_ax1 = fig.add_axes([axes[0, 0].get_position().x1 + 0.01, axes[0, 0].get_position().y0, 0.01, axes[0, 0].get_position().height])
# cbar_ax1.set_title(r'$\times 10^{-2}$', fontsize=12)
# plt.colorbar(figure, ax=axes[0, :])#, cax=cbar_ax1, ticks=[-2, -1, 0, 1, 2])

figure = axes[0, 2].imshow(sample_obs['delta_fin'][-4:].mean(0), origin='lower', cmap='inferno', vmin=vmin2, vmax=vmax2, extent=extent)
axes[0, 2].set_title(r'True final', fontsize=22)
# axes[0, 2].set_xlabel(r'$[h \ {\rm Mpc}^{-1}]$', fontsize=12)
# axes[0, 2].set_ylabel(r'$[h \ {\rm Mpc}^{-1}]$', fontsize=12)
axes[0, 2].set_xticklabels([])
axes[0, 2].set_yticklabels([])
# cbar_ax1 = fig.add_axes([axes[0, 1].get_position().x1 + 0.015, axes[0, 1].get_position().y0, 0.01, axes[1, 0].get_position().height])
# plt.colorbar(figure, ax=axes[1, 0], pad = 0.006, aspect=30)#, cax=cbar_ax1)

figure = axes[0, 1].imshow(samples[1][-4:].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
axes[0, 1].set_title(r'Sample 1', fontsize=22)
axes[0, 1].set_xticklabels([])
axes[0, 1].set_yticklabels([])
# axes[0, 1].set_ylabel(r'$[{\rm Mpc} / h]$', fontsize=14)
# axes[0, 1].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=14)
# clb = plt.colorbar(figure, ax=axes[1, 1], pad=0.006, aspect=30, ticks=[-2, -1, 0, 1, 2])
# clb.ax.set_title(r'$\times 10^{-2}$', fontsize=12)

# axes[1, 1].set_ylabel(r'$[{\rm Mpc} / h]$', fontsize=12)
# figure = axes[1, 1].imshow(1e2 * draws[i][-4:].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
# axes[1, 1].set_title(titles1[i], fontsize=16)
# axes[1, 1].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=12)
# cbar_ax1 = fig.add_axes([axes[1, 1].get_position().x1 + 0.015, axes[1, 1].get_position().y0, 0.01, axes[1, 1].get_position().height])
# cbar_ax1.set_title(r'$\times 10^{-2}$', fontsize=10)
# plt.colorbar(figure, ax=axes[1, :], cax=cbar_ax1, ticks=[-2, -1, 0, 1, 2])

figure = axes[1, 1].imshow(samples[2][-4:].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
# axes[1, 1].set_title(f'Sample #{frame + 1}', fontsize=16)
axes[1, 1].set_title(r'Sample 2', fontsize=22)
axes[1, 1].set_xticklabels([])
axes[1, 1].set_yticklabels([])
# axes[1, 1].set_yticklabels([])
# axes[1, 1].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=14)

figure = axes[1, 0].imshow(MAP[-4:].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
axes[1, 0].set_title(r'MAP', fontsize=22)
axes[1, 0].set_xticklabels([])
axes[1, 0].set_yticklabels([])
# axes[0, 1].set_ylabel(r'$[{\rm Mpc} / h]$', fontsize=12)
# axes[0, 1].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=12)
# axes[0, 2].set_xticklabels([])
# axes[0, 2].set_yticklabels([])
# clb = plt.colorbar(figure, ax=axes[0, 2], pad=0.006, aspect=30, ticks=[-2, -1, 0, 1, 2])
# clb.ax.set_title(r'$\times 10^{-2}$', fontsize=12)

figure = axes[1, 2].imshow(var[-4:].mean(0), origin='lower', cmap='Purples', extent=extent)#, vmin=vmin1, vmax=vmax1, )
axes[1, 2].set_title(r'Variance', fontsize=22)
axes[1, 2].set_xticklabels([])
axes[1, 2].set_yticklabels([])
# axes[0, 1].set_ylabel(r'$[{\rm Mpc} / h]$', fontsize=12)
# axes[0, 1].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=12)
# axes[1, 2].set_xticklabels([])
# axes[1, 2].set_yticklabels([])
# clb = plt.colorbar(figure, ax=axes[1, 2], pad=0.006, aspect=30, ticks=[-2, -1, 0, 1, 2])
# clb.ax.set_title(r'$\times 10^{-2}$', fontsize=12)


# figure = axes[1, 1].imshow(1e2 * draws[2][-4:].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
# # axes[1, 1].set_title(f'Sample #{frame + 1}', fontsize=16)
# axes[1, 1].set_title(r'Sample 2', fontsize=20)
# axes[1, 1].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=12)

# K = get_k().flatten()
# # mask = (K.cpu() < k_Nq) * (K.cpu() > 1e-3)
# mask1 = (K.cpu() > 1e-3) * (K.cpu() < 0.03)
# mask2 = (K.cpu() > 0.03) * (K.cpu() < 0.1)
# mask3 = (K.cpu() > 0.1) * (K.cpu() < k_Nq)
# print(mask.sum())



# plt.tight_layout()

fig.savefig("./plots/"+current_time+"_"+run_name+"/samples_"+current_time+"_"+run_name+".pdf")#, bbox_inches='tight')

### Animation

In [11]:
from matplotlib.animation import FuncAnimation

In [15]:
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

extent=[0, 1000, 0, 1000]
vmin1, vmax1 = -3, 3
vmin2, vmax2 = -0.5, 2.5


fig, axes = plt.subplots(2, 2,  figsize=(12, 12), sharex=True, sharey=True, layout='compressed')

figure = axes[0, 0].imshow(sample_obs['delta_ic'][45:64].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
axes[0, 0].set_title(r'True initial', fontsize=20)
axes[0, 0].set_ylabel(r'$[{\rm Mpc} / h]$', fontsize=12)

figure = axes[0, 1].imshow(sample_obs['delta_fin'][45:64].mean(0), origin='lower', cmap='inferno', vmin=vmin2, vmax=vmax2, extent=extent)
axes[0, 1].set_title(r'True final', fontsize=20)
plt.colorbar(figure, ax=axes[0, :], pad = 0.006, aspect=30)#, cax=cbar_ax1)

figure = axes[1, 1].imshow(z_MAP.cpu().numpy()[45:64].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
axes[1, 1].set_title(r'MAP estimation', fontsize=20)
axes[1, 1].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=12)
clb = plt.colorbar(figure, ax=axes[1, :], pad=0.006, aspect=30, ticks=[-2, -1, 0, 1, 2])
clb.ax.set_title(r'$\times 10^{-2}$', fontsize=10)

# Function to update the plot for each frame
def update(frame):
    axes[1, 0].cla()
    figure = axes[1, 0].imshow(samples.cpu().numpy().astype('f')[frame][45:64].mean(0), origin='lower', cmap='seismic', vmin=vmin1, vmax=vmax1, extent=extent)
    axes[1, 0].set_title(r'Sample ' + str(frame + 1), fontsize=20)
    axes[1, 0].set_xlabel(r'$[{\rm Mpc} / h]$', fontsize=12)
    axes[1, 0].set_ylabel(r'$[{\rm Mpc} / h]$', fontsize=12)

# plt.tight_layout()

# Create the animation
animation = FuncAnimation(fig, update, frames=10, interval=200)

# Save the animation as a GIF file
animation.save('./plots/'+current_time+'_'+run_name+'/animation'+'_'+current_time+'_'+run_name+'.gif', writer='pillow')

# Show the animation (if you want to display it in a Jupyter notebook or similar)
# plt.show()

NameError: name 'FuncAnimation' is not defined