In [1]:
import os
import numpy as np
import xarray as xr
import os
import dask
# dask.config.set(scheduler='threads')  # or 'synchronous' for step-by-step

# If you want a dashboard:
# from dask.distributed import Client
# client = Client()
from ska_sdp_instrumental_calibration.workflow.stages import (
    load_data_stage,
    predict_vis_stage,
    bandpass_calibration_stage
)

from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput
from copy import deepcopy
from ska_sdp_instrumental_calibration.data_managers.visibility import _generate_file_paths_for_vis_zarr_file
import pickle

from ska_sdp_instrumental_calibration.workflow.utils import with_chunks



  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


In [2]:
def export_model_vis(vis, output_dir, zarr_chunks ,):
    attributes_file, baselines_file, vis_zarr_file = (
        _generate_file_paths_for_vis_zarr_file(output_dir)
    )
    attrs = deepcopy(vis.attrs)
    with open(attributes_file, "wb") as file:
        pickle.dump(attrs, file)

    baselines = deepcopy(vis.baselines).compute()
    with open(baselines_file, "wb") as file:
        pickle.dump(baselines, file)

    writer = (
        vis.drop_attrs()
        .drop_vars("baselines")
        .pipe(with_chunks, zarr_chunks)
        .to_zarr(vis_zarr_file, mode="w", compute=False)
    )

    dask.compute(writer)

def import_model_vis(input_dir, vis_chunks):
    attributes_file, baselines_file, vis_zarr_file = (
        _generate_file_paths_for_vis_zarr_file(input_dir)
    )

    zarr_data = xr.open_dataset(
        vis_zarr_file, chunks=vis_chunks, engine="zarr"
    )

    with open(attributes_file, "rb") as file:
        attrs = pickle.load(file)

    zarr_data = zarr_data.assign_attrs(attrs)

    with open(baselines_file, "rb") as file:
        baselines = pickle.load(file)

    zarr_data = zarr_data.assign({"baselines": baselines})

    return zarr_data



In [3]:
def compute_and_clear_tasks(upstream_output: UpstreamOutput):
    for task in upstream_output.compute_tasks:
        task.compute()

    # Clear the compute tasks
    upstream_output.stage_compute_tasks = []

In [4]:
vis_ms = "/data/SKA/inst_data/cal_bpp_vis-lg3-rotated.small.ms/"
cache_dir = (
    "/home/ska/projects/ska/ska-sdp-instrumental-calibration/notebook_temp/cache"
)
_cli_args_ = {"input": vis_ms}
nchannels_per_chunk = 32
ntimes_per_ms_chunk = 5

upstream_output = UpstreamOutput()
ack = False
datacolumn = "DATA"
field_id = 0
data_desc_id = 0

# Predict visibilities stage
beam_type = "everybeam"
normalise_at_beam_centre = True
eb_coeffs = "/data/SKA/inst_data/DHR-305/coeffs"
eb_ms = None
lsm_csv_path = "/data/SKA/inst_data/DHR-305/sky_model_cal.csv"
gleamfile = None
fov = 5.0
flux_limit = 1.0
alpha0 = -0.78


# bandpass calibration stage
solver = "gain_substitution"
refant = 0
niter = 150
phase_only = False
tol = 1.0e-06
crosspol = False
normalise_gains = None
timeslice = None
plot_config = {
    "plot_table": False,
    "fixed_axis": False,
}
visibility_key = "vis"
export_gaintable = True
run_solver_config = {
    "solver": solver,
    "niter": niter,
    "phase_only": phase_only,
    "tol": tol,
    "crosspol": crosspol,
    "refant": refant,
    "timeslice": timeslice,
    "normalise_gains": normalise_gains,
}

output_dir = f"bandpass_stage_{solver}_{niter}"

In [5]:
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}

# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": ntimes_per_ms_chunk,
    "frequency": nchannels_per_chunk,
}
vis_chunks = {
        **non_chunked_dims,
        "time": -1,
        "frequency": nchannels_per_chunk,
}

In [6]:
load_data_stage.stage_definition(
    upstream_output,
    nchannels_per_chunk,
    ntimes_per_ms_chunk,
    cache_dir,
    ack,
    datacolumn,
    field_id,
    data_desc_id,
    _cli_args_,
    output_dir,
)
vis = upstream_output["vis"]
initial_gaintable = upstream_output["gaintable"]

1|2025-11-12T11:22:24.166Z|INFO|MainThread|load_data_stage|load_data.py#181||Reading cached visibilities from path /home/ska/projects/ska/ska-sdp-instrumental-calibration/notebook_temp/cache/cal_bpp_vis-lg3-rotated.small.ms_fid0_ddid0


In [7]:
is_model_vis_exported = True

In [8]:
if not is_model_vis_exported:
    predict_vis_stage.stage_definition(
        upstream_output,
        beam_type,
        normalise_at_beam_centre,
        eb_ms,
        eb_coeffs,
        gleamfile,
        lsm_csv_path,
        fov,
        flux_limit,
        alpha0,
        _cli_args_,
    )

    modelvis = upstream_output["modelvis"]
    compute_and_clear_tasks(upstream_output)
    export_model_vis(modelvis, output_dir, zarr_chunks)
else: 
    modelvis = import_model_vis(output_dir, vis_chunks)
    upstream_output["modelvis"] = modelvis
    


In [9]:

bandpass_calibration_stage.stage_definition(
    upstream_output,
    run_solver_config,
    plot_config,
    visibility_key,
    export_gaintable,
    output_dir,
)

gaintable_expected = upstream_output["gaintable"]
gaintable_expected = gaintable_expected.compute()

compute_and_clear_tasks(upstream_output)

1|2025-11-12T11:22:24.260Z|INFO|MainThread|bandpass_calibration_stage|bandpass_calibration.py#134||Using vis for calibration.
1|2025-11-12T11:22:24.315Z|INFO|MainThread|bandpass_calibration_stage|bandpass_calibration.py#169||exporting gaintable to h5parm format.
1|2025-11-12T11:22:24.404Z|INFO|ThreadPoolExecutor-0_12|solve_gaintable|solvers.py#142||solve_gaintable: Starting calibration
1|2025-11-12T11:22:24.408Z|INFO|ThreadPoolExecutor-0_8|solve_gaintable|solvers.py#142||solve_gaintable: Starting calibration
1|2025-11-12T11:22:24.414Z|INFO|ThreadPoolExecutor-0_12|solve_gaintable|solvers.py#143||solve_gaintable: Using solver gain_substitution
1|2025-11-12T11:22:24.421Z|INFO|ThreadPoolExecutor-0_8|solve_gaintable|solvers.py#143||solve_gaintable: Using solver gain_substitution
1|2025-11-12T11:22:24.429Z|INFO|ThreadPoolExecutor-0_14|solve_gaintable|solvers.py#142||solve_gaintable: Starting calibration
1|2025-11-12T11:22:24.443Z|INFO|ThreadPoolExecutor-0_7|solve_gaintable|solvers.py#142||so

In [10]:
gaintable_expected = gaintable_expected.compute()
print(gaintable_expected.gain)


<xarray.DataArray 'gain' (time: 1, antenna: 40, frequency: 432, receptor1: 2,
                          receptor2: 2)> Size: 553kB
array([[[[[0.7311617 +1.2231389e-18j, 0.        +0.0000000e+00j],
          [0.        +0.0000000e+00j, 0.6947208 +3.1986262e-18j]],

         [[0.7300085 +1.6386358e-18j, 0.        +0.0000000e+00j],
          [0.        +0.0000000e+00j, 0.69366795-1.4478307e-18j]],

         [[0.72868466-5.3031827e-18j, 0.        +0.0000000e+00j],
          [0.        +0.0000000e+00j, 0.69247955-4.3580482e-18j]],

         ...,

         [[0.91482896-1.8799090e-18j, 0.        +0.0000000e+00j],
          [0.        +0.0000000e+00j, 0.9102958 +6.9888144e-19j]],

         [[0.9117713 -2.4636278e-18j, 0.        +0.0000000e+00j],
          [0.        +0.0000000e+00j, 0.9072989 -1.2682657e-18j]],

         [[0.9099243 +1.9993446e-18j, 0.        +0.0000000e+00j],
          [0.        +0.0000000e+00j, 0.90550447-4.6141897e-19j]]],

...
        [[[0.6984385 -3.1077141e-02j, 0.     

In [11]:
# x = 2.0/0.0
from ska_sdp_func_python.calibration.solvers import solve_gaintable
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import restore_baselines_dim
vis = restore_baselines_dim(vis)
modelvis = restore_baselines_dim(modelvis)
vis = vis.compute()
modelvis = modelvis.compute()
initial_gaintable_copy = (initial_gaintable.compute()).copy(deep=True)

gaintable_actual = solve_gaintable(
    vis=vis,
    modelvis=modelvis,
    gain_table=initial_gaintable_copy,
    solver=solver,
    phase_only=phase_only,
    niter=niter,
    tol=tol,
    crosspol=crosspol,
    normalise_gains=normalise_gains,
    jones_type="B",
    refant=refant,
    timeslice=timeslice,
)

1|2025-11-12T11:22:39.980Z|INFO|MainThread|solve_gaintable|solvers.py#142||solve_gaintable: Starting calibration
1|2025-11-12T11:22:39.981Z|INFO|MainThread|solve_gaintable|solvers.py#143||solve_gaintable: Using solver gain_substitution
1|2025-11-12T11:22:42.888Z|INFO|MainThread|solve_gaintable|solvers.py#285||solve_gaintable: Finished calibration


In [12]:
np.testing.assert_allclose(
    gaintable_expected["gain"].values,
    gaintable_actual["gain"].values,
    rtol=1.0e-6,
    atol=1.0e-6,
)
print("Gaintable values match expected results.")


Gaintable values match expected results.
