# Optimise Export H5params

## Setup

In [None]:
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import load_ms
from ska_sdp_instrumental_calibration.data_managers.visibility import (
    read_dataset_from_zarr,
)
from ska_sdp_instrumental_calibration.workflow.utils import create_bandpass_table
from ska_sdp_instrumental_calibration.workflow.utils import with_chunks
from ska_sdp_instrumental_calibration.data_managers.data_export.export_to_h5parm import (
    export_gaintable_to_h5parm,
)

import h5py

import os
import shutil

from ska_sdp_datamodels.calibration.calibration_model import GainTable
from typing import Iterable, Literal

import dask
from distributed import Semaphore
import dask.array as da

import numpy as np
import xarray as xr
from ska_sdp_datamodels.visibility.vis_create import create_visibility

## Create gaintable with random gain and weight

In [None]:
from astropy.coordinates import Angle, SkyCoord
from astropy import units
from ska_sdp_datamodels.science_data_model import PolarisationFrame
from ska_sdp_datamodels.configuration.config_create import (
    create_named_configuration,
)

config = create_named_configuration("LOWBD2")
AA1 = (
    np.concatenate(
        (
            345 + np.arange(6),  # S8-1:6
            351 + np.arange(4),  # S9-1:4
            429 + np.arange(6),  # S10-1:6
            465 + np.arange(4),  # S16-1:4
        )
    )
    - 1
)
mask = np.isin(config.id.data, AA1)
nstations = config.stations.shape[0]
config = config.sel(indexers={"id": np.arange(nstations)[mask]})
# Reset relevant station parameters
nstations = config.stations.shape[0]
config.stations.data = np.arange(nstations).astype("str")
config = config.assign_coords(id=np.arange(nstations))
# config.attrs["name"] = config.name+"-AA1"
config.attrs["name"] = "AA1-Low"

nfrequency = 3456
ntimes = 89

vis = create_visibility(
    config=config,
    times=np.arange(ntimes) * 0.9 / 3600 * np.pi / 12,
    frequency=150e6 + 1e6 * np.arange(nfrequency),
    channel_bandwidth=[1e6] * nfrequency,
    phasecentre=SkyCoord(ra=0, dec=-27, unit="degree"),
    polarisation_frame=PolarisationFrame("linear"),
    weight=1.0,
)
vis
# time: 89frequency: 3456polarisation: 4spatial: 3baselineid: 820

In [None]:
vis_chunks = {"time": -1, "frequency": 32}
gaintable = create_bandpass_table(vis).pipe(with_chunks, vis_chunks)

In [None]:
np.random.seed(67)

gain_npr = (
    np.random.rand(gaintable.gain.size).astype(np.float64)
    + 1j * np.random.rand(gaintable.gain.size).astype(np.float64)
).reshape(gaintable.gain.shape)
gain_da = da.from_array(gain_npr, chunks=gaintable.gain.data.chunks)
gain_xdr = xr.DataArray(gain_da, coords=gaintable.gain.coords)

weight_npr = (np.random.rand(gaintable.weight.size).astype(np.float64)).reshape(
    gaintable.weight.shape
)
weight_da = da.from_array(weight_npr, chunks=gaintable.weight.data.chunks)
weight_xdr = xr.DataArray(weight_da, coords=gaintable.weight.coords)

gaintable = gaintable.assign({"gain": gain_xdr, "weight": weight_xdr})

## Export h5parm OG

In [None]:
gaintable_dir = "./gaintable_experiment/"

shutil.rmtree(gaintable_dir, ignore_errors=True)
os.makedirs(gaintable_dir, exist_ok=True)

In [None]:
expected_h5_path = f"{gaintable_dir}/expected.h5parm"

export_gaintable_to_h5parm(gaintable.compute(), expected_h5_path)

## Dask friendly export

In [None]:
def _ndarray_of_null_terminated_bytes(strings: Iterable[str]):
    # NOTE: making antenna names one character longer, in keeping with
    # ska-sdp-batch-preprocess
    return np.asarray([s.encode("ascii") + b"\0" for s in strings])

In [None]:
def create_soltab_group(
    solset: h5py.Group, solution_type: Literal["amplitude", "phase", "clock"]
) -> h5py.Group:
    """Create soltab group under given solset group.

    :param solset: base-level HDF5 group to update
    :param solution_type: only "amplitude" and "phase" are supported at present
    :return: HDF5 group for the "solution_type" data
    """
    soltab = solset.create_group(f"{solution_type}000")
    soltab.attrs["TITLE"] = np.bytes_(solution_type)
    return soltab


def create_soltab_datasets(soltab: h5py.Group, gaintable: GainTable):
    """Add a dataset for each of the GainTable dimensions.

    :param soltab: HDF5 table to update
    :param gaintable: GainTable
    """
    # create a dataset for each dimension
    for dim in gaintable.gain.dims:
        soltab.create_dataset(dim, data=gaintable[dim].data)

    # create datasets for the data and weights
    shape = gaintable.gain.shape
    # chunks = gaintable.gain.chunks
    chunks = (1, 40, 32, 4)
    axes = np.bytes_(",".join(list(gaintable.gain.sizes)))

    val = soltab.create_dataset("val", shape=shape, dtype=gaintable.gain.real.dtype)
    # val = soltab.create_dataset("val", shape=shape, chunks=chunks, dtype=np.float32)
    val.attrs["AXES"] = axes

    weight = soltab.create_dataset("weight", shape=shape, dtype=gaintable.weight.dtype)

    # weight = soltab.create_dataset("weight", shape=shape, chunks=chunks, dtype=np.float32)
    weight.attrs["AXES"] = axes

    return val, weight

In [None]:
def write_to_file(data, freq_index, filename, ds_name, sem):
    with sem:
        with h5py.File(filename, "a") as file:
            file[ds_name][:, :, freq_index[0][0] : freq_index[-1][0] + 1, :] = data
    return data

In [None]:
def create_file_with_meta(filename, gaintable) -> dict:
    with h5py.File(filename, "w") as file:

        solset = file.create_group("sol000")

        # Amplitude table
        soltab = create_soltab_group(solset, "amplitude")
        amp_val_dset, amp_weight_dset = create_soltab_datasets(soltab, gaintable)

        # Phase table
        soltab = create_soltab_group(solset, "phase")
        phase_val_dset, phase_weight_dset = create_soltab_datasets(soltab, gaintable)

        return {
            "amp_val_dset": amp_val_dset.name,
            "amp_weight_dset": amp_weight_dset.name,
            "phase_val_dset": phase_val_dset.name,
            "phase_weight_dset": phase_weight_dset.name,
        }

In [None]:
def write_gaintable(filename, gaintable):
    ## Creating file with metadata

    metadata = create_file_with_meta(filename, gaintable)

    ### Write actual data into h5 file concurrently.

    freq_index = xr.DataArray(
        np.arange(gaintable.freq.size), coords={"freq": gaintable.freq}
    ).pipe(with_chunks, gaintable.chunksizes)

    gain_val = np.absolute(gaintable["gain"])
    gain_weight = gaintable["weight"]

    phase_val = xr.DataArray(
        da.angle(gaintable["gain"].data), coords=gaintable["gain"].coords
    )

    phase_weight = gaintable["weight"]

    # required as h5py file can't be opened simultaneously by workers.
    sem = Semaphore(max_leases=1, name="gaintable-file")

    return dask.delayed(lambda *ufuncs: ufuncs)(
        [
            xr.apply_ufunc(
                write_to_file,
                gain_val,
                freq_index,
                dask="parallelized",
                output_dtypes=[gain_val.dtype],
                kwargs={
                    "filename": filename,
                    "ds_name": metadata["amp_val_dset"],
                    "sem": sem,
                },
            ),
            xr.apply_ufunc(
                write_to_file,
                gain_weight,
                freq_index,
                dask="parallelized",
                output_dtypes=[gain_weight.dtype],
                kwargs={
                    "filename": filename,
                    "ds_name": metadata["amp_weight_dset"],
                    "sem": sem,
                },
            ),
            xr.apply_ufunc(
                write_to_file,
                phase_val,
                freq_index,
                dask="parallelized",
                output_dtypes=[phase_val.dtype],
                kwargs={
                    "filename": filename,
                    "ds_name": metadata["phase_val_dset"],
                    "sem": sem,
                },
            ),
            xr.apply_ufunc(
                write_to_file,
                phase_weight,
                freq_index,
                dask="parallelized",
                output_dtypes=[phase_weight.dtype],
                kwargs={
                    "filename": filename,
                    "ds_name": metadata["phase_weight_dset"],
                    "sem": sem,
                },
            ),
        ]
    )

In [None]:
import logging

from dask.delayed import Delayed

logger = logging.getLogger()


def export_gaintable_to_h5parm_dask(gaintable, filename: str, squeeze: bool = False):
    """Export a GainTable to a H5Parm HDF5 file.

    :param gaintable: GainTable
    :param filename: Name of H5Parm file
    :param squeeze: If True, remove axes of length one from dataset
    """
    logger.info(f"exporting cal solutions to {filename}")

    # check gaintable gain and weight dimensions
    expected_gaintable_dims = ("time", "antenna", "frequency", "receptor1", "receptor2")
    if gaintable.gain.dims != expected_gaintable_dims:
        raise ValueError(f"Unexpected dims: {gaintable.gain.dims}")

    # adjust dimensions to be consistent with H5Parm output format
    gaintable = gaintable.rename({"antenna": "ant", "frequency": "freq"})
    gaintable = gaintable.stack(pol=("receptor1", "receptor2"))

    polstrs = _ndarray_of_null_terminated_bytes(
        [f"{p1}{p2}" for p1, p2 in gaintable["pol"].data]
    )
    gaintable = gaintable.assign_coords({"pol": polstrs})

    # check polarisations and discard unused terms
    polstrs = _ndarray_of_null_terminated_bytes(["XX", "XY", "YX", "YY"])
    if not np.array_equal(gaintable["pol"].data, polstrs):
        raise ValueError("Subsequent pipelines assume linear pol order")

    # # TODO: Delay this
    # if np.sum(np.abs(gaintable.isel(pol=[1, 2]).weight.data)) == 0:
    #     gaintable = gaintable.isel(pol=[0, 3])

    # replace antenna indices with antenna names
    if gaintable.configuration is None:
        raise ValueError("Missing gt config. H5Parm requires antenna names")
    antenna_names = _ndarray_of_null_terminated_bytes(
        gaintable.configuration.names.data[gaintable["ant"].data]
    )
    gaintable = gaintable.assign_coords({"ant": antenna_names})

    # remove axes of length one if required
    if squeeze:
        gaintable = gaintable.squeeze(drop=True)

    logger.info(f"output dimensions: {dict(gaintable.gain.sizes)}")

    return write_gaintable(filename, gaintable)

### Testing

In [None]:
import os

actual_h5_path = os.path.abspath(f"{gaintable_dir}/actual.h5parm")

result = export_gaintable_to_h5parm_dask(gaintable, actual_h5_path)
result

### Run

In [None]:
from distributed import Client

client = Client(processes=False)

In [None]:
result.compute()

## Read and verify

In [None]:
with h5py.File(f"{gaintable_dir}/actual.h5parm", "r") as actual:
    actual_amp_grp = actual["sol000"]["amplitude000"]
    actual_amp_val = actual_amp_grp["val"][()]
    actual_amp_freq = actual_amp_grp["freq"][()]
    actual_amp_pol = actual_amp_grp["pol"][()]
    actual_amp_time = actual_amp_grp["time"][()]
    actual_amp_weight = actual_amp_grp["weight"][()]

    actual_phase_grp = actual["sol000"]["phase000"]
    actual_phase_val = actual_phase_grp["val"][()]
    actual_phase_freq = actual_phase_grp["freq"][()]
    actual_phase_pol = actual_phase_grp["pol"][()]
    actual_phase_time = actual_phase_grp["time"][()]
    actual_phase_weight = actual_phase_grp["weight"][()]

In [None]:
with h5py.File(f"{gaintable_dir}/expected.h5parm", "r") as expected:
    # print(expected['sol000']["phase000"].keys())
    expected_amp_grp = expected["sol000"]["amplitude000"]
    # print(expected_amp_grp.keys())
    expected_amp_val = expected_amp_grp["val"][()]
    expected_amp_freq = expected_amp_grp["freq"][()]
    expected_amp_pol = expected_amp_grp["pol"][()]
    expected_amp_time = expected_amp_grp["time"][()]
    expected_amp_weight = expected_amp_grp["weight"][()]

    # print()
    expected_phase_grp = expected["sol000"]["phase000"]
    expected_phase_val = expected_phase_grp["val"][()]
    expected_phase_freq = expected_phase_grp["freq"][()]
    expected_phase_pol = expected_phase_grp["pol"][()]
    expected_phase_time = expected_phase_grp["time"][()]
    expected_phase_weight = expected_phase_grp["weight"][()]

In [None]:
np.testing.assert_allclose(expected_amp_val, actual_amp_val)
np.testing.assert_allclose(expected_amp_freq, actual_amp_freq)
np.testing.assert_allclose(expected_amp_time, actual_amp_time)
np.testing.assert_allclose(expected_amp_weight, actual_amp_weight)


np.testing.assert_allclose(expected_phase_val, actual_phase_val)
np.testing.assert_allclose(expected_phase_freq, actual_phase_freq)
np.testing.assert_allclose(expected_phase_time, actual_phase_time)
np.testing.assert_allclose(expected_phase_weight, actual_phase_weight)