# Validate Bandpass Polarisation Stage

## Imports

In [None]:
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import run_solver

from ska_sdp_instrumental_calibration.data_managers.visibility import (
    read_dataset_from_zarr,
    write_ms_to_zarr,
)
from ska_sdp_instrumental_calibration.workflow.utils import (
    create_bandpass_table,
    with_chunks,
)
from ska_sdp_datamodels.calibration.calibration_create import (
    create_gaintable_from_visibility,
)

from ska_sdp_instrumental_calibration.data_managers.data_export.export_to_h5parm import (
    export_gaintable_to_h5parm,
)

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    apply_gaintable_to_dataset,
    prediction_central_beams,
)

from utils import (
    get_frequency,
    get_polarisations,
    get_antennas,
    get_values,
    REMOVE_LAST_ITEM,
)

import shutil

import dask

import numpy as np
import xarray as xr

import os
from typing import Literal

import glob
import h5py

%matplotlib inline
import matplotlib.pyplot as plt

import pandas
from ska_sdp_instrumental_calibration.data_managers.visibility import (
    load_ms_as_dataset_with_time_chunks,
)

In [None]:
# If you need to connect to a dask client
from distributed import Client

scheduler = "localhost:34567"

client = Client(scheduler)

## Generate data

We use the scripts stored in `scripts/ska_low_sim` to generate the data.
Information about how to use the scripts is present in this [confluence page](https://confluence.skatelescope.org/x/eZHMF).

### Simulation config

The following is the custom configuration used for simulation (further referred to as `custom_sim.yaml`)

```yaml
scenario: "low40s-model"          # Scenario name (used for output folder prefix)

# ===============================
# Global simulation parameters
# ===============================

n_stations: 40                                         # Number of stations
tel_model: "./SKA-Low_AA2_40S_rigid-rotation_model.tm" # Telescope model directory

simulation_start_frequency_hz: 123.0e6                  # Start frequency (Hz)
simulation_end_frequency_hz: 153.0e6                    # End frequency (Hz)
correlated_channel_bandwidth_hz: 21.70138888888889e3    # Channel width (Hz)

observing_time_mins: 10                              # Observation duration (minutes)
sampling_time_sec: 3.3973862400000003                   # Dump/integration time (seconds)

fields:
  EoR2:
    Cal1:
      ra_deg: 197.914612
      dec_deg: -22.277973
      scan_id_start: 300
      transit_time: "2000-01-03 22:33:30.000"

# ==================================
# Options for generate_gaintable.py
# ==================================

generate_gaintable:
  output_gaintable: &gen_gaintable "./gaintables/custom_gaintable.h5"

  spline_data_path: "./SKA_Low_AA2_SP5175_spline_data.npz" # Bandpass spline data file
  station_offset: true              # Apply per-station amplitude/phase offsets
  time_variant: true                # Apply time-dependent effects

  rfi: false                        # Inject RFI band
  rfi_start_freq_hz: 154.25347222228538e6        # Hz
  rfi_end_freq_hz: 159.8090277778474e6           # Hz

  plot: true                        # Generate diagnostic plots
  plot_output_dir: "./gaintables/generation_plots/"

# ===============================
# Options for run_sim.py
# ===============================

run_sim:
  oskar_sif: "./OSKAR-2.11.1-Python3.sif" # Path to OSKAR Singularity image

  # GLEAM sky model. Optional. Comment to disable.
  gleam_file: "./sky-models/GLEAM_EGC.fits" # GLEAM catalogue FITS file
  field_radius_deg: 10.0            # Radius of field of view (degrees)

  # Corruptions to be applied. All are optional. Comment to disable.
  # gaintable: *gen_gaintable           # Gaintable containing bandpass corruptions
  # cable_delay: "./cable_delays/cable_length_error_40s.txt" # Cable delay error file
  # tec_screen: "./tec/calibrator_iono_tec.fits" # Ionospheric TEC screen FITS

  # Imaging parameters using wsclean. Optional. Comment to disable.
  create_dirty_image: true          # Whether to run wsclean imaging
  image_size: 1024                  # Image size (pixels)
  pixel_size: "2arcsec"             # Pixel size (angular units)

  # Extra parameters to pass directly to run_oskar.py
  run_oskar_extra_params: "--use-gpus --double-precision"
  ```

### Overview of steps followed

Assuming the you have followed steps on the [confluence page](https://confluence.skatelescope.org/x/eZHMF) to get the data and setup python environment, here are the general steps to follow to simulate the data needed to run this script:

1. Create custom gaintable using the the above config and generate_gaintable script

    ```bash
    python3 generate_gaintable.py custom_sim.yaml
    ```

2. Convert gaintable to DP3 h5parm file

    ```bash
    python3 utils/h5parm_from_oskar_gains.py custom_sim.yaml "./gaintables/custom_gaintable.h5"
    ```

3. Create "model" visibilities

    ```bash
    python run_sim.py custom_sim.yaml
    ```

4. Apply the corruptions using DP3 to create corrupted visibilities

    ```bash
    msin="path_to_the_model_visibilities"
    parmdb="./gaintables/custom_gaintable.h5parm"
    mout="corrupted_visibilities.ms"
    DP3_CMD=${DP3_CMD:-"DP3"}

    ${DP3_CMD} steps="[applycal]" msin=${msin}\
	  applycal.invert=False applycal.parmdb=${parmdb}\
	  applycal.missingantennabehavior=flag applycal.updateweights=True\
	  applycal.steps="[phase, amplitude]" applycal.phase.correction=phase000\
	  applycal.amplitude.correction=amplitude000 msout=${msout}
    ```


## Setup common parameters

In [None]:
# Set these paths to the simulated data generated in previous steps
vis_ms = "/home/ska/Work/data-simulation/only-bandpass-simulation/visibility.scan-300.corrupt.ms"
model_ms = (
    "/home/ska/Work/data-simulation/only-bandpass-simulation/visibility.scan-300.ms"
)
sim_gaintable = "/home/ska/Work/data-simulation/only-bandpass-simulation/customgaintable/gain_model_cal.h5"
dp3_gaintable = "/home/ska/Work/data-simulation/only-bandpass-simulation/customgaintable/gain_model_cal_dp3.h5parm"

# Zarr conversion params
nchannels_per_chunk = 32
ntimes_per_ms_chunk = 16

# predict stage related parameters
_cli_args_ = {"input": vis_ms}
beam_type = "everybeam"
normalise_at_beam_centre = True
eb_coeffs = "/home/ska/Work/data/INST/sim/coeffs"
eb_ms = vis_ms
lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"
fov = 10.0
flux_limit = 1.0
alpha0 = -0.78

# Solver params
solver: str = "gain_substitution"
refant: int = 0
niter: int = 200
phase_only: bool = False
tol: float = 1e-06
crosspol: bool = False
normalise_gains: str = None
jones_type: Literal["T", "G", "B"] = "B"
timeslice: float = None

run_solver_params = dict(
    solver=solver,
    refant=refant,
    niter=niter,
    phase_only=phase_only,
    tol=tol,
    crosspol=crosspol,
    normalise_gains=normalise_gains,
    jones_type=jones_type,
    timeslice=timeslice,
)

working_dir = os.getcwd()
cache_dir = f"{working_dir}/cache"

In [None]:
os.environ["PATH"] = (
    f"/home/ska/spack/opt/spack/linux-ubuntu22.04-x86_64_v3/gcc-11.4.0/dp3-6.5.1-ihprmbbwfxaeb6bu73kvoyxvydmd3qq7/bin:/home/ska/spack/opt/spack/linux-ubuntu22.04-x86_64_v3/gcc-11.4.0/wsclean-3.6.20250630-cd2552n6cesup3jr3ja7vj2vgggmbyfp/bin/:{os.environ['PATH']}"
)

In [None]:
# Common dimensions across zarr and loaded visibility dataset
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}

# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": ntimes_per_ms_chunk,
    "frequency": nchannels_per_chunk,
}

# Pipeline only works on frequency chunks
# Its expected that later stages follow same chunking pattern
vis_chunks = {
    **non_chunked_dims,
    "time": -1,
    "frequency": nchannels_per_chunk,
}

## Testing run_solver directly

In [None]:
# vis_cache_directory = f"{working_dir}/cache/vis/"
# model_cache_directory = f"{working_dir}/cache/model/"

In [None]:
# os.makedirs(vis_cache_directory, exist_ok=True)
# write_ms_to_zarr(vis_ms, vis_cache_directory, zarr_chunks)

In [None]:
# os.makedirs(model_cache_directory, exist_ok=True)
# write_ms_to_zarr(model_ms, model_cache_directory, zarr_chunks)

### Load visibilities from zarr and create initial gaintable

In [None]:
# vis_ds = read_dataset_from_zarr(vis_cache_directory, vis_chunks)

# # Uncomment to read model_ds
# model_ds = read_dataset_from_zarr(model_cache_directory, vis_chunks)
# #Fix to ensure that time coords match
# model_ds = model_ds.assign_coords({"time": vis_ds.time})

In [None]:
# gaintable = create_bandpass_table(vis_ds).pipe(with_chunks, vis_chunks)

# # gaintable = create_gaintable_from_visibility(vis_ds, jones_type='B', timeslice=None)

### Normalise at beam center

In [None]:
# def plot_vis(
#     vis_ds: xr.Dataset,
#     time: int,
#     baselineid: int,
#     polarisation: int,
#     figsize=(10, 6),
#     savepath=None,
# ):
#     """
#     Plot amplitude and phase of complex visibilities from an xarray Dataset.

#     Parameters
#     ----------
#     vis_ds : xr.Dataset
#         Dataset containing a complex-valued DataArray named 'vis'.
#     time : int
#         Index along the 'time' dimension.
#     baselineid : int
#         Index along the 'baselineid' dimension.
#     polarisation : int
#         Index along the 'polarisation' dimension.
#     figsize : tuple, optional
#         Figure size (width, height) in inches. Default is (10, 6).
#     savepath : str, optional
#         If provided, saves the plot to the given path.
#     """
#     # Extract the selected visibility slice
#     vis = vis_ds.vis.isel(time=time, baselineid=baselineid, polarisation=polarisation)

#     # Ensure we have a frequency dimension for the x-axis
#     freq = vis_ds.coords.get("frequency", np.arange(vis.size))

#     amp = np.abs(vis)
#     phase = np.angle(vis, deg=True)  # Convert phase to degrees

#     fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, sharex=True)

#     # Plot amplitude
#     ax1.plot(freq, amp, color="tab:blue")
#     ax1.set_ylabel("Amplitude")
#     ax1.set_title(f"Baseline {baselineid}, Time {time}, Pol {polarisation}")

#     # Plot phase
#     ax2.plot(freq, phase, color="tab:orange")
#     ax2.set_ylabel("Phase [deg]")
#     ax2.set_xlabel("Frequency" if "frequency" in vis_ds.coords else "Channel index")

#     plt.tight_layout()

#     if savepath:
#         plt.savefig(savepath, dpi=150)
#         plt.close(fig)
#     else:
#         plt.show()

In [None]:
# time = 15
# baselineid = 50
# polarisation = 0

<!-- #### Plotting visibility before normalising -->

In [None]:
# plot_vis(vis_ds, time=time, baselineid=baselineid, polarisation=polarisation)

In [None]:
# plot_vis(model_ds, time=time, baselineid=baselineid, polarisation=polarisation)

<!-- #### Normalise visibility -->

In [None]:
# if normalise_at_beam_centre:
#     beams = prediction_central_beams(
#         vis_ds,
#         beam_type=beam_type,
#         eb_ms=eb_ms,
#         eb_coeffs=eb_coeffs,
#     ).persist()

#     vis_ds = apply_gaintable_to_dataset(vis_ds, beams, inverse=True)
#     model_ds = apply_gaintable_to_dataset(model_ds, beams, inverse=True)

<!-- #### Plotting after normalising -->

In [None]:
# plot_vis(vis_ds, time=time, baselineid=baselineid, polarisation=polarisation)

In [None]:
# plot_vis(model_ds, time=time, baselineid=baselineid, polarisation=polarisation)

### Running calibration and exporting gaintable

In [None]:
# output_dir = f"{working_dir}/run_solver_{solver}_{niter}/"

# os.makedirs(output_dir, exist_ok=True)

In [None]:
# gaintable = run_solver(
#     vis_ds,
#     model_ds,
#     gaintable,
#     **run_solver_params,
# )

In [None]:
# export_task = dask.delayed(export_gaintable_to_h5parm, pure=False)(
#     gaintable,
#     f"{output_dir}/gaintable.h5parm",
# )

In [None]:
# client.compute(export_task, sync=True, optimize_graph=True)

In [None]:
# client.restart()

## Testing bandpass_calibration stage with others

This code tests the following flow:

load_data -> predict_vis -> bandpass_calibration

In [None]:
output_dir = f"{working_dir}/bandpass_stage_{solver}_{niter}"
config_path = f"{output_dir}/inst.yml"

os.makedirs(output_dir, exist_ok=True)

In [None]:
config = {
    "global_parameters": {
        "experimental": {
            "pipeline": [
                {
                    "load_data": {
                        "nchannels_per_chunk": nchannels_per_chunk,
                        "ntimes_per_ms_chunk": ntimes_per_ms_chunk,
                        "cache_directory": cache_dir,
                        "ack": False,
                        "datacolumn": "DATA",
                        "field_id": 0,
                        "data_desc_id": 0,
                    }
                },
                {
                    "predict_vis": {
                        "beam_type": beam_type,
                        "normalise_at_beam_centre": normalise_at_beam_centre,
                        "eb_coeffs": eb_coeffs,
                        "lsm_csv_path": lsm_csv_path,
                        "fov": fov,
                        "flux_limit": flux_limit,
                        "alpha0": alpha0,
                    }
                },
                {
                    "bandpass_calibration": {
                        "run_solver_config": {
                            "solver": solver,
                            "refant": refant,
                            "niter": niter,
                            "phase_only": phase_only,
                            "tol": tol,
                            "crosspol": crosspol,
                            "normalise_gains": normalise_gains,
                            "timeslice": timeslice,
                        },
                        "plot_config": {
                            "plot_table": True,
                            "fixed_axis": False,
                        },
                        "visibility_key": "vis",
                        "export_gaintable": True,
                    }
                },
            ]
        }
    }
}

import yaml

with open(config_path, "w") as f:
    yaml.safe_dump(config, f)

In [None]:
!time ska-sdp-instrumental-calibration experimental --input $vis_ms --config $config_path --dask-scheduler $scheduler --output $output_dir --no-unique-output-subdir

## Comparing h5parm files

In [None]:
output_h5_path = glob.glob(f"{output_dir}/gaintables/bandpass.gaintable.h5*")[0]

actual_gaintable = output_h5_path

In [None]:
def suppress_exception(func, *args, **kwargs):
    try:
        return func(*args, **kwargs)
    except Exception as e:
        print(e, end="\n\n")

### Compare oskar gains with DP3 gains

In [None]:
with h5py.File(sim_gaintable) as sim_gain_f:
    for pol_id, pol in enumerate(["x", "y"]):
        print(f"Amplitude for pol {pol}")
        suppress_exception(
            np.testing.assert_allclose,
            np.abs(sim_gain_f[f"gain_{pol}pol"]),
            get_values(dp3_gaintable, solset="amplitude000", pol=pol_id).transpose(
                0, 2, 1
            ),
        )

        print(f"Phase for pol {pol}")
        suppress_exception(
            np.testing.assert_allclose,
            np.angle(sim_gain_f[f"gain_{pol}pol"]),
            get_values(dp3_gaintable, solset="phase000", pol=pol_id).transpose(0, 2, 1),
        )

### Compare actual vs DP3, by broadcasting expected DP3 gains across time

In [None]:
pols = get_polarisations(actual_gaintable)
parallel_hands = ["XX", "YY"]
indices = [pols.index(item) for item in parallel_hands]


for pol_id, pol in zip(indices, parallel_hands):
    print(f"Amplitude for pol {pol}")

    exp_amp = get_values(
        dp3_gaintable,
        solset="amplitude000",
        pol=(pol_id % 2),
        frequency=REMOVE_LAST_ITEM,
    )

    act_amp_time0_values = get_values(
        actual_gaintable, solset="amplitude000", pol=pol_id
    )
    act_amp = [act_amp_time0_values for i in range(exp_amp.shape[0])]
    act_amp = np.concatenate(act_amp)

    suppress_exception(
        np.testing.assert_allclose,
        act_amp,
        exp_amp,
    )

    print(f"Phase for pol {pol}")
    exp_phase = get_values(
        dp3_gaintable, solset="phase000", pol=(pol_id % 2), frequency=REMOVE_LAST_ITEM
    )
    exp_phase = exp_phase - exp_phase[:, [refant], :]

    act_phase = []
    for i in range(exp_phase.shape[0]):
        act_phase.append(get_values(actual_gaintable, solset="phase000", pol=pol_id))
    act_phase = np.concatenate(act_phase)

    suppress_exception(
        np.testing.assert_allclose,
        np.degrees(act_phase),
        np.degrees(exp_phase),
    )

### Comparing plots

In [None]:
expected_gaintable = dp3_gaintable
time_idx = 0
antenna_idx = 20
pol_idx = 0  # 0 for XX. or 3 for YY.
refant = 0

In [None]:
frequency = get_frequency(actual_gaintable)
channels = np.arange(frequency.size)

pols = get_polarisations(actual_gaintable)
stations = get_antennas(actual_gaintable)

exp_amp = get_values(
    dp3_gaintable,
    solset="amplitude000",
    time=time_idx,
    antenna=antenna_idx,
    pol=pol_idx % 2,
    frequency=REMOVE_LAST_ITEM,
)

exp_phase = get_values(
    dp3_gaintable,
    solset="phase000",
    time=time_idx,
    antenna=antenna_idx,
    pol=pol_idx % 2,
    frequency=REMOVE_LAST_ITEM,
)
reference_phase = get_values(
    dp3_gaintable,
    solset="phase000",
    time=time_idx,
    antenna=refant,
    pol=pol_idx % 2,
    frequency=REMOVE_LAST_ITEM,
)
exp_phase = exp_phase - reference_phase

act_amp = get_values(
    actual_gaintable,
    solset="amplitude000",
    time=time_idx,
    antenna=antenna_idx,
    pol=pol_idx % 2,
)
act_phase = get_values(
    actual_gaintable,
    solset="phase000",
    time=time_idx,
    antenna=antenna_idx,
    pol=pol_idx % 2,
)


def channel_to_freq(channel):
    return np.interp(channel, np.arange(len(frequency)), frequency)


def freq_to_channel(freq):
    return np.interp(freq, frequency, np.arange(len(frequency)))


fig = plt.figure(layout="constrained", figsize=(18, 9))
amp_ax, phase_ax = fig.subplots(nrows=1, ncols=2)

amp_ax.scatter(channels, act_amp, label="Actual")
amp_ax.scatter(channels, exp_amp, label="Expected")
amp_ax.set_ylabel("Amplitude")
amp_ax.set_xlabel("Channel")
amp_ax.secondary_xaxis(
    "top",
    functions=(channel_to_freq, freq_to_channel),
).set_xlabel("Frequency [MHz]")

phase_ax.scatter(channels, np.rad2deg(act_phase), label="Actual")
phase_ax.scatter(channels, np.rad2deg(exp_phase), label="Expected")
phase_ax.set_ylabel("Phase (degree)")
amp_ax.set_xlabel("Channel")
phase_ax.secondary_xaxis(
    "top",
    functions=(channel_to_freq, freq_to_channel),
).set_xlabel("Frequency [MHz]")

primary_axes = amp_ax
handles, labels = primary_axes.get_legend_handles_labels()
fig.legend(handles, labels, loc="outside upper right")

plt.show()
plt.close(fig)

## Generate images

In [None]:
# Resolve wsclean / DP3 commands (use env vars if set)

wsclean_cmd = os.environ.get("WSCLEAN_CMD", "wsclean")
# Early check: wsclean must exist
if not shutil.which(wsclean_cmd):
    raise Exception(
        f"wsclean command not found (looked for '{wsclean_cmd}'). "
        "Either add wsclean to PATH, or set WSCLEAN_CMD environment variable pointing to the executable."
    )

dp3_cmd = os.environ.get("DP3_CMD", "DP3")
# Early check: DP3 must exist
if not shutil.which(dp3_cmd):
    raise Exception(
        f"DP3 command not found (looked for '{dp3_cmd}'). "
        "Either add DP3 to PATH, or set DP3_CMD environment variable pointing to the executable."
    )

In [None]:
image_output_dir = f"{output_dir}/images"
os.makedirs(image_output_dir, exist_ok=True)

In [None]:
!cd $image_output_dir ; bash ../../../../scripts/clean-with-gains -i $vis_ms -g $actual_gaintable -size 512 512 -scale 2arcsec &> log.txt

### Plotting visibility amplitude vs frequency for all baselines of antenna

In [None]:
corrected_ms_path = glob.glob(f"{image_output_dir}/corrected.ms")[0]

In [None]:
time_idx = 0
antenna_idx = 10
pol_idx = 0  # XX. or 3 for YY.
refant = 0

In [None]:
correct_vis = load_ms_as_dataset_with_time_chunks(
    corrected_ms_path, zarr_chunks["time"]
)

nantennas = correct_vis.configuration.id.size

vis_baseline_indices = pandas.MultiIndex.from_arrays(
    np.triu_indices(nantennas, k=0), names=("antenna1", "antenna2")
)
antenna_idx_to_baselines_idx = [
    (other, antenna_idx) for other in range(antenna_idx)
] + [(antenna_idx, other) for other in range(antenna_idx, nantennas)]
baseline_idx = np.array(
    [vis_baseline_indices.get_loc(indices) for indices in antenna_idx_to_baselines_idx]
)

data_to_plot = correct_vis.vis.isel(
    time=time_idx, baselineid=baseline_idx, polarisation=pol_idx
)
data_to_plot.load()

amp = np.abs(data_to_plot)

fig = plt.figure(layout="constrained", figsize=(18, 9))

for idx, bl in enumerate(antenna_idx_to_baselines_idx):
    if idx % 5 == 0:
        plt.plot(amp.isel(baselineid=idx), label=f"{bl[0]}-{bl[1]}")

plt.xlabel("Frequency Channel")
plt.ylabel("Amplitude")
plt.title(f"Amp vs Freq for Antenna {antenna_idx}")
plt.legend()
plt.grid(True)
plt.show()
plt.close(fig)