# SILC Tutorial

## SILC Pipeline Overview

The **Scale-discretised, directional Internal Linear Combination (SILC)** pipeline performs component separation in CMB analysis. 

###  Pipeline

| Step | Process | Description |
|------|---------|-------------|
| **1️** | **Data Acquisition** | Download CMB, noise, and foreground maps, currently using Planck simulations  → `download.py` |
| **2️** | **Map Processing** | Handle instrumental beams, reduce resolution, and convert from HEALPix → McEwen-Wiaux (MW) sampling  → `map_tools.py` `map_processing.py`|
| **3️** | **Wavelet Analysis** | Apply multi-scale wavelet transforms with customisable filter scales and directional components  → `map_processing.py` |
| **4** | **SPIN Algorithm** | Compute Scale-discretised, directional Internal Linear Combination on each wavelet scale  → `ilc.py` |
| **5️** | **Map Synthesis** | Combine all scales into a single, clean ILC component map  → `ilc.py` |

- `visualise.py` module provides capability to compute and plot maps and power spectra. 
- `file_templates.py` module contains the directory structure for storing and loading data throughout the pipeline.

###  Current Capabilities

> **Primary Target**: CMB extraction 

> **Future goals**: Generalisation to extract any astrophysical component (e.g., thermal Sunyaev-Zel'dovich effect, synchrotron emission)

This notebook will decompose the pipeline to show the main processes occurring at each stage, for one realisation of the ILC.

## Import 

**Note:** To run modules in the terminal, use e.g. `python3 -m skyclean.silc.map_processing` while in the Skyclean home directory.

In [None]:
import sys
import os
from matplotlib import pyplot as plt
import numpy as np
import healpy as hp
import jax

# Add the parent directory to Python path for proper module resolution
# Get current working directory and navigate to parent (assumes notebook is in examples/ folder)
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

print(f"Current directory: {current_dir}")
print(f"Added to Python path: {parent_dir}")

# Check JAX devices
print(f"\nJAX Configuration:")
print(f"JAX devices available: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")

# Import SILC modules
from skyclean.silc.download import DownloadData
from skyclean.silc.file_templates import FileTemplates
from skyclean.silc.map_tools import MWTools, HPTools, SamplingConverters
from skyclean.silc.map_processing import ProcessMaps
from skyclean.silc.ilc import SILCTools, ProduceSILC
from skyclean.silc.pipeline import Pipeline
from skyclean.silc.utils import *
from skyclean.silc.visualise import Visualise

## Download 

Skyclean currently downloads maps from the [Planck simulation archive](https://pla.esac.esa.int/#home). The downloaded maps include noise and foregrounds. The noise maps have a total of 300 random realisations. 

A problem was encountered with the Planck CMB simulations where the instrumental beams produced different maps in different frequency channels when deconvolved. This breaks the ILC assumption that the CMB is frequency channel-independent. Thus, a new approach was decided on where `cmb_spectrum.txt` is used to generate random CMB realisations. Make sure this is in your data directory before running downloads. 

In [None]:
## INITIALISE PARAMS

# Initialise file templates
data_directory = "data/"   # Local data storage
files = FileTemplates(data_directory)
file_templates = files.file_templates

# Define components to download
components = ["cmb", "sync", "dust", "noise"]

# Define frequencies (Planck channels in GHz)
# frequencies = ["030", "044", "070", "100", "143", "217", "353", "545", "857"]  # this is the full set
frequencies = ["030", "044", "070"] # let's work with a subset. 

# Download configuration
realisations = 1           # Download just one realisation for this example
start_realisation = 0      # Start from realisation 0

print(f"Target: {realisations} realisation(s) starting from #{start_realisation}")
print(f"Storage: {data_directory}")

# Set processing parameters
realisation = 0
desired_lmax = 64  # Lets choose a lower lmax for faster processing
standard_fwhm_rad = np.radians(5/60)  # 5 arcminute beam

# Set wavelet parameters
lam = 4.0  # Lambda parameter controlling wavelet scaling
N_directions = 1  # Number of directional components (1 = axisymmetric)
L = desired_lmax + 1  # L parameter for s2wav (lmax + 1)
wavelet_components = ["cfn"] # only take wavelet transform of CFN

In [None]:
# Initialise the downloader
downloader = DownloadData(
    components=components,
    frequencies=frequencies, 
    realisations=realisations,
    start_realisation=start_realisation,
    directory=data_directory
)

# Download everything at once
downloader.download_all()  

print("Download complete. Ready for map processing.")
print("Check your data/CMB_realisations directory for the files.")

##  Process Maps

Map processing handles instrumental beams, reduces resolution, and converts between HEALPix and McEwen-Wiaux sampling schemes. This step creates the total maps (CFN: CMB + Foreground + Noise) that will be used for wavelet analysis.

We will start by showing the map processing and wavelet transform tools available in `skyclean`, and then show how the pipeline can run these processes automatically.

**Note:** Experimentally, the CFN would be observed directly, but the Skyclean pipeline starts by building these maps from individual components. This modularity is necessary for model training where the known CMB is required.

In [None]:

# Map Processing Example
# This demonstrates the key map processing steps

print("Map Processing Parameters:")
print(f"  Components: {components}")
print(f"  Frequencies: {frequencies}")
print(f"  Target lmax: {desired_lmax}")
print(f"  Standard beam FWHM: {np.degrees(standard_fwhm_rad)*60:.1f} arcmin")


# Calculate target nside from lmax
nside = HPTools.get_nside_from_lmax(desired_lmax)
print(f"  Target nside: {nside}")

print("PROCESSING INDIVIDUAL COMPONENTS")

# Store processed maps for visualisation
processed_maps = {}
original_maps = {}

for freq in frequencies:
    processed_maps[freq] = {}
    original_maps[freq] = {}
    
    print(f"\nProcessing frequency: {freq} GHz")
    
    for comp in components:
        print(f"  Processing {comp}...")
        filepath = file_templates[comp].format(frequency=freq, realisation=realisation)
    
        original_map = hp.read_map(filepath, verbose=False) # load in the .fits file
        original_maps[freq][comp] = original_map
        
        # Apply unit conversion if needed
        converted_map = HPTools.unit_convert(original_map.copy(), freq)
        
        # Process the map
        if comp == "noise":
            # Noise: only reduce resolution (no beam convolution)
            processed_map, _ = HPTools.reduce_hp_map_resolution(converted_map, lmax=desired_lmax, nside=nside)
        else:
            # CMB and foregrounds: beam convolution + resolution reduction
            processed_map = HPTools.convolve_and_reduce(converted_map, lmax=desired_lmax, nside=nside, standard_fwhm_rad=standard_fwhm_rad)
        
        processed_maps[freq][comp] = processed_map
    


print("CREATING TOTAL MAPS (CFN)")
# Create CFN (CMB + Foreground + Noise) maps; this functionality exists in map_processing.py but we will show it explicitly here.
cfn_maps = {}

for freq in frequencies:
    print(f"\nCreating CFN map for {freq} GHz...")
    
    # Initialise empty map
    cfn_map = np.zeros(hp.nside2npix(nside), dtype=np.float64)
    
    # Sum all components
    for comp in components:
        cfn_map += processed_maps[freq][comp]
        print(f"  + Added {comp}")
    
    cfn_maps[freq] = cfn_map
    print(f"  Total RMS: {np.std(cfn_map):.2e}")
    
    # Save CFN map to the standard location (following map_processing.py pattern)
    cfn_filepath = file_templates["cfn"].format(frequency=freq, realisation=realisation, lmax=desired_lmax)
    hp.write_map(cfn_filepath, cfn_map, overwrite=True)
    print(f"  Saved CFN map to: {cfn_filepath}")

print("\nCONVERTING CFN MAPS TO MW SAMPLING")
# Convert HP CFN maps to MW sampling (used for wavelet transforms)
cfn_mw_maps = {}

for freq in frequencies:
    print(f"Converting CFN map to MW sampling for {freq} GHz...")
    
    # Convert HEALPix CFN map to MW sampling
    cfn_mw_map = SamplingConverters.hp_map_2_mw_map(cfn_maps[freq], lmax=desired_lmax)
    cfn_mw_maps[freq] = cfn_mw_map
    
    # Convert back to HP for verification
    cfn_hp_reconstructed = SamplingConverters.mw_map_2_hp_map(cfn_mw_map, lmax=desired_lmax)
    
    # Report conversion statistics
    original_rms = np.std(cfn_maps[freq])
    reconstructed_rms = np.std(cfn_hp_reconstructed)
    conversion_error = np.std(cfn_maps[freq] - cfn_hp_reconstructed)
    
    print(f"  MW map shape: {cfn_mw_map.shape}")
    print(f"  RMS before conversion: {original_rms:.2e}")
    print(f"  RMS after HP→MW→HP: {reconstructed_rms:.2e}")
    print(f"  Conversion error RMS: {conversion_error:.2e}")


print("VISUALISATION")


# Plot original maps (before processing)
print("\nPlotting original maps...")
n_comp = len(components)
n_freq = len(frequencies)

fig = plt.figure(figsize=(5*n_freq, 4*n_comp))
for i, comp in enumerate(components):
    for j, freq in enumerate(frequencies):
        panel = i*n_freq + j + 1
        hp.mollview(
            original_maps[freq][comp],
            fig=fig.number,
            sub=(n_comp, n_freq, panel),
            title=f"Original {comp} @ {freq} GHz",
            unit="K",
            cbar=True
        )
plt.suptitle("Original Maps (Before Processing)", fontsize=16, y=0.95)
plt.tight_layout()
plt.show()

# Plot processed maps (after beam convolution and resolution reduction)
print("\nPlotting processed maps...")
fig = plt.figure(figsize=(5*n_freq, 4*n_comp))
for i, comp in enumerate(components):
    for j, freq in enumerate(frequencies):
        panel = i*n_freq + j + 1
        hp.mollview(
            processed_maps[freq][comp],
            fig=fig.number,
            sub=(n_comp, n_freq, panel),
            title=f"Processed {comp} @ {freq} GHz",
            unit="K",
            cbar=True
        )
plt.suptitle("Processed Maps (After Beam Convolution & Resolution Reduction)", fontsize=16, y=0.95)
plt.tight_layout()
plt.show()

# Plot CFN maps
print("\nPlotting CFN maps...")
fig = plt.figure(figsize=(5*n_freq, 4))
for j, freq in enumerate(frequencies):
    panel = j + 1
    hp.mollview(
        cfn_maps[freq],
        fig=fig.number,
        sub=(1, n_freq, panel),
        title=f"CFN @ {freq} GHz",
        unit="K",
        cbar=True
    )
plt.suptitle("Total Maps (CMB + Foregrounds + Noise)", fontsize=16, y=0.9)
plt.tight_layout()
plt.show()

# Visualise MW CFN maps using MWTools.visualise_mw_map
print("\nVisualising MW CFN maps using MWTools...")
for freq in frequencies:
    print(f"\nPlotting CFN @ {freq} GHz in MW sampling...")
    MWTools.visualise_mw_map(
        cfn_mw_maps[freq], 
        title=f"CFN MW Sampling @ {freq} GHz",
        directional=False  # Using directional=False for regular MW maps
    )

# Plot difference maps
print("\nPlotting conversion difference maps...")
fig = plt.figure(figsize=(5*n_freq, 4))
for j, freq in enumerate(frequencies):
    panel = j + 1
    cfn_hp_from_mw = SamplingConverters.mw_map_2_hp_map(cfn_mw_maps[freq], lmax=desired_lmax)
    difference_map = cfn_maps[freq] - cfn_hp_from_mw
    hp.mollview(
        difference_map,
        fig=fig.number,
        sub=(1, n_freq, panel),
        title=f"HP - MW→HP Difference @ {freq} GHz",
        unit="K",
        cbar=True
    )
plt.suptitle("Conversion Error Maps (Original HP - MW→HP)", fontsize=16, y=0.9)
plt.tight_layout()
plt.show()

print(f"Original resolution: nside={hp.get_nside(original_maps[frequencies[0]][components[0]])}")
print(f"Processed resolution: nside={nside}")
print(f"Standard beam applied: {np.degrees(standard_fwhm_rad)*60:.1f} arcmin FWHM")
print(f"MW map dimensions: {cfn_mw_maps[frequencies[0]].shape}")

# Clear large variables
del original_maps, processed_maps, cfn_maps, cfn_mw_maps

In [None]:
# Wavelet Transform Demonstration
# Apply multi-scale wavelet transforms to the 030 GHz CFN map

# Load the saved CFN map for 030 GHz
target_freq = "030"
print(f"Loading CFN map for {target_freq} GHz...")

cfn_filepath = file_templates["cfn"].format(frequency=target_freq, realisation=realisation, lmax=desired_lmax)
cfn_map = hp.read_map(cfn_filepath, verbose=False)

# Convert to MW sampling for wavelet transforms
cfn_mw_map = SamplingConverters.hp_map_2_mw_map(cfn_map, lmax=desired_lmax)
print(f"Converted CFN @ {target_freq} GHz to MW sampling, shape: {cfn_mw_map.shape}")

# WAVELET FILTER VISUALISATION
print(f"\n" + "=" * 60)
print("AXISYMMETRIC WAVELET FILTER VISUALISATION")
print("=" * 60)

# Visualise the axisymmetric wavelet filters
print("Displaying axisymmetric wavelet filters used in the decomposition...")
MWTools.visualise_axisym_wavelets(L=L, lam=lam)

# Apply wavelet transform
print(f"\nApplying wavelet transform to CFN @ {target_freq} GHz...")
wavelet_coeffs, _ = MWTools.wavelet_transform_from_map(
    cfn_mw_map, 
    L=L, 
    N_directions=N_directions, 
    lam=lam
)

# Report decomposition structure
n_scales = len(wavelet_coeffs)  # Total number of scales (scaling coeffs = scale 0, then wavelet scales 1, 2, ...)
print(f"Total number of scales: {n_scales}")
print(f"Components: scale 0 (scaling) + scales 1-{n_scales-1} (wavelets)")

for i, coeff in enumerate(wavelet_coeffs):
    scale_lmax = coeff.shape[1] - 1  # MW map shape is (L, 2*L-1), so lmax = L-1
    if i == 0:
        print(f"Scale 0 (scaling) coefficients shape: {coeff.shape}, effective lmax: {scale_lmax}")
    else:
        print(f"Scale {i} (wavelet) coefficients shape: {coeff.shape}, effective lmax: {scale_lmax}")

# Visualise original CFN map in MW sampling
print(f"\nOriginal CFN map @ {target_freq} GHz:")
MWTools.visualise_mw_map(
    cfn_mw_map, 
    title=f"Original CFN {target_freq} GHz",
    directional=False
)

# Visualise all wavelet components (scaling coefficients + wavelet scales)
print(f"\nWavelet decomposition components:")
for scale in range(len(wavelet_coeffs)):
    scale_lmax = wavelet_coeffs[scale].shape[1] - 1
    if scale == 0:
        continue # for some reason scale 0 wont plot
    else:
        print(f"Scale {scale} - Wavelet coefficients (lmax={scale_lmax}):")
        title = f"Scale {scale} Wavelet Coefficients {target_freq} GHz (lmax={scale_lmax})"
    
    MWTools.visualise_mw_map(
        wavelet_coeffs[scale][0],  # First (and only) direction for axisymmetric case
        title=title,
        directional=False
    )

# Test synthesis (reconstruction) from wavelet coefficients
print(f"\nTesting wavelet synthesis (reconstruction)...")
reconstructed_mw = MWTools.inverse_wavelet_transform(
    wavelet_coeffs, 
    L=L, 
    N_directions=N_directions, 
    lam=lam
)

# Visualise reconstructed map
print(f"\nReconstructed CFN map:")
MWTools.visualise_mw_map(
    reconstructed_mw, 
    title=f"Reconstructed CFN {target_freq} GHz",
    directional=False
)

# Compute and visualise reconstruction error
reconstruction_error_mw = cfn_mw_map - reconstructed_mw
print(f"\nReconstruction error map:")
MWTools.visualise_mw_map(
    reconstruction_error_mw, 
    title=f"Reconstruction Error {target_freq} GHz",
    directional=False
)

max_error = np.max(np.abs(reconstruction_error_mw))
print(f"Maximum absolute error: {max_error:.2e}")
print(f"\nWavelet transform demonstration complete for {target_freq} GHz")
print(f"Successfully decomposed map into {n_scales} total scales:")
print(f"  - Scale 0: Scaling coefficients (large-scale structure)")
print(f"  - Scales 1-{n_scales-1}: Wavelet coefficients (progressively finer scales)")
print(f"Each scale captures different angular scales of the CMB and foreground")

# Clear large variables
del cfn_map, cfn_mw_map, wavelet_coeffs, reconstructed_mw, reconstruction_error_mw

In [None]:
processor = ProcessMaps(
    components,
    wavelet_components,
    frequencies,
    1,
    start_realisation=0,
    desired_lmax=desired_lmax,
    directory=data_directory,
)
processor.produce_and_save_cfns()

processor.produce_and_save_wavelet_transforms(
        N_directions,
        lam,
    )

# Clear processor to free memory
print("Clearing processor to free memory...")
del processor

## SILC Stage
 
Once decomposing the CFN into wavelet scales, our goal is to perform internal linear combination (ILC) on each scale. An optimisation formula is applied to compute the weighting for each frequency channel that minimises the variance of the foreground and noise. This is subject to the constraint that $\sum_i w_i(p) a_i = 1$ where $i$ are the frequencies, $p$ the pixel, $w_i$ and the weights and $a_i$ is the power spectral density of the component we want to preserve (currently, the CMB, for which $a_i = 1 \forall \, i$). 

First, we run the SILC stage of the pipeline on the GPU using the `ProduceSILC` class. The results should be saved to your data directory such that this only needs to be performed once. Then, we will confirm at an example scale that $\sum_i w_i(p) = 1$ as demanded by the ILC procedure. Next, we will show the ILC synthesised map across the scales. Finally, we will find the spectra of the CMB, CFN and ILC synthesised map. If the procedure is succesful, we expect the ILC power spectrum to be closer to the CMB than the ILC as a function of multipole. However, since the ILC only works up to second order moments. This will build the case for applying an ML enhancement stage, which will be covered in the next notebook. 

In [None]:
# SILC Algorithm - Run the Scale-discretised Internal Linear Combination
print("Setting up SILC parameters...")

# Check what device JAX is using for SILC
print(f"JAX devices available: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")

# Initialise SILC producer
ilc_producer = ProduceSILC(
    ilc_components=["cfn"],  # Process CFN maps to extract CMB
    frequencies=frequencies,
    realisations=1,
    start_realisation=0,
    lmax=desired_lmax,
    N_directions=N_directions,
    lam=lam,
    synthesise=True,
    directory=data_directory,
)


# Run the SILC pipeline
print("\nRunning SILC pipeline...")
print("This will:")
print("  1. Load wavelet coefficients for each frequency and scale")
print("  2. Double resolution of wavelet maps")
print("  3. Calculate covariance matrices at each scale")
print("  4. Compute ILC weight vectors")
print("  5. Create ILC maps at each scale")
print("  6. Trim back to original resolution")
print("  7. Synthesize final ILC map from all scales")

ilc_producer.process_wavelet_maps(
    save_intermediates=True,  # Save intermediate results for analysis
    visualise=False  # We'll do visualisation separately
)

# Note: Keep ilc_producer for the next cell's analysis

In [None]:
# Comprehensive SILC Results Visualisation
print("SILC Results Analysis and Visualisation")
print("=" * 60)

# Get the middle scale for weight map for an example scale to plot
middle_scale_idx = len(ilc_producer.scales)// 2
middle_scale = list(ilc_producer.scales)[middle_scale_idx]
print(f"Analysing scale {middle_scale} (middle scale out of {len(ilc_producer.scales)} total scales)")

# 1. WEIGHT MAP ANALYSIS
print("\n" + "=" * 60)
print("1. ILC WEIGHT MAP ANALYSIS")
print("=" * 60)

# Load and visualise weight maps for the middle scale
weight_vector_path = file_templates['weight_vector_matrices'].format(
    scale=middle_scale, 
    realisation=realisation, 
    lmax=desired_lmax, 
    lam=lam
)

if os.path.exists(weight_vector_path):
    weight_vector = np.load(weight_vector_path)
    print(f"Weight vector shape: {weight_vector.shape}")
    print(f"Weight vector represents contributions from {len(frequencies)} frequencies")
    
    # Visualise weight maps for each frequency
    for freq_idx, freq in enumerate(frequencies):
        print(f"\nWeight map for {freq} GHz:")
        MWTools.visualise_mw_map(
            weight_vector[:, :, freq_idx],
            title=f"ILC Weight Map - {freq} GHz - Scale {middle_scale}",
            directional=False
        )
        
    # 2. VERIFY WEIGHT SUM = 1 (ILC constraint)
    print(f"\n" + "-" * 40)
    print("2. ILC CONSTRAINT VERIFICATION: $\sum_i w_i = 1$")
    print("-" * 40)
    
    # Sum weights across all frequencies
    weight_sum = np.sum(weight_vector, axis=2)
    
    # Calculate weight statistics
    print(f"\nWeight Statistics:")
    print(f"  Mean weight sum: {np.mean(weight_sum):.6f}")
    print(f"  Std weight sum: {np.std(weight_sum):.6f}")
    print(f"  Target (ideal): 1")
    
    # Individual frequency weight statistics
    print(f"\nIndividual Frequency Weight Statistics:")
    for freq_idx, freq in enumerate(frequencies):
        freq_weights = weight_vector[:, :, freq_idx]
        print(f"  {freq} GHz - Mean: {np.mean(freq_weights):.6f}, Std: {np.std(freq_weights):.6f}")
    
    # Check constraint satisfaction
    weight_error = np.abs(weight_sum - 1.0)
    max_error = np.max(weight_error)
    mean_error = np.mean(weight_error)
    
    print(f"\nILC Constraint Error Analysis:")
    print(f"  Mean absolute error: {mean_error:.2e}")
    print(f"  Max absolute error: {max_error:.2e}")
    
    # Visualise weight sum map to check ILC constraint
    print(f"\nVisualising weight sum map (should be ~ 1 everywhere):")
    MWTools.visualise_mw_map(
        weight_sum,
        title=f"Weight Sum Map - Scale {middle_scale} (ILC Constraint Check)",
        directional=False
    )
        
else:
    print(f"Weight vector file not found: {weight_vector_path}")

# 3. ILC SYNTHESISED MAP
print(f"\n" + "=" * 60)
print("3. ILC SYNTHESISED MAP VISUALIZATION")
print("=" * 60)

# Load and visualise the final ILC synthesised map
ilc_synth_path = file_templates['ilc_synth'].format(
    realisation=realisation, 
    lmax=desired_lmax, 
    lam=lam
)

if os.path.exists(ilc_synth_path):
    ilc_synth_map = np.load(ilc_synth_path)
    print(f"ILC synthesised map shape: {ilc_synth_map.shape}")
    
    # Convert to HEALPix for visualisation
    if len(ilc_synth_map.shape) == 2:  # MW sampling
        ilc_hp_map = SamplingConverters.mw_map_2_hp_map(ilc_synth_map, lmax=desired_lmax)
    else:  # Already HEALPix
        ilc_hp_map = ilc_synth_map
    
    
    # Visualise ILC map
    fig = plt.figure(figsize=(12, 8))
    hp.mollview(
        ilc_hp_map,
        title="SILC Extracted CMB Map",
        unit="K",
        cbar=True
    )
    plt.show()
    
else:
    print(f"ILC synthesised map not found: {ilc_synth_path}")

# 4. POWER SPECTRA COMPARISON
print(f"\n" + "=" * 60)
print("4. POWER SPECTRA ANALYSIS")
print("=" * 60)

# Initialise visualiser for power spectra
visualiser = Visualise(
    frequencies=frequencies,
    realisation=realisation,
    lmax=desired_lmax,
    lam_list=[lam],
    directory=data_directory
)

# Components to compare: CMB, CFN, and ILC
components_to_plot = ['cmb', 'cfn', 'ilc_synth']
print(f"Plotting power spectra for: {components_to_plot}")

# Visualise power spectra
visualiser.visualise_power_spectra(
    comps=components_to_plot,
)


# Clear large variables to prevent memory overload
if 'weight_vector' in locals():
    del weight_vector
if 'weight_sum' in locals():
    del weight_sum
if 'ilc_synth_map' in locals():
    del ilc_synth_map
if 'ilc_hp_map' in locals():
    del ilc_hp_map
del ilc_producer, visualiser