In [None]:
from draco.core import containers
from drift.core import manager
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from scipy.stats import circvar

# Set some global variables for eigencal. We will probably want a way of feeding these in programmatically. 

In [None]:
# Where the calibration outputs are located
EIGENCAL_DIR = "/project/rrg-acliu/mwilensk/eigencal_to_chord/cal_tests/20251117/single_source_nside_2048"
# Product manager used for eigencal
PRODUCT_DIR = "/project/rrg-kmsmith/ssiegel/bt/chord64_airy_rfsoc4k/taua/freqa/"
# The visibilities to analyze
VIS_FILE = "/project/rrg-acliu/mwilensk/eigencal_to_chord/sims/20251117/nside_2048_pol_zero/freqa/taua/vis_sky/vis_n2_wnoise_jy_taua.h5"
# Common string found in all calibration outputs
TAG = "tau_a_lsd_0"

# Reference frequency for plots in antenna or uv-coordinates
REF_FREQ = 400.

# Reference right ascension -- may be unnecessary in practice; I think we need a high SNR time in simulation
# Default to tau_a from https://arxiv.org/pdf/2002.10431 
# FIXME: there's probably a catalog that can be used
REF_RA = (5 + 34/60 + 31.97 / 3600) / 24 * 360

GAIN_FILES = {"eigencal": f"{EIGENCAL_DIR}/gain_transit_{TAG}.h5", "corrcal": None}

pm = manager.ProductManager.from_config(PRODUCT_DIR)
tel = pm.telescope

vis_container = containers.SiderealStream.from_file(VIS_FILE, distributed=True)
ref_chan = np.argmin(np.abs(vis_container.freq - REF_FREQ))
ref_tind = np.argmin(np.abs(vis_container.ra - REF_RA))

In [None]:


# Make a dictionary that holds containers for the gains from the two methods
# Let's me loop through two methods and avoids lots of typos
cal_methods = ["eigencal", "corrcal"]
gain_containers = {"eigencal": None, "corrcal": None}
for cal_method in cal_methods:
    gain_file = GAIN_FILES[cal_method]
    if gain_file is not None:
        gain_containers[cal_method] = containers.StaticGainData.from_file(gain_file)


### Plot the gains.

Right now there is only two frequencies. When we have many frequencies, we will want a different type of plot, e.g. a 2d color plot.

If using the simulation inputs, there is a phase ramp in the eigencal gains due to a disagreement b/w Healpix simulation and calibration model.

In [None]:
def four_panel_ticker(ax, xlabel, ylabels):
    for row_ind in range(2):
        for col_ind in range(2):
            if col_ind:
                ax[row_ind, col_ind].tick_params(axis="y", left=False, right=True, labelleft=False, labelright=True)
            else:
                ax[row_ind, col_ind].set_ylabel(ylabels[row_ind])
            if row_ind:
                ax[row_ind, col_ind].set_xlabel(xlabel)
            else:
                ax[row_ind, col_ind].tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
    return

fig, ax = plt.subplots(figsize=[12, 9], nrows=2, ncols=2)

#ax[0, 0].set_ylabel("Gain Amplitude")
#ax[1, 0].set_ylabel("Gain Phase")
for col_ind, cal_method in enumerate(cal_methods):
    ax[0, col_ind].set_title(cal_method)
    gain_container = gain_containers[cal_method]
    if gain_container is not None:
        gains = gain_containers[cal_method].gain.local_data.T
        ax[0, col_ind].plot(np.abs(gains), label=vis_container.freq)
        ax[1, col_ind].plot(np.angle(gains))
    #ax[0, col_ind].set_xticklabels([])
    #ax[0, col_ind].tick_params(axis="x", direction="in")
    #ax[0, col_ind].tick_params(axis="x", direction="inout", bottom=True, size=10)
    #ax[1, col_ind].set_xlabel("Input")


ax[0, 0].legend()

four_panel_ticker(ax, "Input", ["Gain Amplitude", "Gain Phase"])

fig.tight_layout(h_pad=0)

### Plot the gain phases in antenna position for the reference frequency.

Adjust PHASE_VMIN and PHASE_VMAX variables at top of next cell to change color settings. May want to pick a cyclic colormap if phase is wrapping across array.

In [None]:
fig, ax = plt.subplots(figsize=[12, 9], ncols=2, nrows=2)
PHASE_VMIN = -0.1
PHASE_VMAX = 0.1
PHASE_CMAP = "coolwarm" # 'hsv' is a good choice if phase wraps across array.

for col_ind, cal_method in enumerate(cal_methods):
    ax[0, col_ind].set_title(cal_method)
    gain_container = gain_containers[cal_method]
    if gain_container is not None:
        gains = gain_containers[cal_method].gain.local_data[ref_chan]
        for row_ind, pol in enumerate(np.unique(tel.polarisation)): # pol is a string
            wh_this_pol = np.where(tel.polarisation == pol)
            gains_this_pol = gains[wh_this_pol]
            xpos, ypos = tel.feedpositions[wh_this_pol].T # transpose since shape is [input, 2]; why no z?
            ax[row_ind, col_ind].scatter(xpos, ypos, c=np.angle(gains_this_pol), cmap=PHASE_CMAP, vmin=PHASE_VMIN, vmax=PHASE_VMAX)

four_panel_ticker(ax, "EW Feed Position (m)", 2 * ["NS Feed Position (m)"])
fig.tight_layout(h_pad=0, w_pad=0)
sm = ScalarMappable(norm=Normalize(vmin=PHASE_VMIN, vmax=PHASE_VMAX), cmap=PHASE_CMAP)
for row_ind, pol in enumerate(np.unique(tel.polarisation)):
    fig.colorbar(sm, ax=ax[row_ind].ravel().tolist(), label=f"{pol} Feed Gain Phase")


###Â Form calibrated vis as well as a "gain product" container in case it's useful...

In [None]:
# Make a dictionary holding product of gains (i.e. baseline gains)
# Type of index_map entries are np.void, making numpy mad when I try to fancy index
gain_products = {}
cal_vis = {}
for cal_method in cal_methods:
    # Keeping a standard of initializing to None so that Nonetype errors are a useful signal
    gain_products[cal_method] = None 
    cal_vis[cal_method] = None
    gain_container = gain_containers[cal_method]
    if gain_container is not None:
        gain_products[cal_method] = []
        gains = gain_container.gain
        for input_indices in vis_container.index_map["prod"]:
            # FIXME: Verify conjugation convention
            gain_products[cal_method].append(gains[:, input_indices[0]] * gains[:, input_indices[1]].conj())
        # FIXME: Verify whether it should be multiply or divide _by method_
        gain_products[cal_method] = np.array(gain_products[cal_method]).T
        cal_vis[cal_method] = vis_container.vis.local_data * gain_products[cal_method][:, :, None]

### Plot amplitude and phase redundancy by redundant group. 

Uses tel._feedmap, which is an Ninput X Ninput array whose entries map a given input pair in tel.baselines, which itself is a list of
unique baseline spacings, with some redundancy due to polarization pairs (i.e. each cross baseline in the right half of the uv-plane shows up four times, the autos only 3 since the YX autos are not present).



In [None]:
cal_vis_by_group = {cal_method: {} for cal_method in cal_methods}
for blind, (input_a, input_b) in enumerate(vis_container.index_map["prod"]):
    # A uv coordinate representing the given unique baseline group
    uv_tuple = tuple(tel.baselines[tel._feedmap[input_a, input_b]])
    pol_tuple = (tel.polarisation[input_a] + tel.polarisation[input_b],)
    uvpol =  uv_tuple + pol_tuple
    for cal_method in cal_methods:
        if cal_vis[cal_method] is not None:
            if uvpol not in cal_vis_by_group[cal_method].keys():
                cal_vis_by_group[cal_method][uvpol] = []
            cal_vis_by_group[cal_method][uvpol].append(cal_vis[cal_method][ref_chan, blind, ref_tind])
        elif cal_vis_by_group[cal_method] is not None: # enforce the control structure
            cal_vis_by_group[cal_method] = None
    


Cross pols are noisy; do we want them here?

In [None]:
pol_combos = np.add.outer(["X", "Y"], ["X", "Y"]).flatten()
def plot_redundancy_check(mode):
    """
    Plot some redundancy metrics based on choice of mode.

    args:
        mode (str): 'amp' for an amplitude metric (coefficient of variation), 'phase' for a phase metric (circular variance)
    
    returns:
        None
    """

    fig = plt.figure(figsize=[12, 9])
    subfigs = fig.subfigures(1, 2)
    ax = [subfigs[method_ind].subplots(2, 2, sharex=True, sharey=True) for method_ind in range(len(cal_methods))]

    if mode == "amp":
        def metric_function(vis):
            abs_vis = np.abs(vis)
            return np.std(abs_vis) / np.mean(abs_vis)
        title = "coefficient of variation"
    elif mode == "phase":
        def metric_function(vis):
            return circvar(np.angle(vis))
        title = "circular variance"

    for fig_ind, cal_method in enumerate(cal_methods):
        subfigs[fig_ind].suptitle(f"{cal_method} {title}")
        if cal_vis_by_group[cal_method] is not None:
            metric = {pol_combo: [] for pol_combo in pol_combos}
            for (u, v, pol), vis in cal_vis_by_group[cal_method].items():
                metric[pol].append([u, v, metric_function(vis)])
            for pol_ind, pol in enumerate(pol_combos):
                row_ind = pol_ind // 2
                col_ind = pol_ind % 2
                u, v, c = np.array(metric[pol]).T
                ax[fig_ind][row_ind, col_ind].scatter(u, v, c=c)
    
                sm = ScalarMappable(norm=Normalize(vmin=0, vmax=np.amax(c)), cmap="viridis")
                cbar = fig.colorbar(sm, ax=ax[fig_ind][row_ind, col_ind])
                
                ax[fig_ind][row_ind, col_ind].set_title(pol)
                if row_ind:
                    ax[fig_ind][row_ind, col_ind].set_xlabel("U (m)")
                if not col_ind:
                    ax[fig_ind][row_ind, col_ind].set_ylabel("V (m)") 

    return

In [None]:
plot_redundancy_check("amp")

In [None]:
plot_redundancy_check("phase")