In [None]:
import h5py, os, math
import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
import cmasher as cmr

from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Ellipse
from tqdm import tqdm

from msfm.utils import files, scales

In [None]:
S8_DIR = "/Users/arne/data/CosmoGrid_example/S8"
os.makedirs("plots", exist_ok=True)

In [None]:
def add_black_rim(ax, color="black", linewidth=1):
    """Add a black rim around the Mollweide projection plot."""
    # the parameters for the ellipse were found by trial and error
    ellipse = Ellipse(xy=(0, 0), width=4, height=2, edgecolor=color, facecolor="none", linewidth=linewidth)
    ax.add_patch(ellipse)

In [None]:
def _deinterleave(local, k):
    """Return (x, y) from Morton (interleaved) local index for k bits per coord."""
    x = 0
    y = 0
    for i in range(k):
        x |= ((local >> (2 * i)) & 1) << i
        y |= ((local >> (2 * i + 1)) & 1) << i
    return x, y


def superpixel_patch(
    hpmap, nside_high, nside_super, ipix_super="equator", nest=True, origin="lower", return_ipix=False
):
    """
    Extract the high-resolution pixels inside one low-resolution (super-)pixel
    and return them as a square 2D array suitable for plt.imshow.

    Parameters
    ----------
    hpmap : 1D array
        HEALPix map at nside_high (length must be 12*nside_high**2).
    nside_high : int
        High-resolution nside (e.g. 512).
    nside_super : int
        Lower-resolution nside of the super-pixel (e.g. 2).
    ipix_super : int or str
        Index of the super-pixel (0 <= ipix_super < 12 * nside_super**2),
        OR one of the special strings:
            - "equator" → choose the superpixel centered closest to lat=0°, lon=0°.
    nest : bool, optional
        True if `hpmap` is in NESTED ordering. If False, the function will
        internally convert it using healpy (ring -> nest).
    origin : {'lower','upper'}, optional
        Pixel-origin convention to return (matches plt.imshow origin).
    return_ipix : bool, optional
        If True, also return a grid of high-res pixel indices.

    Returns
    -------
    patch : 2D ndarray
        Array of shape (factor, factor) with the pixel values.
    ipix_grid : 2D ndarray (optional)
        Array of same shape with the corresponding high-res NESTED pixel indices
        (only if return_ipix=True).
    """
    # sanity checks
    if nside_high % nside_super != 0:
        raise ValueError("nside_high must be an integer multiple of nside_super")
    factor = nside_high // nside_super
    if factor & (factor - 1):
        raise ValueError("ratio (nside_high/nside_super) must be a power of two")
    if origin not in ("lower", "upper"):
        raise ValueError("origin must be 'lower' or 'upper'")

    # convert map to NESTED if user supplied ring ordering
    if not nest:
        hpmap = hp.reorder(hpmap, r2n=True)

    npix_high = hp.nside2npix(nside_high)
    if hpmap.size != npix_high:
        raise ValueError(f"hpmap length ({hpmap.size}) does not match nside_high ({npix_high})")

    n_superpix = hp.nside2npix(nside_super)

    # choose equatorial superpixel automatically
    if isinstance(ipix_super, str):
        key = ipix_super.lower()
        if key == "equator":
            # Find the superpixel whose center is closest to lat=0°, lon=0°
            theta, phi = hp.pix2ang(nside_super, np.arange(n_superpix), nest=True)
            lat = 90.0 - np.degrees(theta)
            lon = np.degrees(phi)
            # find pixel closest to lat=0°, lon=0° (wrap longitude to [-180,180])
            lon = ((lon + 180) % 360) - 180
            dist2 = lat**2 + lon**2
            ipix_super = int(np.argmin(dist2))
        else:
            raise ValueError(f"Unknown string for ipix_super: {ipix_super}")

    if not (0 <= ipix_super < n_superpix):
        raise ValueError(f"ipix_super must be in [0, {n_superpix-1}]")

    # compute deinterleaving parameters
    k = int(round(math.log2(factor)))
    start = ipix_super * (factor * factor)

    patch = np.empty((factor, factor), dtype=hpmap.dtype)
    ipix_grid = np.empty((factor, factor), dtype=np.int64)

    for local in range(factor * factor):
        x, y = _deinterleave(local, k)
        ip = start + local
        patch[y, x] = hpmap[ip]
        ipix_grid[y, x] = ip

    if origin == "upper":
        patch = np.flipud(patch)
        ipix_grid = np.flipud(ipix_grid)

    if return_ipix:
        return patch, ipix_grid
    return patch

In [None]:
META_FILE = "/Users/arne/git/multiprobe-simulation-forward-model/data/CosmoGridV1_metainfo.h5"
with h5py.File(META_FILE, "r") as f:
    grid_params = f["parameters/grid"][:]

sobol = grid_params["sobol_index"]
w0 = grid_params["w0"]

w0_mask = np.abs(w0 + 1.0) < 0.1
grid_params = grid_params[w0_mask]
sobol = sobol[w0_mask]

S8 = grid_params["s8"] * np.sqrt(grid_params["Om"] / 0.3)
S8_inds = np.argsort(S8)

S8 = S8[S8_inds]
sobol = sobol[S8_inds]

sobol_select = []
S8_select = []
for i in [0, len(S8)//2, -1]:
    current_sobol = sobol[i]
    current_S8 = S8[i]
    print(f"S8 = {current_S8:.4f}, i_sobol = {current_sobol:06d}")
    print(f"rsync -ahv athomsen@login.phys.ethz.ch:/home/ipa/refreg/data/data_products/CosmoGrid/raw/grid/cosmo_{current_sobol:06d}/run_0/compressed_shells.npz {S8_DIR}/cosmo_{current_sobol:06d}/")
    sobol_select.append(current_sobol)
    S8_select.append(current_S8)

In [None]:
nside = 2048

all_raw_shells = []

for i in tqdm(sobol_select):
    SHELL_FILE = os.path.join(S8_DIR, f"cosmo_{i:06d}", "compressed_shells.npz")
    shells = np.load(SHELL_FILE)["shells"]
    all_raw_shells.append(shells)

all_raw_shells = np.stack(all_raw_shells)

# particles per shell
particles = []
for i, shell in enumerate(all_raw_shells[0]):
    particles.append(np.sum(shell))

print(832**3, np.sum(particles), np.sum(particles) / 832**3)

fig, ax = plt.subplots()
ax.plot(particles)
ax.set(yscale="log")

# individual

In [None]:
i_z = 20

# cmap = "plasma"
cmap = cmr.wildfire

def make_gnom(m, filename):
    xsize_deg = 30  # deg
    xsize_arcmin = xsize_deg * 60  # arcmin
    xreso = 2000
    yreso = 2 * xreso
    reso = xsize_arcmin / xreso  # arcmin/pixel

    # vmin = np.percentile(m, 1)
    # vmax = np.percentile(m, 99)

    gnom = hp.gnomview(
        m,
        title="",
        cmap=cmap,
        reso=reso,
        xsize=xreso,
        ysize=yreso,
        # min=vmin,
        # max=vmax,
        return_projected_map=True,
    )
    plt.imsave(f"plots/{filename}", gnom, cmap=cmap, origin="lower")

    return gnom


def make_moll(m, filename):
    figsize = 10
    xsize = 2000
    vmin, vmax = np.percentile(m, 1), np.percentile(m, 99)

    plt.figure(figsize=(2 * figsize, figsize))
    moll = hp.mollview(
        m, title="", cmap=cmap, min=vmin, max=vmax, xsize=xsize, cbar=False, sub=(1, 1, 1), return_projected_map=True
    )
    ax = plt.gca()
    add_black_rim(ax, linewidth=4)
    plt.savefig(f"plots/{filename}", bbox_inches="tight", dpi=xsize / figsize)

    return moll


def make_healpix(m, filename):
    # patch = superpixel_patch(m, nside, 2, nest=False)
    patch = superpixel_patch(m, nside, 1, nest=False)

    vmin, vmax = np.percentile(m, 1), np.percentile(m, 99)
    plt.imsave(f"plots/{filename}", patch, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)

    return patch

gnoms = []
molls = []
healpix = []
for i, shell in enumerate(all_raw_shells[:, i_z, :]):
    m = shell.copy()

    m = hp.smoothing(m, fwhm=2 * hp.nside2resol(2048))

    m = m / np.mean(m)
    m[m <= 0] = np.min(m[m > 0])
    m = np.log10(m)

    filename = f"S8={S8_select[i]:.2f},i={i_z}"
    filename += "_wildfire"
    # filename += "_plasma"
    # filename += "_raw"

    gnom = make_gnom(m, filename=f"gnom_{filename}.png")
    moll = make_moll(m, filename=f"moll_{filename}.png")
    hpx = make_healpix(m, filename=f"hpx_{filename}.png")

    gnoms.append(gnom)
    molls.append(moll)
    healpix.append(hpx)

gnoms = np.stack(gnoms)
molls = np.stack(molls)
healpix = np.stack(healpix)

In [None]:
hpx.shape

In [None]:
# shared colorbar

min_percentile = 1
max_percentile = 99

vmin, vmax = np.percentile(gnoms, min_percentile), np.percentile(gnoms, max_percentile)

# Desired native pixel resolution for composite figure
ncols = len(gnoms)
# panel_h_px, panel_w_px = gnoms[0].shape
# panel_h_px, panel_w_px = 1024, 512
panel_h_px, panel_w_px = hpx.shape[0], hpx.shape[1]//2
W_px = panel_w_px * ncols
H_px = panel_h_px

# Choose any dpi and derive figure size in inches to match exact pixel dims
# Using a round dpi keeps spine width conversions simple
dpi = 100
fig_w_in = W_px / dpi
fig_h_in = H_px / dpi

# Border width in pixels -> convert to points for Matplotlib spines
border_px = 5
spine_width_pt = 72 * border_px / dpi

fig, ax = plt.subplots(ncols=ncols, nrows=1, figsize=(fig_w_in, fig_h_in), dpi=dpi)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)

for i in range(ncols):
    patch = healpix[i]
    patch = patch[:, :(patch.shape[1]//2)]

    # Use imshow to map data pixels 1:1 to figure pixels
    ax[i].imshow(patch, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax, interpolation="nearest")
    ax[i].set_aspect("equal")
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    for spine in ax[i].spines.values():
        spine.set_linewidth(spine_width_pt)  # in points
        spine.set_edgecolor("black")

# Save without tight bbox to avoid any automatic resizing/cropping
fig.savefig("plots/S8_evo.png", dpi=dpi, bbox_inches=None, pad_inches=0)

In [None]:
# separate transforms

min_percentile = 1
max_percentile = 99

# Desired native pixel resolution for composite figure
ncols = 3

# TODO
# panel_h_px, panel_w_px = healpix[0].shape
panel_h_px, panel_w_px = 1024, 512


W_px = panel_w_px * ncols
H_px = panel_h_px

# Choose any dpi and derive figure size in inches to match exact pixel dims
# Using a round dpi keeps spine width conversions simple
dpi = 100
fig_w_in = W_px / dpi
fig_h_in = H_px / dpi

# Border width in pixels -> convert to points for Matplotlib spines
border_px = 20
spine_width_pt = 72 * border_px / dpi

fig, ax = plt.subplots(ncols=ncols, nrows=1, figsize=(fig_w_in, fig_h_in), dpi=dpi)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)

for i in range(ncols):
    patch = healpix[i]
    patch = patch[:, :(patch.shape[1]//2)]

    vmin = np.percentile(patch, min_percentile)
    vmax = np.percentile(patch, max_percentile)

    # Use imshow to map data pixels 1:1 to figure pixels
    ax[i].imshow(patch, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax, interpolation="nearest")
    ax[i].set_aspect("equal")
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    for spine in ax[i].spines.values():
        spine.set_linewidth(spine_width_pt)  # in points
        spine.set_edgecolor("black")

# Save without tight bbox to avoid any automatic resizing/cropping
fig.savefig("plots/S8_evo.png", dpi=dpi, bbox_inches=None, pad_inches=0)

# combined

In [None]:
# separate transforms

min_percentile = 1
max_percentile = 99

# Desired native pixel resolution for composite figure
ncols = len(gnoms)
panel_h_px, panel_w_px = gnoms[0].shape
W_px = panel_w_px * ncols
H_px = panel_h_px

# Choose any dpi and derive figure size in inches to match exact pixel dims
# Using a round dpi keeps spine width conversions simple
dpi = 100
fig_w_in = W_px / dpi
fig_h_in = H_px / dpi

# Border width in pixels -> convert to points for Matplotlib spines
border_px = 20
spine_width_pt = 72 * border_px / dpi

fig, ax = plt.subplots(ncols=ncols, nrows=1, figsize=(fig_w_in, fig_h_in), dpi=dpi)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)

for i, gnom in enumerate(gnoms):
    vmin = np.percentile(gnom, min_percentile)
    vmax = np.percentile(gnom, max_percentile)

    # Use imshow to map data pixels 1:1 to figure pixels
    ax[i].imshow(gnom, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax, interpolation="nearest")
    ax[i].set_aspect("equal")
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    for spine in ax[i].spines.values():
        spine.set_linewidth(spine_width_pt)  # in points
        spine.set_edgecolor("black")

# Save without tight bbox to avoid any automatic resizing/cropping
fig.savefig("plots/S8_evo.png", dpi=dpi, bbox_inches=None, pad_inches=0)

In [None]:
# shared colorbar

min_percentile = 1
max_percentile = 99

vmin, vmax = np.percentile(gnoms, min_percentile), np.percentile(gnoms, max_percentile)

# Desired native pixel resolution for composite figure
ncols = len(gnoms)
panel_h_px, panel_w_px = gnoms[0].shape
W_px = panel_w_px * ncols
H_px = panel_h_px

# Choose any dpi and derive figure size in inches to match exact pixel dims
# Using a round dpi keeps spine width conversions simple
dpi = 100
fig_w_in = W_px / dpi
fig_h_in = H_px / dpi

# Border width in pixels -> convert to points for Matplotlib spines
border_px = 5
spine_width_pt = 72 * border_px / dpi

fig, ax = plt.subplots(ncols=ncols, nrows=1, figsize=(fig_w_in, fig_h_in), dpi=dpi)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)

for i, gnom in enumerate(gnoms):
    # Use imshow to map data pixels 1:1 to figure pixels
    ax[i].imshow(gnom, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax, interpolation="nearest")
    # ax[i].imshow(gnom, cmap=cmr.wildfire, origin="lower", vmin=vmin, vmax=vmax, interpolation="nearest")
    ax[i].set_aspect("equal")
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    for spine in ax[i].spines.values():
        spine.set_linewidth(spine_width_pt)  # in points
        spine.set_edgecolor("black")

# Save without tight bbox to avoid any automatic resizing/cropping
# fig.savefig("plots/S8_evo_wildfire.png", dpi=dpi, bbox_inches=None, pad_inches=0)
fig.savefig("plots/S8_evo_plasma.png", dpi=dpi, bbox_inches=None, pad_inches=0)

In [None]:
def extract_superpixel_block(
    map_hi,
    super_ipix,
    *,
    hi_nside=2048,
    super_nside=2,
    map_order="RING",
    super_order="RING",
    return_indices=False,
):
    """Return the hi-res pixels belonging to a super-pixel as a square array."""
    for n in (hi_nside, super_nside):
        if not hp.isnsideok(n):
            raise ValueError(f"nside={n} is invalid for HEALPix.")
    if hi_nside % super_nside:
        raise ValueError("hi_nside must be an integer multiple of super_nside.")
    ratio = hi_nside // super_nside
    if ratio & (ratio - 1):
        raise ValueError("hi_nside/super_nside must be a power of two (hierarchical requirement).")

    map_nested = hp.reorder(map_hi, r2n=True) if map_order.upper() == "RING" else np.asarray(map_hi)

    parent_nest = super_ipix if super_order.upper() == "NEST" else hp.ring2nest(super_nside, super_ipix)
    diff = int(np.log2(ratio))
    block_size = ratio

    child_base = parent_nest << (2 * diff)
    child_nests = child_base + np.arange(4**diff, dtype=np.int64)

    x_child, y_child, face_child = hp.nest2xyf(child_nests)
    x_parent, y_parent, face_parent = hp.nest2xyf(np.atleast_1d(parent_nest))
    if not np.all(face_child == face_parent[0]):
        raise RuntimeError("Child pixels span multiple faces; check your inputs.")

    local_x = x_child - x_parent[0] * ratio
    local_y = y_child - y_parent[0] * ratio

    block = np.empty((block_size, block_size), dtype=map_nested.dtype)
    block[local_y, local_x] = map_nested[child_nests]

    if return_indices:
        child_ring = child_nests if map_order.upper() == "NEST" else hp.nest2ring(hi_nside, child_nests)
        return block, child_ring

    return block


In [None]:
full_sky = all_raw_shells[0, i_z]

cutout = extract_superpixel_block(full_sky, super_ipix=0, hi_nside=2048, super_nside=2)

In [None]:
example.shape

In [None]:
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

def stretch_midband(base_cmap, *, center=0.5, width=0.30, compression=0.25, samples=512):
    """
    Create a copy of `base_cmap` (e.g. 'cmr.wildfire') with a widened dark band.

    Parameters
    ----------
    center : float
        Location of the dark band in normalized [0, 1] colormap space.
    width : float
        Fraction of the colormap domain you want to devote to the dark band.
    compression : float
        0 < compression < 1 shrinks the hue change inside the band (smaller -> darker plateau).
    samples : int
        Number of color samples used when reconstructing the map.
    """
    # cmap = cmr.get_cmap(base_cmap)
    cmap = base_cmap

    width = np.clip(width, 0.0, 1.0)
    half = width / 2
    lo = max(0.0, center - half)
    hi = min(1.0, center + half)

    # xp: how far you advance through *data* space
    xp = np.array([0.0, lo, hi, 1.0])
    # fp: where those points land inside the *colormap*; compression<1 squeezes the ramp
    pad = half * compression
    fp = np.array([0.0,
                   max(0.0, center - pad),
                   min(1.0, center + pad),
                   1.0])

    x = np.linspace(0, 1, samples)
    warped = np.interp(x, xp, fp)
    colors = cmap(warped)

    return LinearSegmentedColormap.from_list(f"{base_cmap}_midstretch", colors)

# Example use in your notebook:
# stretched_wildfire = stretch_midband(cmr.wildfire, center=0.0, width=0.5, compression=0.2)
stretched_wildfire = stretch_midband(cmr.wildfire, center=(vmax - vmin)/2, width=0.1, compression=0.2)
hp.mollview(m, cmap=stretched_wildfire)