## C_l sampler initial notes and tests

In [2]:
import os
import glob
import sys

import numpy as np
import healpy as hp

from scipy import signal
from scipy.fftpack import fft, fft2, fftshift, fftfreq
from scipy import integrate
from scipy.stats import invgamma

from pygdsm import GlobalSkyModel16
from pygdsm import GlobalSkyModel


# Plotting
#import matplotlib.pyplot as plt
import cmocean
import cmocean.cm as cmo
import seaborn as sns
import pylab as plt
import matplotlib as mpl
import matplotlib.pyplot as pyplot
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import (inset_axes, InsetPosition, mark_inset)
from matplotlib import ticker
# tango colors for colorpairs
sys.path.append("/Users/user/Documents/hera/tango-colors")
# from tango_colors import Tango
# tango = Tango('HTML')
import corner


# HYDRA (for HERA antennas)
sys.path.append("/Users/user/Documents/hera/Hydra") # change this to your own path
import hydra
from hydra.utils import build_hex_array

In [3]:
plt.rcParams['font.size'] = '18'
plt.rcParams['savefig.facecolor']='white'
plt.rcParams['axes.titlepad'] = 16
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif"
})
plt.rcParams["errorbar.capsize"] = 5

# Function definitions

## Mode conversion functions

### `healpy2alms`

In [4]:
def healpy2alms(healpy_modes):
    """
    Takes a complex array of alms (positive modes only) and turns into
    a real array split as [real, imag] making sure to remove the 
    m=0 modes from the imag-part.
      
     Parameters
    ----------
    * healpy_modes (ndarray (complex)):
            Array of zeros except for the specified mode. 
            The array represents all positive (+m) modes including zeroth modes.
    
    Returns
    -------
    * alms (ndarray (floats))
            Array of zeros except for the specified mode. 
            The array represents all positive (+m) modes including zero 
            and is split into a real (first) and imag (second) part. The
            Imag part is smaller as the m=0 modes shouldn't contain and 
            imaginary part. 
    """
    lmax = hp.sphtfunc.Alm.getlmax(healpy_modes.size) # to remove the m=0 imag modes
    alms = np.concatenate((healpy_modes.real,healpy_modes.imag[(lmax+1):]))
        
    return alms    


### `alms2healpy`

In [5]:
def alms2healpy(alms, lmax):
    """
    Takes a real array split as [real, imag] (without the m=0 modes 
    imag-part) and turns it into a complex array of alms (positive 
    modes only) ordered as in HEALpy.
      
     Parameters
    ----------
    * alms (ndarray (floats))
            The array represents all positive (+m) modes including zero 
            and has double length, as real and imaginary values are split. 
            The first half is the real values.

    
    Returns
    -------
    * healpy_modes (ndarray (complex)):
            Array of zeros except for the specified mode. 
            The array represents all positive (+m) modes including zeroth modes.
            
    """
    
    real_imag_split_index = int((np.size(alms)+(lmax+1))/2)
    real = alms[:real_imag_split_index]
    
    add_imag_m0_modes = np.zeros(lmax+1)
    imag = np.concatenate((add_imag_m0_modes, alms[real_imag_split_index:]))
    
    healpy_modes = real + 1.j*imag
    
    return healpy_modes

### `lm_order_alms`

In [6]:
def lm_order_alms(alms, lmax):
    """
    Takes a real array split as [real, imag] (without the m=0 modes 
    imag-part) ordered as (m,l) and reorders as (l,m) (still omitting
    the m=0 imaginary modes)
      
     Parameters
    ----------
    * alms (ndarray (floats))
            The array represents all positive (+m) modes including zero 
            and has double length, as real and imaginary values are split. 
            The first half is the real values.

    * lmax (integer)
            The maximum ell-value 

    
    Returns
    -------
    * healpy_modes (ndarray (complex)):
            Array of zeros except for the specified mode. 
            The array represents all positive (+m) modes including zeroth modes.
            
    """
    
    healpy_modes = alms2healpy(alms,lmax)

    # intialise for-loop
    real_modes = np.zeros_like(healpy_modes.real)
    imag_modes = np.zeros_like(healpy_modes.imag[(lmax+1):])
    real_idx = 0
    imag_idx = 0
    
    for ell in np.arange(0,lmax+1):
        for em in np.arange(0,ell+1): 
            healpy_idx = hp.sphtfunc.Alm.getidx(lmax, ell, em)
            real_modes[real_idx] = healpy_modes[healpy_idx].real
            real_idx +=1
            
            if em != 0:
                imag_modes[imag_idx] = healpy_modes[healpy_idx].imag
                imag_idx +=1

    return np.concatenate((real_modes, imag_modes))


### `get_em_ell_idx`

In [7]:
def get_em_ell_idx(lmax):
    """
    Function to get the em, ell, and index of all the modes given the lmax. 
    (m,l)-ordering, (m-major ordering)

    Parameters
    ----------
    * lmax: (int)
        Maximum ell value for alms

    Returns
    -------
    * ems: (list (int))
        List of all the em values of the alms (m,l)-ordering (m-major)

    * ells: (list (int))
        List of all the ell values of the alms (m,l)-ordering (m-major)
        
    * idx: (list (int)) 
        List of all the indices for the alms

    """

    ells_list = np.arange(0,lmax+1)
    em_real = np.arange(0,lmax+1)
    em_imag = np.arange(1,lmax+1)
    
    Nreal = 0
    i = 0
    idx = []
    ems = []
    ells = []

    for em in em_real:
        for ell in ells_list:
            if ell >= em:
                idx.append(i)
                ems.append(em)
                ells.append(ell)
                
                Nreal += 1
                i +=1
    
    Nimag=0

    for em in em_imag:
        for ell in ells_list:
            if ell >= em:
                idx.append(i)
                ems.append(em)
                ells.append(ell)

                Nimag += 1
                i += 1

    return ems, ells, idx

### `find_common_true_index`

In [8]:
def find_common_true_index(arr_em, arr_ell, lmax):
    """
    Find the common index between two arrays of same length consisting of true and false.

    Parameters
    ----------
    * arr_em: (ndarray (boolean))
        The m array to compare, consisting of true and false

    * arr_ell: (ndarray (boolean))
        The ell array to compare, consisting of true and false

    Returns
    -------
    * idx_real: (int)
        The common index for the real part

    * idx_imag: (int)
        The common index for the imag part

    """
    real_imag_split_index = int(((lmax+1)**2 + (lmax+1))/2)

    real_idx = []
    imag_idx = []

    for idx in range(len(arr_em)):
        if arr_em[idx] and arr_ell[idx] and idx < real_imag_split_index:
            real_idx = idx
        elif arr_em[idx] and arr_ell[idx] and idx >= real_imag_split_index:
            imag_idx = idx

    return real_idx, imag_idx


### `get_idx_ml`

In [9]:
def get_idx_ml(em, ell, lmax):
    """
    Get the global index for the alms (m,l)-ordering (m-major) given a m 
    and ell value. 
    
    Parameters
    ----------
    * em: (int)
        The em value of the mode. Note, em cannot be greater than the ell value.

    * ell: (int)
        The ell value of the mode. Note, ell has to be larger or equal to the em value.

    * lmax: (int)
        The lmax of the modes

    Returns
    -------
    * common_idx_real: (int)
        The global index of the real part of the spherical harmonic mode

    * common_idx_imag: (int)
        The global index of the imaginary part of the spherical harmonic mode. 
        There are no m=0 imaginary mode, so in case of m=0 it returns and empty list [].

    """

    assert np.all(em <= ell), "m cannot be greater than the ell value"
    ems_idx, ells_idx, idx = get_em_ell_idx(lmax)

    em_check = np.array(ems_idx) == em
    ell_check = np.array(ells_idx) == ell

    common_idx_real, common_idx_imag = find_common_true_index(arr_em=em_check,
                                                              arr_ell=ell_check,
                                                              lmax=lmax)
    if common_idx_imag == []:
        idx_list = [common_idx_real]
    else:
        idx_list = [common_idx_real, common_idx_imag]
        
    for common_idx in idx_list:
        assert common_idx == idx[common_idx], "the global index does not match the index list"
        assert em == ems_idx[common_idx], "The em corresponding to the global index does not match the chosen em"
        assert ell == ells_idx[common_idx], "The ell corresponding to the global index does not match the vhosen ell"

    return common_idx_real, common_idx_imag

### `get_em_labels`

In [10]:
def get_em_labels(lmax):
    ells = np.arange(0,lmax+1)
    em_real = np.arange(0,lmax+1) 
    em_imag = np.arange(1,lmax+1)

    # ylabel = []
    # First append all real (l,m) values
    Nreal = 0
    i = 0
    idx = []
    ems = []
    for em in em_real:
        for ell in ells:
            if ell >= em:
                # ylabel.append((ell, em))
                if ell == em:
                    idx.append(i)
                    if ell == 0:
                        ems.append(fr'$m= {em}$')
                    else:
                        ems.append(fr'${em}$')
                Nreal += 1
                i += 1

    # Then all imaginary -- note: no m=0 modes!  
    Nimag = 0
    for em in em_imag:
        for ell in ells:
            if ell >= em:
                # ylabel.append((ell,em))
                if ell == em:
                    idx.append(i)
                    ems.append(fr'${em}$')
                Nimag += 1
                i += 1
    return ems, idx

### `get_healpy_from_gsm`

In [11]:
def get_healpy_from_gsm(freq, lmax, nside=64, resolution="low", output_model=False, output_map=False):
    """
    Generate an array of alms (HEALpy ordered) from gsm 2016 (https://github.com/telegraphic/pygdsm)
    
    Parameters
    ----------
    * freqs: (float or np.array)
        Frequency (in MHz) for which to return GSM model
        
    * lmax: (int)
        Maximum l value for alms
        
    * nside: (int)
        The NSIDE you want to upgrade/downgrade the map to. Default is nside=64.

    * resolution: (str)
        if "low/lo/l":  The GSM nside = 64  (default)
        if "hi/high/h": The GSM nside = 1024 

    * output_model: (Boolean) optional
        If output_model=True: Outputs model generated from the GSM data. 
        If output_model=False (default): no model output.

    * output_map: (Boolean) optional
        If output_map=True: Outputs map generated from the GSM data. 
        If output_map=False (default): no map output.

    Returns
    -------
    *healpy_modes: (np.array)
        Complex array of alms with same size and ordering as in healpy (m,l)
    
    *gsm_2016: (PyGDSM 2016 model) optional
        If output_model=True: Outputs model generated from the GSM data. 
        If output_model=False (default): no model output.

    *gsm_map: (healpy map) optional
        If output_map=True: Outputs map generated from the GSM data. 
        If output_map=False (default): no map output.
    
    """
    gsm_2016 = GlobalSkyModel16(freq_unit='MHz', resolution=resolution) 
    gsm_map = gsm_2016.generate(freqs=freq)
    gsm_upgrade = hp.ud_grade(gsm_map, nside)
    healpy_modes_gal = hp.map2alm(maps=gsm_upgrade,lmax=lmax)

    # Per default it is in gal-coordinates, convert to equatorial
    rot_gal2eq = hp.Rotator(coord="GC")
    healpy_modes_eq = rot_gal2eq.rotate_alm(healpy_modes_gal)

    if output_model == False and output_map == False: # default
        return healpy_modes_eq
    elif output_model == False and output_map == True:
        return healpy_modes_eq, gsm_map 
    elif output_model == True and output_map == False:
        return healpy_modes_eq, gsm_2016 
    else:
        return healpy_modes_eq, gsm_2016, gsm_map

## Plotting functions

### `discrete_cmap`

In [12]:
def discrete_cmap(N, base_cmap=None):
    """
    Create an N-bin discrete colormap from the specified input map

    Notes:
    
    If base_cmap is a string or None, you can simply do return plt.cm.get_cmap(base_cmap, N)
    The following works for string, None, or a colormap instance:
    """

    base = plt.colormaps.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    
    return base.from_list(cmap_name, color_list, N)

### `alm_plot`

In [13]:
def alm_plot(x_true, x_solns, lmax, title, filename, ylabel, xlabel, xlim=None, ylim=None, pdf=True, display = False):
    
    # Choose colours
    cmap20 = discrete_cmap(lmax,cmo.thermal)
    bgcolor = cmap20(16)
    linecolor = '0.5'
    
    # Convert to HEALpy alms for better intuition
    ells, ems = hp.sphtfunc.Alm.getlm(lmax)
    hp_true = alms2healpy(x_true,lmax)
    hp_solns = np.empty((x_solns.shape[0],hp_true.shape[0]),dtype=np.dtype(np.complex128))
    idx = 0
    for x_soln in x_solns:
        hp_soln = alms2healpy(x_soln,lmax)
        hp_solns[idx] = hp_soln
        idx += 1
        
    hp_soln_mean_real = np.mean(a=hp_solns.real, axis=0) 
    hp_soln_std_real = np.sqrt(np.var(a=hp_solns.real,axis=0))
    
    hp_soln_mean_imag = np.mean(a=hp_solns.imag, axis=0) 
    hp_soln_std_imag = np.sqrt(np.var(a=hp_solns.imag,axis=0))
    
    # Move points slightly to better see the different m-modes
    ell_em_labels = [float(ell) for ell in ells]
    em_idx = 0
    for em in ems:
        adjust = em/lmax*0.9
        ell_em_labels[em_idx] += adjust
        em_idx += 1
    
    ## PLOTTING SECTION
    fig, ax = plt.subplots(nrows=2, figsize=(16,6), sharex=True) 
    fig.subplots_adjust(hspace=0)
    
    for _ax in ax: 
        _ax.axhline(0,ls="-",color=linecolor, lw=2, alpha=0.5) 
        _ax.set_ylim([-109,109])
        _ax.set_xlim([-0.3,21.5])
        _ax.tick_params(length=6)

    # Plotting errorbars, marking out m=0 imag modes and outlier modes.
    outlier_ells_real = []
    ell_idx = 0
    for em in ems:
        # Real part
        diff_real = (hp_soln_mean_real-hp_true.real)[ell_idx]
        if diff_real > 100:
            real = ax[0].scatter(x=ell_em_labels[ell_idx], y=100, 
                                 marker='^', label='imag', color=cmap20(em))
            if np.any(np.isin(outlier_ells_real,ems[ell_idx])):
                ax[0].text(ell_em_labels[ell_idx]+0.2, 95, f'{diff_real:.0f}', va='center', rotation='horizontal')
            else: 
                ax[0].text(ell_em_labels[ell_idx]-0.3, 83, f'{diff_real:.0f}', va='center', rotation='horizontal')
                outlier_ells_real = np.append(outlier_ells_real,ells[ell_idx])
                        
        elif diff_real < -100:
            real = ax[0].scatter(x=ell_em_labels[ell_idx], y=-100, 
                                  marker='v', label='imag', color=cmap20(em))
            if np.any(np.isin(outlier_ells_real,ems[ell_idx])):
                ax[0].text(ell_em_labels[ell_idx]+0.2, -95, f'{diff_real:.0f}', va='center', rotation='horizontal')
            else: 
                ax[0].text(ell_em_labels[ell_idx]-0.3, -84, f'{diff_real:.0f}', va='center', rotation='horizontal')
                outlier_ells_real = np.append(outlier_ells_real,ells[ell_idx])
            
        else:
            real = ax[0].errorbar(x=ell_em_labels[ell_idx], y=diff_real, yerr=hp_soln_std_real[ell_idx], 
                                  fmt='o', label='real', color=cmap20(em))#,markerfacecolor=cmap20(em), markeredgecolor=bgcolor
                
        ## Imaginary. mark out the m=0 with an X
        diff_imag = (hp_soln_mean_imag-hp_true.imag)[ell_idx]
        if ems[ell_idx] == 0:
            imag = ax[1].scatter(x=ell_em_labels[ell_idx], y=diff_imag, 
              marker='x', label='imag', color=cmap20(em))#, color='white')#, markeredgecolor=cmap20(em))
        elif diff_imag > 100:
            imag = ax[1].scatter(x=ell_em_labels[ell_idx], y=100, 
                                 marker='^', label='imag', color=cmap20(em))
            ax[1].text(ell_em_labels[ell_idx]-0.4, 83, f'{diff_imag:.0f}', va='center', rotation='horizontal')
        elif diff_imag < -100:
            imag = ax[1].scatter(x=ell_em_labels[ell_idx], y=-100, 
                                  marker='v', label='imag', color=cmap20(em))
            ax[1].text(ell_em_labels[ell_idx]-0.4, -84, f'{diff_imag:.0f}', va='center', rotation='horizontal')
        else:
            imag = ax[1].errorbar(x=ell_em_labels[ell_idx], y=diff_imag, yerr=hp_soln_std_imag[ell_idx], 
                          fmt='o', label='imag', color=cmap20(em))#, markerfacecolor=cmap20(em), markeredgecolor=bgcolor)
        
        ell_idx+=1
        
    # customising ticks
    ax[0].tick_params(bottom=True, direction='inout')
    ax[1].tick_params(top=True, direction='inout')

    ax[0].tick_params(left=True, direction='out')
    ax[1].tick_params(left=True, direction='out')
    
    # Plot labels "real" or "imag"
    ax[0].text(20,-70,r'real', 
         bbox=dict(boxstyle="round",
                      fc=bgcolor, 
                      ec=None, 
                      lw=0, 
                      alpha=0.1))
    
    ax[1].text(19.8,-70,r'imag', 
         bbox=dict(boxstyle="round",
                  fc=bgcolor, 
                  ec=None, 
                  lw=0, 
                  alpha=0.1))
    
    # custom colorbar for m-values
    cax = fig.add_axes([0.585, 0.835, 0.3, 0.02]) #left, bottom, width, height
    norm = mpl.colors.Normalize(vmin=0, vmax=lmax)
    cbar = mpl.colorbar.ColorbarBase(cax, cmap=cmap20,
                                    norm=norm,
                                    orientation='horizontal')
    cbar.set_label(r'$m$-value')
    
    # title and labels
    ax[0].set_title(title)
    ax[1].set_xlabel(xlabel, labelpad=8)
    fig.text(0.06, 0.5, ylabel, va='center', rotation='vertical')
    ax[1].set_xticks(np.arange(0,lmax+1,1)) # set xticks and labels to be integer and stop at lmax
    
    if pdf:
        filetype = '.pdf'
    else:
        filetype = '.png'
    # ax[0].yscale('log')
    plt.savefig(figpath+filename+filetype,
            bbox_inches='tight',
            transparent=False,
            dpi=300)#fig.dpi)

    if display == False:
        plt.close()

    return hp_soln_std_real, hp_soln_std_imag

# Main

In [24]:
# Display available folders
parentpath = '/Users/user/Documents/hera/cl_sampler_analysis/data/'
# [file[52:] for file in sorted(glob.glob(parentpath+'*'))]

In [47]:
figpath = parentpath+'vis_response_freq_test_plots/'
if not os.path.isdir(figpath): os.makedirs(figpath)

for frequency in np.arange(560,810,10):
    folder = f'{frequency}MHz_1m/'

    # Set paths and create figure folder within data folder
    path = parentpath+folder
    # figpath = path+'figures/'
    # if not os.path.isdir(figpath): os.makedirs(figpath)

    # Load precomputed data and see which are available
    try:
        precomp = np.load(path+'precomputed_data.npz')
    except:
        precomp = np.load(path+'precomputed_data_20_1.npz')

    vis_response = precomp['vis_response']
    autos = precomp['autos']
    x_true = precomp['x_true']
    inv_noise_cov = precomp['inv_noise_cov']
    min_prior_std = precomp['min_prior_std']
    inv_prior_cov = precomp['inv_prior_cov']
    a_0 = precomp['a_0']
    data_seed = precomp['data_seed']
    prior_seed = precomp['prior_seed']
    wf_soln = precomp['wf_soln']
    nside = precomp['nside']
    lmax = precomp['lmax']
    dish_diameter = precomp['dish_diameter']
    freqs = precomp['freqs']
    lsts_hours = precomp['lsts_hours']
    precomp_time = precomp['precomp_time']

    ant_pos = np.load(path+'ant_pos.npz')


    # Baselines (didn't output this! So recalculating)
    autos_only = False
    include_autos = False
    
    dict_ants = dict(ant_pos)
    
    ants = [int(ant) for ant in dict_ants]
    antpairs = []
        
    if autos_only == False and include_autos == False:
        auto_ants = []
    for i in ants:
        for j in ants:
            # Toggle via keyword argument if you want to keep the auto baselines/only have autos
            if include_autos == True:
                if j >= i:
                    antpairs.append((ants[i],ants[j]))
            elif autos_only == True:
                if j == i:
                    antpairs.append((ants[i],ants[j]))
            else:
                if j == i:
                    auto_ants.append((ants[i],ants[j]))
                if j > i:
                    antpairs.append((ants[i],ants[j]))
                        
    ant_labels = antpairs
    ant_labels[0] = f'bl = {antpairs[0]}'


    Nlsts = len(lsts_hours)
    Nfreq = len(freqs)
    Nbl = vis_response.shape[0]//Nfreq//Nlsts
    assert(Nbl == len(antpairs))
    Nalm = vis_response.shape[1]
    Nvis = Nlsts*Nfreq*Nbl
    
    bl_border = vis_response.shape[0]/Nbl # bls  d
    freq_border = bl_border/Nfreq # nfreqs  c 
    lsts_border = freq_border/Nlsts # nlsts b
    
    idx_x = [i for i in range(0, Nvis, Nvis//Nbl)]
    ems, idx_y = get_em_labels(lmax)

    fig, ax1 = plt.subplots(nrows=1, figsize=(16,9)) 
    log_operator = np.log(np.abs(np.real(vis_response).T))
    
    # vmin = log_operator.min()
    vmin = -20
    vmax = 0#log_operator.max()
    # vmax = -5

    if vmin > log_operator.min() and vmax < log_operator.max():
        extend = 'both'
    elif vmin > log_operator.min():
        extend = 'min'
    elif vmax < log_operator.max():
        print(log_operator.max())
        extend = 'max'
    else:
        extend = 'neither'
    im1 = ax1.matshow(log_operator, cmap=cmo.tempo_r, vmin=vmin, vmax=vmax)
    
    # Labels and ticks on parent plot
    # fig.colorbar(im1, ax=ax1)
    plt.xlabel(r'$\texttt{NVIS} = \texttt{NLSTS}\times\texttt{NFREQ}\times\texttt{NBL}$',labelpad=-20)
    plt.ylabel(r'$\texttt{NMODES}$ $(m,\ell)$')#, labelpad=-15)
    plt.title(f'{frequency}MHz, 1m dish',
             loc='left',fontsize=20.7)
    # plt.title(r'\textbf{Operator that contains all necessary information about the interferometer response}'+'\n'+'to the individual spherical harmonic modes on the sky',
    #          loc='left',fontsize=20.7)
    
    # annotate y-axis with "real-" and "imag part"
    # ax1.text(-50,35,'real part',rotation='vertical',color='0.2')
    # ax1.text(-50,80,'imag part',rotation='vertical',color='0.2')
    ax1.text(-50,150,'real part',rotation='vertical',color='0.2')
    ax1.text(-50,360,'imag part',rotation='vertical',color='0.2')
    
    xlabels = pyplot.xticks(idx_x,ant_labels,rotation=80, 
                            ha="right", rotation_mode="anchor",
                           fontsize=12) 
    ylabels = pyplot.yticks(idx_y[::2],ems[::2],rotation=0, fontsize=12) 
    
    ax1.tick_params(top=False, labeltop=False, bottom=True, labelbottom=True)
    
    # plt.gcf().set_size_inches((18., 6.))
    
    # Set colorbar to graph size 
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="3%", pad=0.2)
    cbar = plt.colorbar(im1, cax=cax, extend=extend)
    cbar.set_label(r'$\textup{log}_{10}|\mathbf{X}_\textup{re}|$', labelpad=10)

    
    # Save final figure
    fig.savefig(figpath+f'{frequency}MHz_1m'+'.png', 
        bbox_inches='tight', 
        transparent=True, 
        dpi=200,
        facecolor='white')

    plt.close()
    


  log_operator = np.log(np.abs(np.real(vis_response).T))
  log_operator = np.log(np.abs(np.real(vis_response).T))
