In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import scipy.special
from cosmoprimo import Cosmology, PowerSpectrumBAOFilter
import argparse
import multiprocessing
from utils_paths import path_template
from utils import cosmology, redshift_distributions
from utils_template import PowerSpectrumMultipoles, CorrelationFunctionMultipoles, WThetaCalculator

# 1. Arguments
class Args:
    def __init__(self):
        self.cosmology_template = "planck"
        self.nz_flag = "fid"
        # self.nz_flag = "clusteringz"
        self.include_wiggles = "y"
args = Args()
args.include_wiggles = '' if args.include_wiggles == 'y' else '_nowiggles'

# 2. Path for saving the template
path_template = path_template(include_wiggles=args.include_wiggles, nz_flag=args.nz_flag, cosmology_template=args.cosmology_template)()
os.makedirs(path_template, exist_ok=True)

# 3. Numerical resolution of the calculation. The larger these numbers are, the more accurate the template will be
Nz, Nk, Nmu, Nr, Ntheta = 10**3, 2*10**5, 5*10**4, 5*10**4, 10**3 # these are the settings I used for DES Y6 BAO
# Nz, Nk, Nmu, Nr, Ntheta = 10**3, 10**4, 10**4, 10**4, 10**2

# 4. Redshift distributions
nz_instance = redshift_distributions(args.nz_flag)

# 5. Cosmology
params = cosmology(args.cosmology_template)
cosmo = Cosmology(
    h=params.h,
    Omega_cdm=params.Omega_m - params.Omega_b - params.Omega_nu_massive,
    Omega_b=params.Omega_b,
    sigma8=params.sigma_8,
    n_s=params.n_s,
    Omega_ncdm=params.Omega_nu_massive,
    N_eff=3.046,
    engine='class'
)

# 6. Calculation of pk_ell
pk_calculator = PowerSpectrumMultipoles(
    cosmo=cosmo,
    include_wiggles=args.include_wiggles,
    nz_instance=nz_instance,
    Nk=Nk,
    Nmu=Nmu,
    path_template=path_template,
)
with multiprocessing.Pool(nz_instance.nbins) as pool:
    pk_ell_dict = pool.map(pk_calculator.compute_pk_ell, range(nz_instance.nbins))

# 7. Calculation of xi_ell
xi_calculator = CorrelationFunctionMultipoles(
    power_spectrum_multipoles=pk_calculator,
    Nr=Nr,
)
with multiprocessing.Pool(nz_instance.nbins) as pool:
    xi_ell_dict = pool.map(xi_calculator.compute_xi_ell, range(nz_instance.nbins))

# 8. Calculation of w(theta)
calculator = WThetaCalculator(correlation_function_multipoles=xi_calculator, Nz=Nz, Ntheta=Ntheta, n_cpu=100)
for bin_z in range(nz_instance.nbins):
    calculator.compute_and_save_wtheta(bin_z)
