# 4. Internal Linear Combination (ILC) applied to wavelet coefficient maps across all frequency channels at each scale

In [None]:
import math 
import os 
import jax
jax.config.update("jax_enable_x64", True)
import s2fft
import healpy as hp
import numpy as np
import s2wav
import matplotlib.pyplot as plt
%matplotlib inline 

### 4.1 Create the directory structure to store data

In [None]:

def check_and_create_ilc_directories():
    """
    Checks for the existence of a specific nested directory structure for ILC processing and creates any missing directories.
    This includes handling multiple levels of nested directories as shown in the provided folder structure.

    The structure checked is:
    - ILC
      - covariance_matrix
      - ILC_doubled_maps
      - ILC_processed_wavelet_maps
      - synthesized_ILC_MW_maps
      - wavelet_doubled
      - weight_vector_data
    """

    # Define the root directories
    # base_dir = "path_to_base_directory"  # Set this to your base directory path
    # ilc_dir = os.path.join(base_dir, "ILC")
    ilc_dir = "ILC"
    # List of directories under the ILC directory
    ilc_sub_dirs = ["covariance_matrix", "ILC_doubled_maps", "ILC_processed_wavelet_maps", "synthesized_ILC_MW_maps","wavelet_doubled","weight_vector_data"]

    # Create the ILC directory and its subdirectories
    create_directory(ilc_dir)
    for sub_dir in ilc_sub_dirs:
        create_directory(os.path.join(ilc_dir, sub_dir))

def create_directory(dir_path):
    """
    Checks if a directory exists, and if not, creates it. Prints the status of the directory.
    
    Parameters:
        dir_path (str): The path of the directory to check and create.
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        print(f"Created directory: {dir_path}")
    else:
        print(f"Directory already exists: {dir_path}")

# Run the function to check and create directories as needed
check_and_create_ilc_directories()


### 4.2 ILC Functions 

In [None]:
def mw_alm_2_hp_alm(MW_alm, lmax):
    '''
    Converts MW alm coefficients to HEALPix alm coefficients.
    
    Arg:
        MW_alm: 2D array of shape (Lmax, 2*Lmax-1) (MW sampling, McEwen & Wiaux)
        lmax: maximum multipole moment of the MW alm
    Returns:
        hp_alm: 1D array in healpix 
    '''
    # Initialize the 1D hp_alm array with the appropriate size
    hp_alm = np.zeros(hp.Alm.getsize(lmax), dtype=np.complex128)
        
    for l in range(lmax + 1):
        for m in range(-l, l + 1):
            index = hp.Alm.getidx(lmax, l, abs(m))
            if m < 0:
                hp_alm[index] = (-1)**m * np.conj(MW_alm[l, lmax + m])
            else:
                hp_alm[index] = MW_alm[l, lmax + m]

    return hp_alm

def visualize_wavelet_coefficient_map(MW_Pix_Map, title, variable, min=None, max=None):
    """
    Processes a wavelet coefficient map and visualizes it using HEALPix mollview.

    Parameters:
       MW_Pix_map: the MW wavelet coefficient map. (MW sampling, McEwen & Wiaux)
       title (str): the title for the plot.
    Returns:
        Displays a mollview map.
    """
    
    if MW_Pix_Map.shape[0] != 1:
        L_max = MW_Pix_Map.shape[0]
    else:
        # 3 dimensions MW_Pix_Map (product of s2wav.analysis)
        L_max = MW_Pix_Map.shape[1]

    original_map_alm = s2fft.forward(MW_Pix_Map, L=L_max)
    # print("Original map alm shape:", original_map_alm.shape)
    
    original_map_hp_alm = mw_alm_2_hp_alm(original_map_alm, L_max - 1)
    original_hp_map = hp.alm2map(original_map_hp_alm, nside=(L_max - 1)//2)

    hp.mollview(
        # original_hp_map * 1e5,
        original_hp_map,
        coord=["G"],
        title=title+variable,
        unit=r"K",
        # min=min, max=max,  # Uncomment and adjust these as necessary for better visualization contrast
    )

    plt.show()
    return original_hp_map

def Single_Map_doubleworker(MW_Pix_Map):
    '''
    Arg:
        MW_Pix_Map: a MW wavelet coefficent pixel map of shape (1, Lmax, 2*Lmax-1) (MW sampling, McEwen & Wiaux)
        It is the output of s2wav.analysis
        (Scale: 0, size (1, 4, 7))

    Process:
        1. Covert MW Pixel Map to MW alm space using s2fft.forward

        2. Double alm: Add zero to the mw alms  
        
        3. Convert doubled mw alm to mw map 

    Returns:
        MW_Pix_Map_doubled: The MW pixel map with increased resolution.
    
    '''
    MW_alm = s2fft.forward(MW_Pix_Map, L = MW_Pix_Map.shape[1])
    L = MW_alm.shape[0]
    padded_alm = np.zeros((2*L-1,2*(2*L-1)-1),dtype=np.complex128)
    # L = 4 | l = 0,1,2,3 , true lmax is L-1 = 3 | m = -3...0...(L-1)| m = 2(L-1)+1 = 2L-1      
    # double true lmax: 2*3 = 6, and add 1, new L = 7 = 2(L-1)+1 = 2L-1
    # new m = -6...0...(new L-1) | new m = 2*(2L-1)-1
    inner_matrix_middle = MW_alm.shape[1] // 2
    outer_matrix_middle = padded_alm.shape[1] // 2
    start_col = (outer_matrix_middle - inner_matrix_middle)
    end_col = start_col + MW_alm.shape[1] # not included
      
    padded_alm[:MW_alm.shape[0], start_col:end_col] = MW_alm
    # print(padded_alm[:MW_alm.shape[0], start_col:start_col + end_col].shape)
    # print("padded alm size", padded_alm)
    # print(padded_alm.shape)
    
    MW_Pix_Map_doubled = np.real(s2fft.inverse(padded_alm, L = padded_alm.shape[0]))
    # print("Scale:","doubled map size", MW_Pix_Map_doubled.shape)
    # Note
    # assert imaginery part is around zero
    # print(np.imag(MW_Pix_Map_doubled))
    # MW_Pix_Map_doubled = s2fft.inverse(MW_alm_doubled, L = MW_alm_doubled.shape[0])
    
    return MW_Pix_Map_doubled

def smoothed_covariance(MW_Map1, MW_Map2):
    '''
    Args:
        MW_Map1, MW_Map2: same size MW pixel wavelet maps at different frequencies
    Returns:
        R_map: smoothed covariance map beteen MW_Map1 and MW_Map2
    '''
    smoothing_lmax = MW_Map1.shape[0]
    # Get the real part of the map
    map1 = np.real(MW_Map1)
    map2 = np.real(MW_Map2)
    # Covariance matrix
    R_MW_Pixel_map = np.multiply(map1,map2) + 0.j #Add back in zero imaginary part

    # smoothing in harmonic space for efficiency
    R_MW_alm = s2fft.forward(R_MW_Pixel_map, L = smoothing_lmax)

    nsamp = 1200.0
    lmax_at_scale_j = R_MW_alm.shape[0]
    npix = hp.nside2npix(1<<(int(0.5*lmax_at_scale_j)-1).bit_length())
    # (int(0.5*scale_lmax)-1).bit_length() calculates the number of bits necessary to represent the integer int(0.5*scale_lmax)-1 in binary.
    # 1 << (int(0.5*scale_lmax)-1).bit_length() performs a bitwise left shift, essentially calculating 2^(number of bits).
    scale_fwhm = 4.0 * math.sqrt(nsamp / npix)
    # for high resolution maps, it is still the same number pixels sampled by the actual range is smaller.
    # the beam will become very narrow.

    gauss_smooth = hp.gauss_beam(scale_fwhm,lmax=smoothing_lmax-1)
    MW_alm_beam_convolved = np.zeros(R_MW_alm.shape, dtype=np.complex128)

    # Convolve the MW alms with the beam
    for i in range(R_MW_alm.shape[1]):
        MW_alm_beam_convolved[:, i] = R_MW_alm[:, i] * gauss_smooth
    
    R_covariance_map = np.real(s2fft.inverse(MW_alm_beam_convolved, L = smoothing_lmax))

    return R_covariance_map

def load_frequency_data(base_path, file_template, frequencies, scales=None, realization = None):
    """
    Load NumPy arrays from dynamically generated file paths for each frequency and scale.
    
    Args:
        base_path (str): The base path where the files are located.
        file_template (str): The template for the file names, with placeholders for frequency and scale.
        frequencies (list): A list of frequency names.
        scales_: A lists of scales.
        
    Returns:
        dict: A dictionary where keys are tuples of (frequency, scale) and values are loaded NumPy arrays.
    """
    frequency_data = {}
    realization = str(realization).zfill(4)
    for frequency in frequencies:
        for scale in scales:
            # Generate the file path using the template and the current frequency and scale
            path = f"{base_path}/{file_template.format(frequency, scale, realization)}"
            try:
                frequency_data[(frequency, scale)] = np.load(path)
            except Exception as e:
                print(f"Error loading {path} for frequency {frequency} and scale {scale}: {e}, realization {realization}")
    return frequency_data

def double_and_save_wavelet_maps(original_wavelet_c_j, frequencies, scales, realization):
    """
    Doubles the resolution of wavelet maps and saves them with the realization number in the file name.

    Args:
        original_wavelet_c_j (dict): Dictionary containing the original wavelet maps.
        frequencies (list): List of frequency strings.
        scales (list): List of scale indices.
        realization (int): The realization number for file naming.
    """
    for i in frequencies:
        for j in scales:
            # Perform the doubling of the wavelet map for the given frequency and scale
            wavelet_MW_Pix_Map_doubled = Single_Map_doubleworker(original_wavelet_c_j[(i, j)])
            
            # Save the doubled wavelet map with the realization number in the filename
            # np.save(f"ILC/wavelet_doubled/Wav_Pix2_F{i}_S{j}_R{realization:04d}.npy", wavelet_MW_Pix_Map_doubled)
            np.save(f"ILC/wavelet_doubled/Wav_Pix2_F{i}_S{j}_R{realization}.npy", wavelet_MW_Pix_Map_doubled)

def calculate_covariance_matrix(frequencies, doubled_MW_wav_c_j, scale, realization):
    """
    Calculates the covariance matrices for given frequencies and saves them to disk,
    accommodating any size of the input data arrays.
    
    Args:
        frequencies (list): List of frequency indices.
        doubled_MW_wav_c_j (dict): Dictionary containing data arrays for covariance calculations.
        scale (int): The scale.
        realization (int): The realization.

    Returns:
        full_array: np.ndarray: A 4D array containing the covariance matrices for the given frequencies.
    """
    # Check dimensions of the first item to set the size of the covariance matrices
    if frequencies:
        sample_data = doubled_MW_wav_c_j[(frequencies[0], scale)]
        n_rows, n_cols = sample_data.shape
    else:
        raise ValueError("Frequency list is empty.")
    
    total_frequency = len(frequencies)
    # Initialize a 4D array to store the covariance matrices
    full_array = np.zeros((total_frequency, total_frequency, n_rows, n_cols))

    # Calculate the covariance matrix and save each one
    # Calculate the upper triangle only since the matrix is symmetric
    for i in range(total_frequency):
        for fq in range(i, total_frequency):
            # print(f"Calculating covariance between {frequencies[i]} and {frequencies[fq]}")
            
            full_array[i, fq] = smoothed_covariance(doubled_MW_wav_c_j[(frequencies[i], scale)],
                                                    doubled_MW_wav_c_j[(frequencies[fq], scale)])
            # Save the computed covariance matrix
            # np.save(f"ILC/covariance_matrix/cov_MW_Pix2_F{frequencies[i]}_F{frequencies[fq]}_S{scale}", full_array[i, fq])
    f = '_'.join(frequencies)
    
    # Testing if single process output is the same as multiprocessing output
    # np.save(f"ILC/covariance_matrix/half_original_{scale}_R{realization}", full_array)
    # Fill the symmetric part of the matrix
    for l1 in range(1, total_frequency):
        for l2 in range(l1):
            full_array[l1, l2] = full_array[l2, l1]
    np.save(f"ILC/covariance_matrix/cov_MW_Pix2_F{f}_S{scale}_R{realization}_Full", full_array)
    # print(full_array.shape)
    return full_array

def compute_weight_vector(R,scale,realization):
    """
    Processes the given 4D matrix R by computing and saving the weight vectors for each matrix in the first two dimensions.
    Also stores results in memory as arrays and saves them to disk. Adjusts the size of the identity vector based on sub-matrix size.

    Args:
        R (np.ndarray): A 4D matrix with dimensions suitable for swapping and inverting.
        scale (int): The scale.
        realization (int): The realization.
    Returns:
        inverses: (np.ndarray): An Array containing the inverse matrices
        weight_vectors (np.ndarray): A 3D Array containing the weight vector.
        The size of the first two dimensions of the weight vector is the size of the wavelet coefficient map at the given scale.
        The third dimension is the weight vector (The contribution from each frequency).
        Each element of the weight vector is a 1D array.
        singular_matrices_location (list): The locations of singular matrices.
    """
    # print(R.shape)
    # Swap the axes to get R_Pix
    R_Pix = np.swapaxes(np.swapaxes(R, 0, 2), 1, 3)
    
    # Get dimensions for looping and size of sub-matrices
    dim1, dim2, subdim1, subdim2 = R_Pix.shape
    # print(dim1, dim2, subdim1, subdim2)
    # Create arrays to store inverses and weight vectors
    inverses = np.zeros((dim1, dim2, subdim1, subdim2))
    weight_vectors = np.zeros((dim1, dim2, subdim1))

    # Realiztion 6 has a singular matrix
    # Adjust identity vector size based on sub-matrix dimensions
    identity_vector = np.ones(subdim2, dtype=float)
    singular_matrices_lcoation = []
    singular_matrices = []
    for i in range(dim1):
        for j in range(dim2):
            
            det = np.linalg.det(R_Pix[i, j])
            if det == 0:
                print(i,j)
                print(R_Pix[i, j].shape)
                print(det)
                print(R_Pix[i, j])
                print("Pixel", i,j)
                print("The matrix is singular.")
                # np.linalg.inv(R_Pix[i, j])
                zeros = np.zeros((subdim1))

                singular_matrices_lcoation.append((i,j))
                singular_matrices.append(R_Pix[i, j])
                weight_vectors[i, j] = zeros
                np.save(f"ILC/weight_vector_data/inverse_singular_matrix_{i}_{j}_S{scale}_R{realization}.npy", R_Pix[i,j])
                print("saved at ", f"ILC/weight_vector_data/inverse_singular_matrix_{i}_{j}_S{scale}_R{realization}.npy")
                
            else:
                # print("The matrix is not singular.")
                # Invert the matrix at position (i, j)
                inverses[i, j] = np.linalg.inv(R_Pix[i, j])
            
                # Compute the weight vector
                numerator = np.dot(inverses[i, j], identity_vector)
                denominator = np.dot(np.dot(inverses[i, j], identity_vector),identity_vector)
                weight_vectors[i, j] = numerator / denominator
        
            # Save the inverse matrix and weight vector to disk
            # np.save(f"../weight_vector_data/inverse_matrix_{i}_{j}.npy", inverses[i, j])
            # np.save(f"../weight_vector_data/weight_vector_{i}_{j}.npy", weight_vectors[i, j])

    np.save(f"ILC/weight_vector_data/weight_vector_S{scale}_R{realization}", weight_vectors)
            

    return inverses, weight_vectors,singular_matrices_lcoation,singular_matrices
 
def compute_ILC_for_pixel(i, j, frequencies, scale, weight_vector_load, doubled_MW_wav_c_j):
    """
    Computes the Internal Linear Combination (ILC) value for a specific pixel using the provided wavelet coefficients and weight vectors.

    Args:
        i (int): The row index of the pixel in the map.
        j (int): The column index of the pixel in the map.
        frequencies (list): A list of frequency identifiers corresponding to different channels.
        scale (int): The scale of the wavelet coefficient map.
        weight_vector_load (list): A list where each element corresponds to the weight vector map at a scale.
        doubled_MW_wav_c_j (dict): A dictionary with keys as tuples of (frequency, scale) and values as 2D arrays of wavelet coefficients for each pixel.

    Returns:
        float: The ILC value computed for the pixel at position (i, j).
    """
    # Create a vector of pixel values of all frequencies at the given pixel position
    pix_vector = np.array([
        doubled_MW_wav_c_j[(frequencies[k], scale)][i, j] for k in range(len(frequencies))
    ])
    return np.dot(weight_vector_load[scale][i, j], pix_vector)

def create_doubled_ILC_map(frequencies, scale, weight_vector_load, doubled_MW_wav_c_j, realization):
    
    """
    Creates a doubled Internal Linear Combination (ILC) map for a given scale and realization.
    Doubled because the resolution of the wavelet coefficient map is doubled.
    
    Args:
        frequencies (list): A list of frequency identifiers corresponding to different channels.
        scale (int): The wavelet coefficient scale.
        weight_vector_load (list): A list where each element corresponds to the weight vector map at a scale.
        doubled_MW_wav_c_j (dict): A dictionary with keys as tuples of (frequency, scale) and values as 2D arrays of wavelet coefficients for each pixel.
        realization (int): The realization index used for saving the resulting ILC map.

    Returns:
        doubled_map (np.ndarray): The generated ILC map as a 2D numpy array.
    """
    # Get the size of the wavelet map
    size = doubled_MW_wav_c_j[(frequencies[0],scale)].shape
    
    # Initialize the doubled map
    doubled_map = np.zeros((size[0], size[1]))
    
    # Compute the ILC value for each pixel in the map
    for i in range(doubled_map.shape[0]):
        for j in range(doubled_map.shape[1]):
            doubled_map[i, j] = compute_ILC_for_pixel(i, j, frequencies, scale,weight_vector_load, doubled_MW_wav_c_j)
    np.save(f"ILC/ILC_doubled_maps/ILC_Map_S{scale}_R{realization}", doubled_map)
    
    return doubled_map

def trim_to_original(MW_Doubled_Map, scale, realization):
    '''
    Input:
        MW_Doubled_Map: The MW Pixel map with increased resolution.
        original_shape: A tuple indicating the original size of the alm data.

    Process:
        1. convet it to alm  
        1. Trim the alm back to its original dimensions.
        2. Convert the trimmed alm array back to a pixel map using an inverse spherical transform.

    Returns:
        MW_Pix_Map_original: The pixel map converted back to its original resolution.
    '''

    # 8,15, 15//2 = 7
    MW_alm_doubled = s2fft.forward(MW_Doubled_Map, L=MW_Doubled_Map.shape[0])
    # print(MW_alm_doubled)
    inner_matrix_vertical = int((MW_Doubled_Map.shape[0]+1) / 2)
    inner_matrix_horizontal = int(2*inner_matrix_vertical - 1)
    
    inner_matrix_middle = inner_matrix_horizontal // 2
    outer_matrix_middle = MW_Doubled_Map.shape[1] // 2
    start_col = (outer_matrix_middle - inner_matrix_middle)
    end_col = start_col + inner_matrix_horizontal # not included

    # Extract the original size part from the doubled alm data
    trimmed_alm = MW_alm_doubled[:inner_matrix_vertical, start_col:end_col]
    print("trimmed alm shape", trimmed_alm.shape)
    # Convert trimmed alm to the original pixel map
    MW_Pix_Map_original = s2fft.inverse(trimmed_alm, L=trimmed_alm.shape[0])[np.newaxis, ...]
    np.save(f"ILC/ILC_processed_wavelet_maps/ILC_processed_wav_Map_S{scale}_R{realization}", MW_Pix_Map_original)
    return MW_Pix_Map_original
   

### 4.3 Combine all the steps for Internal Linear Combination together.

In [None]:
def process_wavelet_maps(base_path, file_template, frequencies, scales, realizations):
    for realization in realizations:
        realization_str = str(realization).zfill(4)
        print(f"Processing realization {realization_str}")
        path = f"ILC/ILC_processed_wavelet_maps/ILC_processed_wav_Map_S5_R{realization_str}.npy"
        if os.path.exists(path):
                print(f"File {path} already exists.")
                continue
        original_wavelet_c_j = load_frequency_data(base_path, file_template, frequencies, scales, realization_str)

        # Double the resolution of the wavelet maps
        double_and_save_wavelet_maps(original_wavelet_c_j, frequencies, scales, realization_str)

        doubled_MW_wav_c_j = load_frequency_data("ILC/wavelet_doubled/", "Wav_Pix2_F{}_S{}_R{}.npy", frequencies, scales, realization_str)

        # Calculate the covariance matrices for each scale
        for i in range(len(scales)):      
            scale = i
            # print("Calculate covariance for Scale", i)
            calculate_covariance_matrix(frequencies, doubled_MW_wav_c_j, scale, realization_str)

        F_str = '_'.join(frequencies)
        R_covariance = [np.load(f"ILC/covariance_matrix/cov_MW_Pix2_F{F_str}_S{i}_R{realization_str}_Full.npy") for i in range(len(scales))]
        # print(len(R_covariance))
        # print(R_covariance[0].shape)  

        # Calculate the weight vectors for each frequency wavelet coefficient map using covariance matrix and the euqation.
        for scale in range(len(R_covariance)):
            # print(scale)
            compute_weight_vector(R_covariance[scale], scale, realization_str)
            
        weight_vector_load = [np.load(f"ILC/weight_vector_data/weight_vector_S{i}_R{realization_str}.npy") for i in range(len(scales))]

        doubled_maps = []

        # Create the doubled resolution ILC map for each scale
        for i in range(len(scales)):
            doubled_maps.append(create_doubled_ILC_map(frequencies, scales[i], weight_vector_load, doubled_MW_wav_c_j, realization=realization_str))

        doubled_maps = [np.load(f"ILC/ILC_doubled_maps/ILC_Map_S{i}_R{realization_str}.npy") for i in range(len(scales))]

        # Trim the doubled resolution ILC map back to the original resolution
        trimmed_maps = [trim_to_original(doubled_maps[i], i, realization_str) for i in range(len(scales))]
        
        for i in range(len(scales)):
            tilte = "wavelet coefficient map at scale: "
            visualize_wavelet_coefficient_map(trimmed_maps[i], tilte, str(i))
        # return trimmed_maps

In [None]:
# Define the base path, file template, frequencies, scales, and realizations
# Base path for the generated wavelet coefficient maps
base_path = "wavelet_transform/wavelets/wav_MW_maps"

# File template for the wavelet coefficient maps
file_template = "Wav_MW_Pix_F{}_S{}_R{}.npy"

frequencies = ['030', '044', '070', '100', '143', '217', '353', '545', '857']
# scale is based on the number of wavelet coefficient map generated by s2wav.analysis
scales = [0, 1, 2, 3, 4, 5] 
# realizations = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
realizations = [0]

process_wavelet_maps(base_path, file_template, frequencies, scales, realizations)


### 4.4 Visualize ILC combined Wavelet coefficient maps for each scale

In [None]:
# L max for the wavelet coefficient map
L_max = 32
N_directions = 1
for i in range(1):
    realization = str(realizations[i]).zfill(4)
    trimmed_maps = [np.load(f"ILC/ILC_processed_wavelet_maps/ILC_processed_wav_Map_S{i}_R{realization}.npy") for i in range(len(scales))]
    for j in range(len(scales)):
            visualize_wavelet_coefficient_map(trimmed_maps[j], "ILC wavelet coefficient map at scale: ", str(j))

### 4.5 Synthesize ILC wavelet coefficient maps of all scales together. 
Note: Unfinished task for handling scaling coefficient.

In [None]:
import s2wav
from s2wav import filters
L_max = 32
N_directions = 1
for i in range(len(realizations)):
    realization = str(realizations[i]).zfill(4)
    ILC_trimmed_wav_maps = [np.load(f"ILC/ILC_processed_wavelet_maps/ILC_processed_wav_Map_S{scale}_R{realization}.npy") for scale in range(len(scales))]

    filter = filters.filters_directional_vectorised(L_max, N_directions)
    # f_scal = np.array([[0]]) #np.load(f"wavelet_transform/wavelets/scal_coeffs/Scal_MW_Pix_F030.npy") 
    f_scal = np.array([[0]])    
    # [np.load(f"wavelet_transform/wavelets/scal_coeffs/Scal_MW_Pix_F{frequencies[i]}.npy") for i in range(len(frequencies))]

    MW_Pix = s2wav.synthesis(ILC_trimmed_wav_maps, L = L_max, f_scal = f_scal, filters = filter, N = 1)
    title = "ILC CMB Map realization: "
    visualize_wavelet_coefficient_map(MW_Pix, title, str(realizations[i]))
    np.save(f"ILC/synthesized_ILC_MW_maps/ILC_MW_Map_R{realization}", MW_Pix)