# Validate Bandpass Polarisation Stage

## Imports

In [None]:
from ska_sdp_instrumental_calibration.workflow.stages.bandpass_calibration import (
    bandpass_calibration_stage,
)

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import run_solver

from ska_sdp_instrumental_calibration.data_managers.visibility import (
    read_dataset_from_zarr,
    write_ms_to_zarr,
)
from ska_sdp_instrumental_calibration.workflow.utils import (
    create_bandpass_table,
    with_chunks,
)
from ska_sdp_datamodels.calibration.calibration_create import (
    create_gaintable_from_visibility,
)

from ska_sdp_instrumental_calibration.data_managers.data_export.export_to_h5parm import (
    export_gaintable_to_h5parm,
)

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    apply_gaintable_to_dataset,
    prediction_central_beams,
)


import dask

import numpy as np
import xarray as xr

import os
from typing import Literal

import glob
import h5py

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
# If you need to connect to a dask client
from distributed import Client

scheduler = "localhost:34567"

client = Client(scheduler)

## Generate data

We use the scripts stored in `scripts/ska_low_sim` to generate the data.
Information about how to use the scripts is present in this [confluence page](https://confluence.skatelescope.org/x/eZHMF).

### Simulation config

The following is the custom configuration used for simulation (further referred to as `custom_sim.yaml`)

```yaml
scenario: "low40s-model"          # Scenario name (used for output folder prefix)

# ===============================
# Global simulation parameters
# ===============================

n_stations: 40                                         # Number of stations
tel_model: "./SKA-Low_AA2_40S_rigid-rotation_model.tm" # Telescope model directory

simulation_start_frequency_hz: 123.0e6                  # Start frequency (Hz)
simulation_end_frequency_hz: 153.0e6                    # End frequency (Hz)
correlated_channel_bandwidth_hz: 21.70138888888889e3    # Channel width (Hz)

observing_time_mins: 10                              # Observation duration (minutes)
sampling_time_sec: 3.3973862400000003                   # Dump/integration time (seconds)

fields:
  EoR2:
    Cal1:
      ra_deg: 197.914612
      dec_deg: -22.277973
      scan_id_start: 300
      transit_time: "2000-01-03 22:33:30.000"

# ==================================
# Options for generate_gaintable.py
# ==================================

generate_gaintable:
  output_gaintable: &gen_gaintable "./gaintables/custom_gaintable.h5"

  spline_data_path: "./SKA_Low_AA2_SP5175_spline_data.npz" # Bandpass spline data file
  station_offset: true              # Apply per-station amplitude/phase offsets
  time_variant: true                # Apply time-dependent effects

  rfi: false                        # Inject RFI band
  rfi_start_freq_hz: 154.25347222228538e6        # Hz
  rfi_end_freq_hz: 159.8090277778474e6           # Hz

  plot: true                        # Generate diagnostic plots
  plot_output_dir: "./gaintables/generation_plots/"

# ===============================
# Options for run_sim.py
# ===============================

run_sim:
  oskar_sif: "./OSKAR-2.11.1-Python3.sif" # Path to OSKAR Singularity image

  # GLEAM sky model. Optional. Comment to disable.
  gleam_file: "./sky-models/GLEAM_EGC.fits" # GLEAM catalogue FITS file
  field_radius_deg: 10.0            # Radius of field of view (degrees)

  # Corruptions to be applied. All are optional. Comment to disable.
  # gaintable: *gen_gaintable           # Gaintable containing bandpass corruptions
  # cable_delay: "./cable_delays/cable_length_error_40s.txt" # Cable delay error file
  # tec_screen: "./tec/calibrator_iono_tec.fits" # Ionospheric TEC screen FITS

  # Imaging parameters using wsclean. Optional. Comment to disable.
  create_dirty_image: true          # Whether to run wsclean imaging
  image_size: 1024                  # Image size (pixels)
  pixel_size: "2arcsec"             # Pixel size (angular units)

  # Extra parameters to pass directly to run_oskar.py
  run_oskar_extra_params: "--use-gpus --double-precision"
  ```

### Overview of steps followed

Assuming the you have followed steps on the [confluence page](https://confluence.skatelescope.org/x/eZHMF) to get the data and setup python environment, here are the general steps to follow to simulate the data needed to run this script:

1. Create custom gaintable using the the above config and generate_gaintable script

    ```bash
    python3 generate_gaintable.py custom_sim.yaml
    ```

2. Convert gaintable to DP3 h5parm file

    ```bash
    python3 utils/h5parm_from_oskar_gains.py custom_sim.yaml "./gaintables/custom_gaintable.h5"
    ```

3. Create "model" visibilities

    ```bash
    python run_sim.py custom_sim.yaml
    ```

4. Apply the corruptions using DP3 to create corrupted visibilities

    ```bash
    msin="path_to_the_model_visibilities"
    parmdb="./gaintables/custom_gaintable.h5parm"
    mout="corrupted_visibilities.ms"
    DP3_CMD=${DP3_CMD:-"DP3"}

    ${DP3_CMD} steps="[applycal]" msin=${msin}\
	  applycal.invert=False applycal.parmdb=${parmdb}\
	  applycal.missingantennabehavior=flag applycal.updateweights=True\
	  applycal.steps="[phase, amplitude]" applycal.phase.correction=phase000\
	  applycal.amplitude.correction=amplitude000 msout=${msout}
    ```


## Setup common parameters

In [None]:
# Set these paths to the simulated data generated in previous steps

vis_ms = "/home/ska/Work/data-simulation/low40s-model-250925_120444/visibility.scan-300.corrupt.ms"
model_ms = (
    "/home/ska/Work/data-simulation/low40s-model-250925_120444/visibility.scan-300.ms"
)
sim_gaintable = "/home/ska/Work/data-simulation/low40s-model-250925_120444/customgaintable/gain_model_cal.h5"
dp3_gaintable = "/home/ska/Work/data-simulation/low40s-model-250925_120444/customgaintable/gain_model_cal_dp3.h5parm"

# Zarr conversion params
nchannels_per_chunk = 20
ntimes_per_ms_chunk = 20

# Everybeam related params
beam_type = "everybeam"
eb_coeffs = "/home/ska/Work/data/INST/sim/coeffs"
eb_ms = vis_ms

# Normalise visibility at beam centre
normalise_at_beam_centre = True

# Solver params
solver: str = "gain_substitution"
refant: int = 0
niter: int = 100
phase_only: bool = False
tol: float = 1e-06
crosspol: bool = False
normalise_gains: str = None
jones_type: Literal["T", "G", "B"] = "B"
timeslice: float = None

run_solver_params = dict(
solver=solver,
refant=refant,
niter=niter,
phase_only=phase_only,
tol=tol,
crosspol=crosspol,
normalise_gains=normalise_gains,
jones_type=jones_type,
timeslice=timeslice,)

## Convert MS to Zarr files

In [None]:
# Common dimensions across zarr and loaded visibility dataset
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}

# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": ntimes_per_ms_chunk,
    "frequency": nchannels_per_chunk,
}

# Pipeline only works on frequency chunks
# Its expected that later stages follow same chunking pattern
vis_chunks = {
    **non_chunked_dims,
    "time": -1,
    "frequency": nchannels_per_chunk,
}

In [None]:
working_dir = os.getcwd()

vis_cache_directory = f"{working_dir}/cache/vis/"
model_cache_directory = f"{working_dir}/cache/model/"

## Load visibilities from zarr and create initial gaintable

In [None]:
vis_ds = read_dataset_from_zarr(vis_cache_directory, vis_chunks)

model_ds = read_dataset_from_zarr(model_cache_directory, vis_chunks)
# Fix to ensure that time coords match
model_ds = model_ds.assign_coords({"time": vis_ds.time})

In [None]:
gaintable = create_bandpass_table(vis_ds).pipe(with_chunks, vis_chunks)

# gaintable = create_gaintable_from_visibility(vis_ds, jones_type='B', timeslice=None)

## Normalise at beam center

In [None]:
def plot_vis(
    vis_ds: xr.Dataset,
    time: int,
    baselineid: int,
    polarisation: int,
    figsize=(10, 6),
    savepath=None,
):
    """
    Plot amplitude and phase of complex visibilities from an xarray Dataset.

    Parameters
    ----------
    vis_ds : xr.Dataset
        Dataset containing a complex-valued DataArray named 'vis'.
    time : int
        Index along the 'time' dimension.
    baselineid : int
        Index along the 'baselineid' dimension.
    polarisation : int
        Index along the 'polarisation' dimension.
    figsize : tuple, optional
        Figure size (width, height) in inches. Default is (10, 6).
    savepath : str, optional
        If provided, saves the plot to the given path.
    """
    # Extract the selected visibility slice
    vis = vis_ds.vis.isel(time=time, baselineid=baselineid, polarisation=polarisation)

    # Ensure we have a frequency dimension for the x-axis
    freq = vis_ds.coords.get("frequency", np.arange(vis.size))

    amp = np.abs(vis)
    phase = np.angle(vis, deg=True)  # Convert phase to degrees

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, sharex=True)

    # Plot amplitude
    ax1.plot(freq, amp, color="tab:blue")
    ax1.set_ylabel("Amplitude")
    ax1.set_title(f"Baseline {baselineid}, Time {time}, Pol {polarisation}")

    # Plot phase
    ax2.plot(freq, phase, color="tab:orange")
    ax2.set_ylabel("Phase [deg]")
    ax2.set_xlabel("Frequency" if "frequency" in vis_ds.coords else "Channel index")

    plt.tight_layout()

    if savepath:
        plt.savefig(savepath, dpi=150)
        plt.close(fig)
    else:
        plt.show()

In [None]:
time = 15
baselineid = 50
polarisation = 0

### Plotting visibility before normalising

In [None]:
plot_vis(vis_ds, time=time, baselineid=baselineid, polarisation=polarisation)

In [None]:
plot_vis(model_ds, time=time, baselineid=baselineid, polarisation=polarisation)

### Normalise visibility

In [None]:
if normalise_at_beam_centre:
    beams = prediction_central_beams(
        vis_ds,
        beam_type=beam_type,
        eb_ms=eb_ms,
        eb_coeffs=eb_coeffs,
    ).persist()

    vis_ds = apply_gaintable_to_dataset(vis_ds, beams, inverse=True)
    model_ds = apply_gaintable_to_dataset(model_ds, beams, inverse=True)

### Plotting after normalising

In [None]:
plot_vis(vis_ds, time=time, baselineid=baselineid, polarisation=polarisation)

In [None]:
plot_vis(model_ds, time=time, baselineid=baselineid, polarisation=polarisation)

## Running calibration and exporting gaintable

### run_solver

In [None]:
# output_dir = f"{working_dir}/run_solver_{solver}_{niter}/"

# os.makedirs(output_dir, exist_ok=True)

In [None]:
# gaintable = run_solver(
#     vis_ds,
#     model_ds,
#     gaintable,
#     **run_solver_params,
# )

In [None]:
# export_task = dask.delayed(export_gaintable_to_h5parm, pure=False)(
#     gaintable,
#     f"{output_dir}/gaintable.h5parm",
# )

In [None]:
# client.compute(export_task, sync=True, optimize_graph=True)

In [None]:
# client.restart()

### bandpass_calibration stage

In [None]:
output_dir = f"{working_dir}/bandpass_stage_{solver}_{niter}/"

os.makedirs(output_dir, exist_ok=True)

In [None]:
from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput

upstream_output = UpstreamOutput()

upstream_output["vis"] = vis_ds
upstream_output["modelvis"] = model_ds
upstream_output["gaintable"] = gaintable

upout = bandpass_calibration_stage.stage_definition(
    upstream_output,
    run_solver_params,
    {"plot_table": False, "fixed_axis": False},
    "vis",
    True,
    output_dir,
)

In [None]:
client.compute(
    upout.compute_tasks,
    sync=True,
    optimize_graph=True,
)

In [None]:
# client.restart()

## Comparing h5parm files

In [None]:
output_h5_path = glob.glob(f"{output_dir}/*.h5*")[0]

actual_gaintable = output_h5_path

In [None]:
def suppress_exception(func, *args, **kwargs):
    try:
        return func(*args, **kwargs)
    except Exception as e:
        print(e, end="\n\n")

#### Compare oskar gains with DP3 gains

In [None]:
with h5py.File(sim_gaintable) as sim_gain_f:
    with h5py.File(dp3_gaintable) as dp3_gain_f:
        for pol_id, pol in enumerate(["x", "y"]):
            print(f"Amplitude for pol {pol}")
            suppress_exception(
                np.testing.assert_allclose,
                np.abs(sim_gain_f[f"gain_{pol}pol"]),
                dp3_gain_f["sol000"]["amplitude000"]["val"][..., pol_id].transpose(
                    0, 2, 1
                ),
            )

            print(f"Phase for pol {pol}")
            suppress_exception(
                np.testing.assert_allclose,
                np.angle(sim_gain_f[f"gain_{pol}pol"]),
                dp3_gain_f["sol000"]["phase000"]["val"][..., pol_id].transpose(0, 2, 1),
            )

#### Compare actual vs DP3, by averaging expected DP3 gains across time

In [None]:
with h5py.File(dp3_gaintable) as dp3_gain_f:
    with h5py.File(actual_gaintable) as act_gain_f:
        for pol_id, pol in zip([0, 1], ["XX", "YY"]):
            print(f"Amplitude for pol {pol}")
            suppress_exception(
                np.testing.assert_allclose,
                np.mean(
                    dp3_gain_f["sol000"]["amplitude000"]["val"][..., pol_id % 2],
                    axis=0,
                    keepdims=True,
                )[..., :-1],
                act_gain_f["sol000"]["amplitude000"]["val"][..., pol_id],
            )

            print(f"Phase for pol {pol}")
            suppress_exception(
                np.testing.assert_allclose,
                np.mean(
                    dp3_gain_f["sol000"]["phase000"]["val"][..., pol_id % 2],
                    axis=0,
                    keepdims=True,
                )[..., :-1],
                act_gain_f["sol000"]["phase000"]["val"][..., pol_id],
            )

#### Compare actual vs DP3, by broadcasting expected DP3 gains across time

In [None]:
with h5py.File(dp3_gaintable) as dp3_gain_f:
    with h5py.File(actual_gaintable) as act_gain_f:
        for pol_id, pol in zip([0, 1], ["XX", "YY"]):
            print(f"Amplitude for pol {pol}")
            expanded_act = []
            for i in range(dp3_gain_f["sol000"]["amplitude000"]["val"].shape[0]):
                expanded_act.append(
                    act_gain_f["sol000"]["amplitude000"]["val"][..., pol_id]
                )
            expanded_act = np.concatenate(expanded_act)

            suppress_exception(
                np.testing.assert_allclose,
                dp3_gain_f["sol000"]["amplitude000"]["val"][..., pol_id % 2][..., :-1],
                expanded_act,
            )

            print(f"Phase for pol {pol}")
            expanded_act = []
            for i in range(dp3_gain_f["sol000"]["phase000"]["val"].shape[0]):
                expanded_act.append(
                    act_gain_f["sol000"]["phase000"]["val"][..., pol_id]
                )
            expanded_act = np.concatenate(expanded_act)

            suppress_exception(
                np.testing.assert_allclose,
                dp3_gain_f["sol000"]["phase000"]["val"][..., pol_id % 2][..., :-1],
                expanded_act,
            )

## Generate images

In [None]:
os.environ['PATH'] = f"/home/ska/spack/opt/spack/linux-ubuntu22.04-x86_64_v3/gcc-11.4.0/dp3-6.4.1-f6kotkx5la5ziuqge4vi2bwlnoizlqin/bin:/home/ska/spack/opt/spack/linux-ubuntu22.04-x86_64_v3/gcc-11.4.0/wsclean-3.6.20250630-biqw4junkbmxbtxtfnle5gymjiyoukhd/bin/:{os.environ['PATH']}"

In [None]:
# Resolve wsclean / DP3 commands (use env vars if set)
import shutil

wsclean_cmd = os.environ.get("WSCLEAN_CMD", "wsclean")
# Early check: wsclean must exist
if not shutil.which(wsclean_cmd):
    raise Exception(
        f"wsclean command not found (looked for '{wsclean_cmd}'). "
        "Either add wsclean to PATH, or set WSCLEAN_CMD environment variable pointing to the executable."
    )

dp3_cmd = os.environ.get("DP3_CMD", "DP3")
# Early check: DP3 must exist
if not shutil.which(dp3_cmd):
    raise Exception(
        f"DP3 command not found (looked for '{dp3_cmd}'). "
        "Either add DP3 to PATH, or set DP3_CMD environment variable pointing to the executable."
    )

In [None]:
!cd $output_dir ; bash ../../../scripts/clean-with-gains -i $vis_ms -g $actual_gaintable -size 512 512 -scale 2arcsec &> log.txt