DSC-037 step 3: Assessment & Verfication 
---
DSC-037: Cable reflection systematics for EoR science

### Authors:
SHAO EoR Group and Teal Team

### Documentation on confluence:

Summary: This notebook is a first implementation of DSC-037 to compute and plot rudimentary (FFT, absolute value, then square) delay power spectra.

DSC description page: https://confluence.skatelescope.org/x/0rs6F
Chronological walkthrough: https://confluence.skatelescope.org/x/osw6F
Implementation: https://confluence.skatelescope.org/x/n8LMF
GitHub repo: https://github.com/uksrc-developers/dsc-037-eor


Ticket: TEAL-1128 https://jira.skatelescope.org/browse/TEAL-1128

**Dependencies:**  `numpy`, `astropy`

Last updated: 2025-10-16

## 1. Plot Delay-Time Dynamic Spectrum (Waterfall)

This notebook calculates and plots the delay–time dynamic spectrum (waterfall) for a chosen baseline and polarization from a Measurement Set (MS) or UVFITS file.

## Data Pipeline

For each time sample:
1.  $V(t, f)$  --FFT over frequency ($f$)-->  $\tilde{V}(t, \tau)$
2.  Delay power = $|\tilde{V}(t, \tau)|^2$

**Plot Axes:**
* **Y-axis:** Delay [µs]
* **X-axis:** Time [s]
* **Color:** $|\mathrm{FFT}|^2$ (log scale by default)

## Notes

* Uses a **RAW FFT** (no taper/window function).
* Flagged channels are **zeroed** (unless `ignore_flags` is `True`) so they don't poison the FFT.

In [1]:
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

try:
    from casacore.tables import table
    CASACORE_AVAILABLE = True
except ImportError:
    CASACORE_AVAILABLE = False
    print("WARNING: casacore.tables not found. MS loading will fail.")

try:
    from pyuvdata import UVData
    PYUVDATA_AVAILABLE = True
except ImportError:
    PYUVDATA_AVAILABLE = False
    print("WARNING: pyuvdata not found. UVFITS loading will fail.")

### Helper Functions (Measurement Set)

In [2]:
# ----------------------------
# MS helpers (self-contained)
# ----------------------------
CORR_MAP_NUM2STR = {5: 'RR', 6: 'RL', 7: 'LR', 8: 'LL', 9: 'XX', 10: 'XY', 11: 'YX', 12: 'YY'}

def ms_get_corr_index(ms_path: str, want_corr: str) -> int:
    if not CASACORE_AVAILABLE:
        raise ImportError("casacore.tables is required to read Measurement Sets.")
    pol = table(f"{ms_path}/POLARIZATION", readonly=True, ack=False)
    corr_types = pol.getcol('CORR_TYPE')[0]   # e.g. [9,10,11,12] for XX,XY,YX,YY
    pol.close()
    labels = [CORR_MAP_NUM2STR.get(int(x), str(x)) for x in list(corr_types)]
    want = want_corr.upper()
    if want not in labels:
        raise RuntimeError(f"Requested correlation {want} not found. Available: {labels}")
    return labels.index(want)

def ms_get_freq_axis_hz(ms_path: str, ddid: int) -> np.ndarray:
    if not CASACORE_AVAILABLE:
        raise ImportError("casacore.tables is required to read Measurement Sets.")
    dd = table(f"{ms_path}/DATA_DESCRIPTION", readonly=True, ack=False)
    spw_id = int(dd.getcell('SPECTRAL_WINDOW_ID', ddid))
    dd.close()
    spw = table(f"{ms_path}/SPECTRAL_WINDOW", readonly=True, ack=False)
    chan_freq_hz = spw.getcell('CHAN_FREQ', spw_id)  # Hz, shape [nchan]
    spw.close()
    return np.asarray(chan_freq_hz, dtype=float)

def ms_ant_to_index(ms_path: str, name_or_index) -> int:
    if not CASACORE_AVAILABLE:
        raise ImportError("casacore.tables is required to read Measurement Sets.")
    names = table(f"{ms_path}/ANTENNA", readonly=True, ack=False).getcol('NAME')
    s = str(name_or_index)
    if s.isdigit():
        idx = int(s)
        if idx < 0 or idx >= len(names):
            raise RuntimeError(f"Antenna index {idx} out of range [0,{len(names)-1}]")
        return idx
    where = np.where(names == s)[0]
    if where.size == 0:
        raise RuntimeError(f"Antenna '{name_or_index}' not found. Available: {list(names)}")
    return int(where[0])

def ms_list_baselines(ms_path: str):
    if not CASACORE_AVAILABLE:
        raise ImportError("casacore.tables is required to read Measurement Sets.")
    t = table(ms_path, readonly=True, ack=False)
    ants = table(f"{ms_path}/ANTENNA", readonly=True, ack=False).getcol('NAME')
    a1 = t.getcol('ANTENNA1'); a2 = t.getcol('ANTENNA2')
    bls = np.unique(np.vstack([a1, a2]).T, axis=0)
    print("\nAntenna indices & names:")
    for i, nm in enumerate(ants):
        print(f"  {i}: {nm}")
    print("\nBaselines (ANT1-ANT2):")
    for u, v in bls:
        print(f"  {u}-{v}  ({ants[u]} - {ants[v]})")
    print(f"\nTotal unique baselines: {len(bls)}\n")
    return bls, ants

###  Core Delay Transform

In [3]:
def fft_delay_power(vis_tf: np.ndarray, df_hz: float) -> tuple[np.ndarray, np.ndarray]:
    """
    vis_tf : (Nt, Nf) complex array, per-time spectra (flagged channels already zeroed or kept)
    df_hz  : channel spacing (assumed uniform) in Hz

    Returns:
      tau_s     : (Nf,) delay axis in seconds (fftshift-ed)
      p_t_tau   : (Nt, Nf) delay power |FFT|^2, fftshift-ed along freq axis
    """
    # RAW FFT (no window), along frequency axis
    ft = np.fft.fft(vis_tf, axis=1)
    ft = np.fft.fftshift(ft, axes=1)
    p = np.abs(ft)**2

    tau_s = np.fft.fftshift(np.fft.fftfreq(vis_tf.shape[1], d=df_hz))
    return tau_s, p

### Data Loaders

In [4]:
def load_ms_delay_waterfall(ms_path: str, ant1, ant2, corr: str, col: str = "DATA",
                            ddid_override: int | None = None,
                            timebin: int = 1, chanbin: int = 1,
                            ignore_flags: bool = False):
    if not CASACORE_AVAILABLE:
        raise ImportError("casacore.tables is required to read Measurement Sets.")
        
    t = table(ms_path, readonly=True, ack=False)
    ants = table(f"{ms_path}/ANTENNA", readonly=True, ack=False).getcol('NAME')

    a1 = ms_ant_to_index(ms_path, ant1)
    a2 = ms_ant_to_index(ms_path, ant2)
    lo, hi = sorted([a1, a2])

    q = t.query(f"ANTENNA1=={lo} && ANTENNA2=={hi}")
    if q.nrows() == 0:
        raise RuntimeError(f"No rows for baseline {lo}-{hi} ({ants[lo]} - {ants[hi]}). Use --list-baselines.")

    ddids = np.unique(q.getcol('DATA_DESC_ID'))
    if ddids.size == 0:
        raise RuntimeError("No DATA_DESC_IDs in selection.")
    if ddid_override is not None:
        if ddid_override not in ddids:
            raise RuntimeError(f"--ddid {ddid_override} not in {ddids}")
        ddid = int(ddid_override)
    else:
        if ddids.size > 1:
            print(f"[WARNING] Multiple DDIDs {ddids}; using first: {ddids[0]}")
        ddid = int(ddids[0])

    freqs_hz_full = ms_get_freq_axis_hz(ms_path, ddid)
    # Check uniformity
    dnu = np.diff(freqs_hz_full)
    df_nom = np.median(dnu)
    if not np.allclose(dnu, df_nom, rtol=1e-3, atol=0):
        print("[WARNING] Non-uniform channel spacing detected; using median Δν for FFT.")

    corr_idx = ms_get_corr_index(ms_path, corr)
    have_flag = "FLAG" in t.colnames()

    step = 100000
    times_list = []
    delay_power_rows = []

    total, flagged = 0, 0

    for start in range(0, q.nrows(), step):
        nr = min(step, q.nrows() - start)
        data = q.getcol(col, start, nr)[:, :, corr_idx]              # (nr, nchan)
        flags = q.getcol('FLAG', start, nr)[:, :, corr_idx] if have_flag else np.zeros_like(data, bool)
        time = q.getcol('TIME', start, nr)                            # (nr,)

        total += flags.size
        flagged += np.sum(flags)

        # Optional channel binning BEFORE FFT (changes resolution and delay range)
        if chanbin > 1:
            nch2 = (data.shape[1] // chanbin) * chanbin
            data = data[:, :nch2].reshape(nr, -1, chanbin).mean(axis=2)
            flags = flags[:, :nch2].reshape(nr, -1, chanbin).any(axis=2)
            # New effective frequency axis & spacing
            freqs_hz = freqs_hz_full[:nch2].reshape(-1, chanbin).mean(axis=1)
            df_eff = np.median(np.diff(freqs_hz))
        else:
            freqs_hz = freqs_hz_full
            df_eff = df_nom

        # Time binning BEFORE FFT
        if timebin > 1 and data.shape[0] >= timebin:
            k = (data.shape[0] // timebin) * timebin
            data = data[:k].reshape(-1, timebin, data.shape[1]).mean(axis=1)
            flags = flags[:k].reshape(-1, timebin, flags.shape[1]).any(axis=1)
            time = time[:k].reshape(-1, timebin).mean(axis=1)

        # Apply flags: zero-out flagged channels unless ignoring flags
        vis_block = data if ignore_flags else np.where(~flags, data, 0.0 + 0.0j)

        # Compute delay power for the whole block at once (axis=1 FFT)
        tau_s, p_block = fft_delay_power(vis_block, df_eff)  # p_block: (nrow_block, nchan_eff)

        delay_power_rows.append(p_block)
        times_list.append(time)

    q.close(); t.close()

    flag_pct = (flagged / total * 100) if total > 0 else 0.0
    print(f"[INFO] Flagged samples for baseline {lo}-{hi}: {flag_pct:.2f}% "
          f"({'IGNORED' if ignore_flags else 'applied as zeros'})")

    P_t_tau = np.vstack(delay_power_rows)    # (Nt, Nτ)
    times_all = np.concatenate(times_list)   # (Nt,)

    # Seconds from first time
    t0 = np.nanmin(times_all)
    time_sec = times_all - t0
    delay_us = tau_s * 1e6

    blname = f"{lo}-{hi}"
    polname = corr.upper()
    return time_sec, delay_us, P_t_tau, blname, polname

In [5]:
def load_uvfits_delay_waterfall(uvfits_path: str, ant1, ant2, corr: str,
                                timebin: int = 1, chanbin: int = 1,
                                ignore_flags: bool = False):
    if not PYUVDATA_AVAILABLE:
        raise RuntimeError("pyuvdata required for UVFITS input.")
    uv = UVData(); uv.read(uvfits_path)

    # Antenna mapping
    def uv_ant_to_index(uvo, a):
        try:
            return int(a)
        except Exception:
            names = np.array(uvo.antenna_names)
            where = np.where(names == a)[0]
            if where.size == 0:
                raise RuntimeError(f"Antenna '{a}' not found in UVFITS.")
            return int(where[0])

    a1 = uv_ant_to_index(uv, ant1)
    a2 = uv_ant_to_index(uv, ant2)

    pols = uv.get_pols()
    want = corr.upper()
    if want not in pols:
        raise RuntimeError(f"Polarization {want} not in {pols}")
    pol_code = uv.polarization_array[list(pols).index(want)]

    vis = uv.get_data((a1, a2, pol_code))    # (Nt, Nf)
    flags = uv.get_flags((a1, a2, pol_code)) # (Nt, Nf)
    freqs_hz_full = uv.freq_array[0] if uv.freq_array.ndim == 2 else uv.freq_array
    freqs_hz_full = np.asarray(freqs_hz_full, float)

    # Channel binning BEFORE FFT
    if chanbin > 1:
        nch2 = (vis.shape[1] // chanbin) * chanbin
        vis = vis[:, :nch2].reshape(vis.shape[0], -1, chanbin).mean(axis=2)
        flags = flags[:, :nch2].reshape(flags.shape[0], -1, chanbin).any(axis=2)
        freqs_hz = freqs_hz_full[:nch2].reshape(-1, chanbin).mean(axis=1)
    else:
        freqs_hz = freqs_hz_full

    # Time binning BEFORE FFT
    times_jd = uv.get_times((a1, a2, pol_code))
    if timebin > 1 and vis.shape[0] >= timebin:
        k = (vis.shape[0] // timebin) * timebin
        vis = vis[:k].reshape(-1, timebin, vis.shape[1]).mean(axis=1)
        flags = flags[:k].reshape(-1, timebin, flags.shape[1]).any(axis=1)
        times_jd = times_jd[:k].reshape(-1, timebin).mean(axis=1)

    # Flag handling: zero flagged channels unless ignoring
    vis_proc = vis if ignore_flags else np.where(~flags, vis, 0.0 + 0.0j)

    # Δν
    dnu = np.diff(freqs_hz)
    df_nom = np.median(dnu)
    if not np.allclose(dnu, df_nom, rtol=1e-3, atol=0):
        print("[WARNING] Non-uniform channel spacing detected; using median Δν for FFT.")

    # Delay power per time
    tau_s, P_t_tau = fft_delay_power(vis_proc, df_nom)
    delay_us = tau_s * 1e6

    # Time (seconds from first)
    time_sec = (times_jd - np.min(times_jd)) * 86400.0

    blname = f"{a1}-{a2}"
    polname = want
    return time_sec, delay_us, P_t_tau, blname, polname

## 4. Plotting Function

In [6]:
def plot_delay_waterfall(time_sec, delay_us, power_t_tau, title, out_pdf,
                         vmin=None, vmax=None, log=True):
    vals = power_t_tau[np.isfinite(power_t_tau)]
    # Robust defaults: avoid zeros with LogNorm
    if vals.size == 0:
        vals = np.array([1e-12])
    if vmin is None:
        vmin = max(np.percentile(vals, 5), 1e-12)
    if vmax is None:
        vmax = np.percentile(vals, 95)
        if vmax <= vmin:  # guard
            vmax = vmin * 10.0

    def edges(x):
        dx = np.diff(x)
        dx = np.r_[dx[:1], dx]
        return np.r_[x - dx/2, x[-1] + dx[-1]/2]

    t_edges = edges(np.asarray(time_sec, float))
    d_edges = edges(np.asarray(delay_us, float))

    plt.figure(figsize=(7.6, 5.6))
    mesh = plt.pcolormesh(
        t_edges, d_edges, power_t_tau.T, shading="auto",
        norm=LogNorm(vmin=vmin, vmax=vmax) if log else None,
        cmap="Spectral_r"
    )
    cbar = plt.colorbar(mesh)
    cbar.set_label("Delay Power |Ṽ|² (arb.)")

    plt.xlabel("Time [s]")
    plt.ylabel("Delay [µs]")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_pdf, format="pdf", dpi=300, bbox_inches="tight")
    plt.show() # Uncomment this line to display the plot inline
    plt.close()
    print(f"[INFO] Saved delay waterfall PDF: {out_pdf}")

 ### Configuration & Execution for MWA data

Set the parameters in the cell below and run the following cells to generate the plot.

In [None]:
def main_execution_logic(
    input_file,
    ant1,
    ant2,
    corr,
    col,
    ddid,
    list_baselines,
    ignore_flags,
    timebin,
    chanbin,
    vmin,
    vmax,
    no_log,
    plot_title,
    out_file
):
    try:
        # MS or UVFITS?
        if os.path.isdir(input_file):
            if list_baselines:
                print(f"Listing baselines for {input_file}...")
                ms_list_baselines(input_file)
            else:
                print(f"Loading MS: {input_file}")
                t, dly, P, bl, pol = load_ms_delay_waterfall(
                    input_file, ant1, ant2, corr,
                    col=col, ddid_override=ddid,
                    timebin=timebin, chanbin=chanbin,
                    ignore_flags=ignore_flags
                )
                
                out_pdf = out_file if out_file.endswith(".pdf") else out_file + ".pdf"
                title = plot_title if plot_title else f"Delay–Time Waterfall: {bl}, {pol}"
                
                plot_delay_waterfall(t, dly, P, title, out_pdf,
                                     vmin=vmin, vmax=vmax,
                                     log=(not no_log))

        elif os.path.isfile(input_file):
            print(f"Loading UVFITS: {input_file}")
            t, dly, P, bl, pol = load_uvfits_delay_waterfall(
                input_file, ant1, ant2, corr,
                timebin=timebin, chanbin=chanbin,
                ignore_flags=ignore_flags
            )
            
            out_pdf = out_file if out_file.endswith(".pdf") else out_file + ".pdf"
            title = plot_title if plot_title else f"Delay–Time Waterfall: {bl}, {pol}"
            
            plot_delay_waterfall(t, dly, P, title, out_pdf,
                                 vmin=vmin, vmax=vmax,
                                 log=(not no_log))
        else:
            if not list_baselines: # Don't warn if we're just trying to list
                print(f"[ERROR] Input file not found or is not a directory (MS) or file (UVFITS): {input_file}")

    except (RuntimeError, ImportError, FileNotFoundError) as e:
        print(f"[ERROR] An error occurred: {e}")


In [8]:
# --- Parameters ---

# MS directory or UVFITS file
input_file = "L253456_SAP000_002_time1.flagged.5ch8s.dical.MS" 

# Antenna 1 (index or name)
ant1 = "2"

# Antenna 2 (index or name)
ant2 = "6"

# Polarization: XX, YY, RR, LL, XY, YX, ...
corr = "XX"

# [MS ONLY] Column: DATA, CORRECTED_DATA, MODEL_DATA
col = "DATA"

# [MS ONLY] DATA_DESC_ID override (default: None)
ddid = None

# [MS ONLY] List baselines and exit (set to True, run, then set back to False)
list_baselines = False

# --- Processing ---

# Do not mask flagged channels (keep raw values)
ignore_flags = False

# Average these many consecutive time samples before FFT
timebin = 1

# Average these many consecutive channels before FFT
chanbin = 1

# --- Plotting ---

# Color scale minimum (default: None for robust 5th pct)
vmin = None

# Color scale maximum (default: None for robust 95th pct)
vmax = None

# Disable log color scale
no_log = False

# Plot title (default: None for automatic)
plot_title = "Delay–Time Waterfall"

# Output PDF filename
out_file = "delay_waterfall.pdf"

main_execution_logic(
    input_file=input_file,
    ant1=ant1,
    ant2=ant2,
    corr=corr,
    col=col,
    ddid=ddid,
    list_baselines=list_baselines,
    ignore_flags=ignore_flags,
    timebin=timebin,
    chanbin=chanbin,
    vmin=vmin,
    vmax=vmax,
    no_log=no_log,
    plot_title=plot_title,
    out_file=out_file
)


TypeError: main_execution_logic() got an unexpected keyword argument 'out_file'