# PySM3 Component Examinations - Make Component Maps

This notebook produces maps for each of the components.

We do not find any spatial variation, for any $N_{side} \in \{64, 1024, 2048\}$, using the components `["d9", "s4", "f1", "a1", "co1", "cib1", "ksz1", "tsz1", "rg1", "d1", "s1"]`

# Parameters

In [1]:
# For loading from cmbml directory (temporary solution)
import sys
import os

repo_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, repo_root)

In [None]:
import hydra
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
os.environ['CMB_ML_LOCAL_SYSTEM'] = 'generic_lab'
hydra.core.global_hydra.GlobalHydra.instance().clear() # if re-initialize is needed, clear the global hydra instance (in case of multiple calls to initialize)

initialize(version_base=None, config_path="../cfg")

cfg = compose(config_name='config_comp_models.yaml')

In [2]:
from pathlib import Path

component_maps_dir = Path(cfg.local_system.datasets_root) / "ComponentMaps"

## Imports and Set-Up

In [3]:
# For loading from cmbml directory (temporary solution)
import sys
import os

repo_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, repo_root)

In [4]:
import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt

import healpy as hp
import pysm3
import pysm3.units as u

In [5]:
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
del prop_cycle  # clean up namespace

In [6]:
# Swallow downgrade errors.  TODO: Why are there downgrade errors?
import logging

class LoggingContextManager:
    def __init__(self, filename, level=logging.WARNING, exit_msg=None):
        self.filename = filename
        self.level = level
        self.exit_msg = exit_msg
        self.first_issue_notified = False
        self.issue_occurred = False
        self.logger = logging.getLogger()
        self.file_handler = logging.FileHandler(filename)
        self.file_handler.setLevel(level)
        self.file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        self.original_handlers = None

    def __enter__(self):
        self.original_handlers = self.logger.handlers[:]
        # Set the logger to the lowest possible level during the context to ensure all messages are processed
        self.logger.setLevel(logging.DEBUG)
        self.logger.handlers = []  # Remove existing handlers to avoid duplicate logs
        self.logger.addHandler(self.file_handler)
        self.logger.addFilter(self.process_notification)  # Add custom filter
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.logger.removeHandler(self.file_handler)
        self.logger.handlers = self.original_handlers  # Restore original handlers
        self.file_handler.close()
        if self.issue_occurred:
            print(self.exit_msg or "End of processing: Issues were logged during the session.")

    def process_notification(self, record):
        """Custom filter to process notifications for the first issue."""
        if record.levelno >= self.level:
            if not self.first_issue_notified:
                print(f"First issue encountered; check {self.filename} for more information.")
                self.first_issue_notified = True
            self.issue_occurred = True
        return True  # Always return True to ensure all messages are logged

# Setup basic configuration for logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

## Parameters

In [7]:
nside_sky = 2048
nside_out = 1024

component_maps_dir = component_maps_dir / f"sky{nside_sky}_out{nside_out}"
component_maps_dir.mkdir(exist_ok=True, parents=True)

In [8]:
target_freqs = [30, 44, 70, 100, 143, 217, 353, 545, 857]
# target_freqs = [100]  # [30, 44, 70, 100, 143, 217, 353, 545, 857]  # Debugging line
target_freqs = [f * u.GHz for f in target_freqs]

In [9]:
# components = ["c1"]  # Debugging line
components = ["d9", "s4", "f1", "a1", "co1", "cib1", "ksz1", "tsz1", "rg1", "d1", "s1", "c1"]

In [10]:
# For PySM3 Sky Model
# lmax = 3 * nside_sky - 1
lmax = int(2.5 * nside_sky)
beam_fwhm = 3 * hp.nside2resol(nside_out, arcmin=True) * u.arcmin
# beam_fwhm = 5.0 * u.arcmin

# Producing Maps

Warning! At `nside_sky=2048`, this takes 2.5 hours!

In [None]:
# Produce maps, save them to disk, for limited RAM usage
# map_dict = {}
n_sims = 1  # Use n_sims > 1 so we can check for spatial variation. Spoiler: There is none.
pbar = tqdm(total=n_sims * len(components) * len(target_freqs), desc="Processing components and frequencies")

# PySM3 throws a warning for each of the calls to apply_smoothing_and_coord_transform(). I'm ok with the lack of convergence.
with LoggingContextManager("pysm3_warnings.log", exit_msg="End of processing: Warnings were logged during the session.") as log:
    for sim_num in range(n_sims):
        np.random.seed(sim_num)
        for comp in components:
            sky = pysm3.Sky(nside=nside_sky, preset_strings=[comp])
            for freq in target_freqs:
                sky_observed = sky.get_emission(freq)
                if nside_sky != nside_out:
                    # Downgrade the map to the output nside; PySM3 has this as a catch-all function, because it operates in alm space internally
                    sky_map = pysm3.apply_smoothing_and_coord_transform(sky_observed[0],
                                                                        fwhm=beam_fwhm,
                                                                        lmax=lmax,
                                                                        output_nside=nside_out)
                    sky_map = sky_map.to(u.uK_CMB, equivalencies=u.cmb_equivalencies(freq))
                np.save(component_maps_dir / f"sim{sim_num}_{comp}_{freq.value}GHz.npy", sky_map.data)
                pbar.update(1)
del sky, sky_observed, freq, comp, pbar  # Clean up namespace