In [10]:
!export JAX_PLATFORMS=cpu
import tszpower
from classy_sz import Class as Class_sz
import jax
import jax.numpy as jnp 
import numpy as np
import torch
from sbi.utils import BoxUniform
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
from sbi.analysis import pairplot
from sbi.inference import NPE
from sbi.inference import NLE
from sbi.inference import NRE_A
# from sbi.inference import FMPE
import matplotlib.pyplot as plt

  pid, fd = os.forkpty()


In [11]:
# 1) Define parameters
allpars = {
    # 'output': '',
    'omega_b': 0.0225,
    'omega_cdm': 0.12,
    'H0': 67.66,
    'tau_reio': 0.0561,
    'ln10^{10}A_s': 3.0,
    'n_s': 0.9665,
    'M_min': 1e10,
    'M_max': 3.5e15,
    # 'ell_min': 2,
    # 'ell_max': 8000,
    # 'dlogell': 0.1,
    'z_min': 5e-3,
    'z_max': 3.0,
    'P0GNFW': 8.130,
    'c500': 1.156,
    'gammaGNFW': 0.3292,
    'alphaGNFW': 1.0620,
    'betaGNFW': 5.4807,
    'B': 1.0,
    # "cosmo_model": 1, # use mnu-lcdm emulators
    'jax': 1
}

In [12]:
tszpower.classy_sz.set(allpars)
tszpower.classy_sz.compute_class_szfast()
# relavant_pars = tszpower.classy_sz.get_all_relevant_params(allpars)

In [13]:
def simulator(theta: torch.Tensor) -> torch.Tensor:
    """
    Simulator that wraps tszpower.compute_Cl_yy_total.
    
    The input `theta` is expected to be a torch.Tensor of shape (batch, 9)
    with columns ordered as:
      [logA, omega_b, omega_cdm, H0, n_s, B, A_cib, A_rs, A_ir]
      
    This function uses the torch tensor directly by converting each element
    to a Python float, calls the tszpower simulator, and then returns a torch.Tensor.
    """
    batch_size = theta.shape[0]
    
    # Generate a base key and split it for each simulation.
    base_key = jax.random.PRNGKey(2837)
    keys = jax.random.split(base_key, batch_size)
    # print(keys)
    
    sim_list = []
    for i in range(batch_size):
        # Extract each parameter as a Python float
        logA      = float(theta[i, 0])
        omega_b   = float(theta[i, 1])
        omega_cdm = float(theta[i, 2])
        H0        = float(theta[i, 3])
        n_s       = float(theta[i, 4])
        B         = float(theta[i, 5])
        A_cib     = float(theta[i, 6])
        A_rs      = float(theta[i, 7])
        A_ir      = float(theta[i, 8])
        
        # Call your tszpower simulator (which uses JAX internally)
        sim_i = tszpower.compute_Cl_yy_total(
            logA,
            omega_b,
            omega_cdm,
            H0,
            n_s,
            B,
            A_cib,
            A_rs,
            A_ir,
            keys[i],
            params_values_dict=allpars,  # your global parameter dictionary
            n_realizations=1
        )
        # Convert the returned JAX array to a NumPy array and then to a torch.Tensor
        sim_torch = torch.tensor(np.array(sim_i), dtype=torch.float32)
        sim_list.append(sim_torch)
    
    # Stack the results to form a tensor of shape (batch, n_ell)
    return torch.stack(sim_list, dim=0)


In [14]:
# --- Define the Prior ---
# We need a 9-dimensional prior (one for each free parameter).
# Order: [logA, omega_b, omega_cdm, H0, n_s, B, A_cib, A_rs, A_ir]

low = torch.tensor([2.5,   0.02,  0.11, 55.,  0.94, 1.0, 0.0, 0.0, 0.0])
high = torch.tensor([3.5, 0.025, 0.13, 90.,  1.0,  2.0, 5.0, 5.0, 5.0])
prior = BoxUniform(low=low, high=high)

# Process the prior.
prior, num_parameters, prior_returns_numpy = process_prior(prior)
# Process the simulator to ensure it returns batched outputs and torch.Tensors.
simulator = process_simulator(simulator, prior, prior_returns_numpy)
# Check consistency.
check_sbi_inputs(simulator, prior)

In [15]:
theta = prior.sample((10000,))
x = simulator(theta)
print("theta is", theta)
print("x is", x)
print("x shape is", x.shape)

theta is tensor([[3.0176, 0.0240, 0.1130,  ..., 4.3077, 0.0568, 2.9076],
        [2.8796, 0.0232, 0.1132,  ..., 4.5891, 0.5272, 4.2974],
        [3.0101, 0.0229, 0.1260,  ..., 0.1043, 0.6794, 2.3754],
        ...,
        [3.4286, 0.0241, 0.1198,  ..., 0.3638, 1.0195, 0.9560],
        [2.9786, 0.0248, 0.1265,  ..., 1.9338, 0.2506, 1.4157],
        [3.2569, 0.0204, 0.1121,  ..., 3.0460, 3.2573, 3.1324]])
x is tensor([[ 9.1051e-03,  2.6906e-03,  5.3921e-03,  ...,  2.2004e+00,
          3.2241e+00,  4.3651e+00],
        [-1.7084e-03,  6.2439e-03, -6.2991e-03,  ...,  2.5262e+00,
          3.7055e+00,  5.1380e+00],
        [ 1.0130e-02,  1.6295e-02,  2.6257e-02,  ...,  1.6339e+00,
          2.1461e+00,  2.8643e+00],
        ...,
        [-8.4075e-04,  4.8074e-04,  1.2115e-02,  ...,  1.0310e+00,
          1.4078e+00,  1.8825e+00],
        [ 7.0748e-03,  1.7674e-02,  4.2691e-02,  ...,  2.2044e+00,
          2.9021e+00,  3.6707e+00],
        [ 1.2858e-02, -4.5128e-03,  4.6135e-03,  ...,  2.700

In [16]:
# 保存 theta（假设 prior.sample() 生成的 theta 已经是 tensor [2000, D]
torch.save(theta, 'theta_10ksamps.pt')  # 原生 PyTorch 格式
# 或保存为可读文本
# with open('theta.txt', 'w') as f:
#     for row in theta:
#         f.write(' '.join(map(str, row.tolist())) + '\n')

# 保存 x（形状 [2000, 18]
torch.save(x, 'x_3ksamps.pt')
# 或分隔符型文本格式
# with open('x.txt', 'w') as f:
#     for data_point in x:
#         f.write(' '.join(f"{num:.6f}" for num in data_point.tolist()) + '\n')
