## Setup

In [None]:
import logging
import os

import numpy
import pandas

import numpy as np

from casacore.tables import table, taql

import logging
import dask

from ska_ser_logging import configure_logging

from ska_sdp_datamodels.visibility.vis_io_ms import (
    create_visibility_from_ms,
)
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    load_ms as load_ms_og,
)

configure_logging(level=logging.INFO)

In [None]:
rev_input_ms_path = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/data/cal_bpp_vis-lg3-rotated.small.reversed.ms"
input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.small.ms"
# input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.ms"
# input_ms_path = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/data/demo.ms"

# eb_ms = input_ms_path
eb_ms = "/home/ska/Work/data/INST/sim/OSKAR_MOCK.ms"

eb_coeffs = "/home/maneesh/.cache/pypoetry/virtualenvs/ska-sdp-instrumental-calibration-vujiG8jS-py3.11/share/everybeam/"

lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"
gleamfile = "/home/ska/Work/data/INST/sim/gleamegc.dat"




## Playground

### Loading data using casacore

In [None]:
datacolumn = "DATA"
start_chan = 1000
end_chan = 1064

In [None]:
msname = input_ms_path
ack = False
field = 0
dd = 0

tab = table(msname, ack=ack)

fields = numpy.unique(tab.getcol("FIELD_ID"))
dds = numpy.unique(tab.getcol("DATA_DESC_ID"))

ftab = table(msname).query(f"FIELD_ID=={field}", style="")

ddtab = table(f"{msname}/DATA_DESCRIPTION")
spwid = ddtab.getcol("SPECTRAL_WINDOW_ID")[dd]
polid = ddtab.getcol("POLARIZATION_ID")[dd]
# ddtab.close()

meta = {"MSV2": {"FIELD_ID": field, "DATA_DESC_ID": dd}}
ms = ftab.query(f"DATA_DESC_ID=={dd}", style="")

otime = ms.getcol("TIME")
nrows = otime.shape[0]
datacol = ms.getcol(datacolumn, nrow=1)
datacol_shape = list(datacol.shape)
channels = datacol.shape[-2]

blc = [start_chan, 0]
trc = [end_chan, datacol_shape[-1] - 1]
channum = range(start_chan, end_chan + 1)

In [None]:
ms

In [None]:
ms.getcoldesc("DATA")

## Implementation

In [None]:
# Moved to ska_sdp_instrumental_calibration.data_managers.visibility module
from ska_sdp_instrumental_calibration.data_managers.visibility import (
    load_ms_as_dataset_with_time_chunks,
    check_if_cache_files_exist,
    read_visibility_from_zarr,
    write_ms_to_zarr,
)

## Testing

### load_ms vs load_ms_as_dataset_with_time_chunks

In [13]:
# vis = load_ms_as_dataset_with_time_chunks(input_ms_path, 5)
vis = load_ms_as_dataset_with_time_chunks(rev_input_ms_path, 5)

vis.load();

1|2025-08-11T04:58:57.363Z|INFO|MainThread|_load_data_vars|visibility.py#574||Does measurement set contain autocorrelations? False
1|2025-08-11T04:58:57.364Z|INFO|MainThread|_load_data_vars|visibility.py#589||In the measurement set, is the baseline antenna order reversed (i.e. is antenna1 > antenna2)? True


1|2025-08-11T04:59:11.186Z|INFO|ThreadPoolExecutor-0_8|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]
1|2025-08-11T04:59:11.188Z|INFO|ThreadPoolExecutor-0_1|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]
1|2025-08-11T04:59:11.189Z|INFO|ThreadPoolExecutor-0_7|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]
1|2025-08-11T04:59:11.269Z|INFO|ThreadPoolExecutor-0_6|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]
1|2025-08-11T04:59:11.271Z|INFO|ThreadPoolExecutor-0_2|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]
1|2025-08-11T04:59:39.216Z|INFO|ThreadPoolExecutor-0_2|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]


In [None]:
og_vis = load_ms_og(rev_input_ms_path, fchunk=64)

og_vis.load();

1|2025-08-11T04:59:38.976Z|INFO|MainThread|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]
1|2025-08-11T04:59:39.218Z|INFO|ThreadPoolExecutor-0_23|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]
1|2025-08-11T04:59:39.220Z|INFO|ThreadPoolExecutor-0_21|create_visibility_from_ms|vis_io_ms.py#339||Reading uni. fields [0], uni. data descs [0]


In [18]:
try:
    np.testing.assert_allclose(
        np.real(og_vis.vis.data[..., [0, 2, 1, 3]]),
        np.real(vis.vis.data),
        err_msg="mismatch in real values",
    )
except Exception as e:
    print(e)

try:
    np.testing.assert_allclose(
        np.imag(og_vis.vis.data[..., [0, 2, 1, 3]]),
        np.imag(vis.vis.data),
        err_msg="mismatch in imaginary values",
    )
except Exception as e:
    print(e)

In [39]:
try:
    np.testing.assert_allclose(
        og_vis.flags.data,
        vis.flags.data,
        err_msg="mismatch in flags values",
    )
except Exception as e:
    print(e)

In [40]:
try:
    np.testing.assert_allclose(
        og_vis.weight.data,
        vis.weight.data,
        err_msg="mismatch in weight values",
    )
except Exception as e:
    print(e)

In [41]:
try:
    np.testing.assert_allclose(
        og_vis.uvw.data,
        vis.uvw.data,
        err_msg="mismatch in uvw values",
    )
except Exception as e:
    print(e)

In [42]:
assert og_vis.equals(vis)

### Test baseline ordering

In [None]:
nantennas = 5

ms_is_baseline_order_reversed = False
ms_contains_autocorrelations = False
# Dependent on above values
antenna1, antenna2 = np.triu_indices(nantennas, k=1)
# antenna2, antenna1 = np.triu_indices(nantennas, k=1)

nbaselines = len(antenna1)

nrows = nbaselines * 1  # Time is 1

vis_baseline_indices = pandas.MultiIndex.from_arrays(
    np.triu_indices(nantennas, k=0), names=("antenna1", "antenna2")
)

antenna1, antenna2

In [None]:
baselines = vis_baseline_indices

expected = []
# Main loop in older function
for row in range(nrows):
    if antenna1[row] <= antenna2[row]:
        ibaseline = baselines.get_loc((antenna1[row], antenna2[row]))
        print("ant1<=ant2", antenna1[row], antenna2[row], ibaseline)
    elif antenna1[row] > antenna2[row]:
        ibaseline = baselines.get_loc((antenna2[row], antenna1[row]))
        print("ant1>ant2", antenna1[row], antenna2[row], ibaseline)
    else:
        raise ValueError("Bad antenna pair")

    expected.append(ibaseline)

In [None]:
if ms_is_baseline_order_reversed:
    # ms_baseline_indices_function = np.tril_indices
    indices_order = slice(None, None, -1)
    # diag_offset = -1
else:
    # ms_baseline_indices_function = np.triu_indices
    indices_order = slice(None, None, None)

diag_offset = 1
if ms_contains_autocorrelations:
    diag_offset = 0

ms_baseline_indices = pandas.MultiIndex.from_arrays(
    np.triu_indices(nantennas, k=diag_offset)[indices_order],
    names=("antenna1", "antenna2"),
)

vis_baseline_indices_to_update = np.array(
    [
        vis_baseline_indices.get_loc(indices[indices_order])
        for indices in ms_baseline_indices
    ]
)

print(vis_baseline_indices_to_update)

In [None]:
np.testing.assert_equal(np.array(expected), vis_baseline_indices_to_update)

### create_visibility_from_ms vs read_dataset_from_zarr

In [None]:
fchunk = 64
times_per_ms_chunk = 24

# Common dimensions across zarr and loaded visibility dataset
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}

# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": times_per_ms_chunk,
    "frequency": fchunk,
}

# Pipeline only works on frequency chunks
# Its expected that later stages follow same chunking pattern
vis_chunks = {
    **non_chunked_dims,
    "time": -1,
    "frequency": fchunk,
}

In [None]:
sim_cache_dir = "./sim_cache"
os.makedirs(sim_cache_dir, exist_ok=True)

with dask.annotate(resources={"process": 1}):
    write_ms_to_zarr(
        input_ms_path,
        sim_cache_dir,
        zarr_chunks,
    )

In [None]:
assert check_if_cache_files_exist(sim_cache_dir)

zarr_vis = read_visibility_from_zarr(sim_cache_dir, vis_chunks)

In [None]:
og_ms_vis = create_visibility_from_ms(input_ms_path)[0]

In [None]:
np.testing.assert_allclose(og_ms_vis.vis, zarr_vis.vis)
np.testing.assert_allclose(og_ms_vis.flags, zarr_vis.flags)
np.testing.assert_allclose(og_ms_vis.uvw, zarr_vis.uvw)
np.testing.assert_allclose(og_ms_vis.weight, zarr_vis.weight)

### test (generated) ms comparision

In [None]:
import numpy as np
from astropy.coordinates import SkyCoord
from ska_sdp_datamodels.calibration.calibration_create import (
    create_gaintable_from_visibility,
)
from ska_sdp_datamodels.configuration.config_create import (
    create_named_configuration,
)
from ska_sdp_datamodels.science_data_model import PolarisationFrame
from ska_sdp_datamodels.visibility.vis_create import create_visibility

from ska_sdp_datamodels.visibility.vis_io_ms import export_visibility_to_ms

In [None]:
def generate_vis():
    """Fixture to build Visibility and GainTable datasets."""
    # Create the Visibility dataset
    config = create_named_configuration("LOWBD2")
    AA1 = (
        np.concatenate(
            (
                345 + np.arange(6),  # S8-1:6
                351 + np.arange(4),  # S9-1:4
                429 + np.arange(6),  # S10-1:6
                465 + np.arange(4),  # S16-1:4
            )
        )
        - 1
    )
    mask = np.isin(config.id.data, AA1)
    nstations = config.stations.shape[0]
    config = config.sel(indexers={"id": np.arange(nstations)[mask]})
    # Reset relevant station parameters
    nstations = config.stations.shape[0]
    config.stations.data = np.arange(nstations).astype("str")
    config = config.assign_coords(id=np.arange(nstations))
    # config.attrs["name"] = config.name+"-AA1"
    config.attrs["name"] = "AA1-Low"
    vis = create_visibility(
        config=config,
        times=np.arange(3) * 0.9 / 3600 * np.pi / 12,
        frequency=150e6 + 1e6 * np.arange(4),
        channel_bandwidth=[1e6] * 4,
        phasecentre=SkyCoord(ra=0, dec=-27, unit="degree"),
        polarisation_frame=PolarisationFrame("linear"),
        weight=1.0,
    )
    # Put a point source at phase centre
    vis.vis.data[..., :] = [1, 0, 0, 1]

    # Create the GainTable dataset
    jones = create_gaintable_from_visibility(vis, jones_type="B")
    jones.gain.data[..., 0, 0] = 1 - 0.1j
    jones.gain.data[..., 1, 1] = 3 + 0j
    jones.gain.data += np.random.normal(0, 0.2, jones.gain.shape)
    jones.gain.data += np.random.normal(0, 0.2, jones.gain.shape) * 1j

    return vis, jones


# def generate_ms(tmp_path, vis):
#     """Create and later delete test MSv2."""
#     ms_path = f"{tmp_path}/{ms_name}"
#     export_visibility_to_ms(ms_path, [vis])

#     yield ms_path

#     shutil.rmtree(ms_path)

In [None]:
gen_vis, _ = generate_vis()

ms_path = "./test.ms"

export_visibility_to_ms(ms_path, [gen_vis])

In [None]:
test_cache_dir = "./test_cache"
os.makedirs(test_cache_dir, exist_ok=True)

# Common dimensions across zarr and loaded visibility dataset
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}
# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": 1,
    "frequency": 2,
}
vis_chunks = {
    **non_chunked_dims,
    "time": -1,
    "frequency": 2,
}


with dask.annotate(resources={"process": 1}):
    write_ms_to_zarr(
        ms_path,
        test_cache_dir,
        zarr_chunks,
    )

In [None]:
assert check_if_cache_files_exist(test_cache_dir)

zarr_vis = read_dataset_from_zarr(test_cache_dir, vis_chunks)

In [None]:
create_vis = create_visibility_from_ms(ms_path)[0]

In [None]:
np.testing.assert_allclose(zarr_vis.vis, create_vis.vis)
np.testing.assert_allclose(zarr_vis.uvw, create_vis.uvw)
np.testing.assert_allclose(zarr_vis.flags, create_vis.flags)
np.testing.assert_allclose(zarr_vis.weight, create_vis.weight)

In [None]:
np.testing.assert_allclose(gen_vis.vis, zarr_vis.vis)
np.testing.assert_allclose(gen_vis.uvw, zarr_vis.uvw)
np.testing.assert_allclose(gen_vis.flags, zarr_vis.flags)
np.testing.assert_allclose(gen_vis.weight, zarr_vis.weight)

In [None]:
np.testing.assert_allclose(gen_vis.vis, create_vis.vis)
np.testing.assert_allclose(gen_vis.uvw, create_vis.uvw)
np.testing.assert_allclose(gen_vis.flags, create_vis.flags)
np.testing.assert_allclose(gen_vis.weight, create_vis.weight)