In [2]:
"""
CosmoPower-JAX Training Script

Author: A. Spurio Mancini
"""

import os
import pickle
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from jax.nn import sigmoid
import optax
from typing import Tuple, List, Dict, Any, Optional
from dataclasses import dataclass
from tqdm import tqdm
from cosmopower_jax.cosmopower_jax import CosmoPowerJAX

In [13]:
!pip list | grep -E 'jax|numpy|diffrax'

cosmopower_jax                    0.5.5
diffrax                           0.7.0
jax                               0.5.0
jax_autovmap                      0.2.1
jax-cosmo                         0.1.0
jax-cuda12-pjrt                   0.5.0
jax-cuda12-plugin                 0.5.0
jax-tqdm                          0.1.1
jaxlib                            0.5.0
jaxopt                            0.6
jaxtyping                         0.2.34
numpy                             1.26.4


In [3]:
import torch
print(f"torch.cuda.is_available() = {torch.cuda.is_available()}")

torch.cuda.is_available() = True


In [7]:
!nvidia-smi

Fri Aug  1 12:56:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08              Driver Version: 575.57.08      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:CA:00.0 Off |                    0 |
| N/A   29C    P0             52W /  400W |       5MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [6]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_May__3_18:49:52_PDT_2022
Cuda compilation tools, release 11.7, V11.7.64
Build cuda_11.7.r11.7/compiler.31294372_0


  pid, fd = os.forkpty()


In [9]:
jax.devices()

[CpuDevice(id=0)]

### Dummy data

In [2]:
def generate_dummy_data(probe: str, n_samples: int = 10000) -> Tuple[np.ndarray, np.ndarray, List[str], np.ndarray]:
    """Generate dummy training data for testing."""

    if probe in ['cmb_tt', 'cmb_ee', 'cmb_te', 'cmb_pp']:
        # CMB parameters: omega_b, omega_cdm, h, tau, n_s, ln10^10A_s
        parameters = ['omega_b', 'omega_cdm', 'h', 'tau_reio', 'n_s', 'ln10^{10}A_s']
        params = np.array([
            np.random.uniform(0.019, 0.025, n_samples),  # omega_b
            np.random.uniform(0.10, 0.14, n_samples),   # omega_cdm
            np.random.uniform(0.64, 0.74, n_samples),   # h
            np.random.uniform(0.04, 0.12, n_samples),   # tau
            np.random.uniform(0.92, 1.00, n_samples),   # n_s
            np.random.uniform(2.9, 3.3, n_samples)      # ln10^10A_s
        ]).T

        # CMB modes (ell)
        modes = np.arange(2, 2509)
        n_modes = len(modes)

        # Generate dummy spectra with realistic CMB-like shape
        spectra = np.zeros((n_samples, n_modes))
        for i in range(n_samples):
            # Simple CMB-like power spectrum
            ell = modes.astype(float)
            As = 10**(params[i, 5] - 10)  # A_s
            ns = params[i, 4]             # n_s

            # Simplified CMB spectrum
            spectra[i] = As * (ell / 100)**(ns - 1) * np.exp(-ell / 1000) * 1e12

    elif probe in ['mpk_lin', 'mpk_boost']:
        # Matter power spectrum parameters
        parameters = ['omega_b', 'omega_cdm', 'h', 'n_s', 'ln10^{10}A_s', 'z']
        params = np.array([
            np.random.uniform(0.019, 0.025, n_samples),  # omega_b
            np.random.uniform(0.10, 0.14, n_samples),   # omega_cdm
            np.random.uniform(0.64, 0.74, n_samples),   # h
            np.random.uniform(0.92, 1.00, n_samples),   # n_s
            np.random.uniform(2.9, 3.3, n_samples),     # ln10^10A_s
            np.random.uniform(0.0, 3.0, n_samples)      # z
        ]).T

        # k modes
        modes = np.logspace(-4, 2, 420)  # k in h/Mpc
        n_modes = len(modes)

        # Generate dummy matter power spectra
        spectra = np.zeros((n_samples, n_modes))
        for i in range(n_samples):
            k = modes
            As = 10**(params[i, 4] - 10)  # A_s
            ns = params[i, 3]             # n_s
            z = params[i, 5]              # redshift

            # Simple matter power spectrum
            if probe == 'mpk_lin':
                spectra[i] = As * (k / 0.05)**(ns - 1) * (1 + z)**(-2) * 1e4
            else:  # mpk_boost
                spectra[i] = 1 + 0.1 * k * (1 + z)  # Simple boost factor

    else:
        raise ValueError(f"Unknown probe: {probe}")

    return params, spectra, parameters, modes

In [3]:
params, spectra, parameters, modes = generate_dummy_data(probe = 'mpk_lin', n_samples = 10000)

In [5]:
spectra.shape
print(f"Array size: {spectra.nbytes / 1e6:.2f} MB")

Array size: 33.60 MB
