# Refactor Predict

## Imports

In [None]:
import sys

eb_coeffs = "/home/maneesh/.cache/pypoetry/virtualenvs/ska-sdp-instrumental-calibration-vujiG8jS-py3.10/share/everybeam/"

# sys.path.insert(0, "/home/maneesh/Work/SKAO/EveryBeam/build/python")
# eb_coeffs = "/home/maneesh/Work/SKAO/EveryBeam/coeffs"

In [None]:
import gc
import importlib
import os
import shutil
from typing import Literal, Optional, Union

import dask
import dask.array as da
from distributed import Client, LocalCluster, performance_report
import everybeam as eb
import numpy as np
import numpy.typing as npt
import xarray as xr
from astropy import constants as const
from astropy.coordinates import ITRS, AltAz, SkyCoord
from astropy.time import Time
from ska_sdp_datamodels.calibration import GainTable
from ska_sdp_datamodels.configuration import Configuration
from ska_sdp_datamodels.sky_model import SkyComponent
from ska_sdp_datamodels.visibility import Visibility
from ska_sdp_func_python.imaging.dft import dft_skycomponent_visibility
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    predict_vis,
    prediction_central_beams,
    restore_baselines_dim,
    simplify_baselines_dim,
)
from ska_sdp_instrumental_calibration.data_managers.visibility import (
    read_dataset_from_zarr,
)
from ska_sdp_instrumental_calibration.logger import setup_logger
from ska_sdp_instrumental_calibration.processing_tasks.beams import radec_to_xyz
from ska_sdp_instrumental_calibration.processing_tasks.lsm import (
    Component,
    convert_model_to_skycomponents,
    generate_lsm_from_csv,
    generate_lsm_from_gleamegc,
)
from ska_sdp_instrumental_calibration.processing_tasks.predict import (
    GenericBeams,
    dft_skycomponent_local,
    gaussian_tapers,
    generate_rotation_matrices,
    predict_from_components,
)
from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput
from ska_sdp_instrumental_calibration.workflow.stages.load_data import load_data_stage
from ska_sdp_instrumental_calibration.workflow.utils import (
    create_bandpass_table,
    with_chunks,
)
from xarray.core.utils import Frozen
from ska_sdp_instrumental_calibration.workflow.utils import create_solint_slices

from ska_sdp_piper.piper.utils.log_util import LogPlugin

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    apply_gaintable_to_dataset,
)

from ska_sdp_datamodels.science_data_model import ReceptorFrame

from ska_sdp_instrumental_calibration.workflow.utils import (
    get_indices_from_grouped_bins,
    get_intervals_from_grouped_bins,
)


np.random.seed(57)

logger = setup_logger("predict_ipynb")

## Setup Dask (optional)

In [None]:
# # if "local_cluster" not in globals():
# #     local_cluster = LocalCluster(
# #         n_workers=4, threads_per_worker=3, dashboard_address=":30088"
# #     )

# # if "dask_client" not in globals():
# #     # dask_client = local_cluster.get_client()
dask_client = Client("localhost:34567")
dask_client.forward_logging()
log_configure_plugin = LogPlugin(verbose=False)
dask_client.register_plugin(log_configure_plugin)

# # print(local_cluster.dashboard_link)

In [None]:
# if 'dask_client' in globals():
#     dask_client.close()
#     del dask_client

# if 'local_cluster' in globals():
#     local_cluster.close()
#     del local_cluster

## Utils

In [None]:
def compare_arrays(
    actual: np.ndarray,
    expected: np.ndarray,
    rtol=1e-6,
    atol=1e-12,
    meta="Data",
    output: Literal["print", "log", "raise"] = "print",
):
    """
    Compare two arrays (NumPy or Dask) element-wise with tolerance thresholds.

    This function checks absolute and relative differences between two arrays and
    reports the maximum absolute difference, maximum relative difference, and the
    number and percentage of elements that differ beyond the specified tolerances.

    If either input is a Dask array, all required values are finalized in a single
    `dask.compute` call to minimize graph execution overhead.

    Parameters
    ----------
    actual : numpy.ndarray or dask.array.Array
        The array containing computed or observed values.
    expected : numpy.ndarray or dask.array.Array
        The reference array to compare against.
    rtol : float, default 1e-6
        Relative tolerance.
    atol : float, default 1e-12
        Absolute tolerance.
    meta : str, default "Data"
        Label used for output messages.
    output: str, Literal["print", "log", "raise"]
        Determines how to present the output
        print: Prints the comparision message
        log: Logs the comparision message
        raise: Raises AssertionError if values don't match
    """
    actions = {
        "print": lambda level, m: print(m),
        "log": lambda level, m: (
            logger.error(m) if level == "error" else logger.info(m)
        ),
        "raise": lambda level, m: (
            (_ for _ in ()).throw(AssertionError(m)) if level == "error" else None
        ),
    }

    is_dask = isinstance(actual, da.Array) or isinstance(expected, da.Array)

    if actual.dtype != expected.dtype:
        actions[output](
            "info",
            f"{meta}: dtype mismatch : actual: {actual.dtype}, expected: {expected.dtype}.\n"
            f"Comparison may be influenced by different precision.",
        )

    diff = actual - expected

    abs_diff = np.abs(diff)
    abs_diff_max = abs_diff.max()

    # element-wise relative difference (safe for expected == 0)
    rel_diff = abs_diff / np.maximum(np.abs(expected), 1e-30)
    rel_diff_max = rel_diff.max()

    # Decide rule for comparision
    # diff_mask = (abs_diff > atol) & (rel_diff > rtol) # More permissive
    diff_mask = abs_diff > (atol + rtol * np.abs(expected))  # numpy style comparision
    num_diff = diff_mask.sum()

    if is_dask:
        abs_diff_max, rel_diff_max, num_diff = da.compute(
            abs_diff_max, rel_diff_max, num_diff
        )

    total = actual.size

    if num_diff > 0:
        msg = (
            f"{meta}: do not match for atol={atol}, rtol={rtol}\n"
            f"\tmax abs diff = {abs_diff_max}\n"
            f"\tmax rel diff = {rel_diff_max}\n"
            f"\tdifferent elements: {num_diff} / {total} ({num_diff / total * 100:.6f}%)"
        )
        actions[output]("error", msg)
    else:
        msg = f"{meta}: match within atol={atol}, rtol={rtol}"
        actions[output]("info", msg)

In [None]:
# def create_gaintable_from_vis_new(
#     vis: Visibility,
#     timeslice: Union[float, Literal["auto"], None] = None,
#     jones_type: Literal["T", "G", "B"] = "T",
# ):
#     # TODO: Write full logic
#     # Reference the original create_gaintable_from_visibility function

#     # TODO: To simplify dask operations,
#     # if timeslice is provided such that 1 < len(solution_time) < len(vis.time)
#     # then still the gaintable must have len(solution_time) = len(vis.time)
#     # where the solution_time values are duplicated
#     # this can be optimized in the future

#     gaintable = create_bandpass_table(vis)
#     # Since its bandpass table, single interval contains whole time chunk
#     gaintable.attrs["soln_interval_slices"] = create_solint_slices(
#         vis.time, timeslice, True
#     )

#     # This should be part of the dim creation logic itself
#     gaintable = gaintable.rename(time="solution_time", frequency="solution_frequency")

#     # # Following chunking logics should be part of the data creation logic itself
#     # This is when we are sure that rest of the code can handle different solutions times
#     gaintable = gaintable.chunk({"solution_time": 1})

#     # if gaintable.jones_type == "B":
#     if gaintable.solution_frequency.size == vis.frequency.size:
#         gaintable = gaintable.chunk({"solution_frequency": vis.chunksizes["frequency"]})

#     return gaintable

In [None]:
# # Test for the new gain interval logic
# time_diff = np.diff(vis.time.data)[0]

# # Full time in single timeslice
# timeslice = np.max(vis.time) - np.min(vis.time)

# gain_time_bins = create_solint_slices(vis.time, timeslice, False)
# gain_time = gain_time_bins.mean().data

# gain_interval = get_intervals_from_grouped_bins(gain_time_bins)


# idx = 0
# time = gain_time[idx]

# time_slice = {
#         "time": slice(
#             time - gain_interval[idx] / 2,
#             time + gain_interval[idx] / 2,
#         )
#     }

# assert np.all(vis.time.sel(time_slice).data == vis.time.data)

In [None]:
def create_gaintable_from_vis_new(
    vis: Visibility,
    timeslice: Union[float, Literal["auto"], None] = None,
    jones_type: Literal["T", "G", "B"] = "T",
    lower_precision: bool = True,
):
    """
    Similar behavior as create_gaintable_from_vis, except
    1. Creates dask backed data variables
    2. Ability to toggle precision of the data variables, currently between 4 or 8 bytes
    """
    # Backward compatibility
    if timeslice == "auto":
        timeslice = None
    # TODO: review this time slice creation logic
    gain_time_bins = create_solint_slices(vis.time, timeslice, False)
    gain_time = gain_time_bins.mean().data
    gain_interval = get_intervals_from_grouped_bins(gain_time_bins)
    ntimes = len(gain_time)

    nants = vis.visibility_acc.nants

    # Set the frequency sampling
    if jones_type == "B":
        gain_frequency = vis.frequency.data
        nfrequency = len(gain_frequency)
    elif jones_type in ("G", "T"):
        gain_frequency = np.mean(vis.frequency.data, keepdims=True)
        nfrequency = 1
    else:
        raise ValueError(f"Unknown Jones type {jones_type}")

    # There is only one receptor frame in Visibility
    # Use it for both receptor1 and receptor2
    receptor_frame = ReceptorFrame(vis.visibility_acc.polarisation_frame.type)
    nrec = receptor_frame.nrec

    gain_shape = [ntimes, nants, nfrequency, nrec, nrec]

    # Create dask backed data variables
    comp_dtype, fl_dtype = np.complex128, np.float64
    if lower_precision:
        comp_dtype, fl_dtype = np.complex64, np.float32

    gain = da.broadcast_to(da.eye(nrec, dtype=comp_dtype), gain_shape)
    gain_weight = da.ones(gain_shape, dtype=fl_dtype)
    gain_residual = da.zeros([ntimes, nfrequency, nrec, nrec], dtype=fl_dtype)

    gain_table = GainTable.constructor(
        gain=gain,
        time=gain_time,
        interval=gain_interval,
        weight=gain_weight,
        residual=gain_residual,
        frequency=gain_frequency,
        receptor_frame=receptor_frame,
        phasecentre=vis.phasecentre,
        configuration=vis.configuration,
        jones_type=jones_type,
    )

    # Rename dimensions to be more explicity about their usage
    gain_table = gain_table.rename(time="solution_time", frequency="solution_frequency")

    # Chunk data variables
    gain_table = gain_table.chunk({"solution_time": 1})
    if gain_table.solution_frequency.size == vis.frequency.size:
        gain_table = gain_table.chunk(
            {"solution_frequency": vis.chunksizes["frequency"]}
        )

    # Logic duplicated from create_solint_slices
    gain_table.attrs["soln_interval_slices"] = get_indices_from_grouped_bins(
        gain_time_bins
    )

    return gain_table

## Setup inputs

In [None]:
input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.small.ms"
# input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.ms"
# input_ms_path = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/data/demo.ms"

vis_cache_dir = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/cache"

# eb_ms = "/home/ska/Work/data/INST/sim/OSKAR_MOCK.ms"
eb_ms = input_ms_path

lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"
gleamfile = "/home/ska/Work/data/INST/sim/gleamegc.dat"

fov = 10.0
flux_limit = 1.0
alpha0 = -0.78

In [None]:
upstream_output = UpstreamOutput()

nchannels_per_chunk = 64
ntimes_per_ms_chunk = 16

upstream_output = load_data_stage.stage_definition(
    upstream_output,
    nchannels_per_chunk,
    ntimes_per_ms_chunk,
    vis_cache_dir,
    False,
    "DATA",
    0,
    0,
    # None, # NOTE: DHR-338
    # None, # NOTE: DHR-338
    # None, # NOTE: DHR-338
    {"input": input_ms_path},
    ".",
)
vis = upstream_output.vis

# vis = read_dataset_from_zarr("/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/dhr_338cache/cal_bpp_vis-lg3-rotated.small.ms_fid0_ddid0", {})

# NOTE: DHR-338
# Explcitly load it, in case its a dask array
vis.antenna1.load()
vis.antenna2.load()

timeslice = np.max(vis.time.data) - np.min(vis.time.data)
gaintable = create_gaintable_from_vis_new(
    vis, timeslice=timeslice, jones_type="B", lower_precision=True
)

lsm = generate_lsm_from_gleamegc(
    gleamfile=gleamfile,
    phasecentre=vis.phasecentre,
    fov=fov,
    flux_limit=flux_limit,
    alpha0=alpha0,
)

# lsm = generate_lsm_from_csv(
#     csvfile=lsm_csv_path,
#     phasecentre=vis.phasecentre,
#     fov=fov,
#     flux_limit=flux_limit,
# )

In [None]:
beam_type = "everybeam"
# beam_type = None

In [None]:
station_rm = da.from_array(np.random.rand(gaintable.antenna.size))
station_rm_xdr = xr.DataArray(
    station_rm, name="station_rm", coords={"antenna": gaintable.antenna}
)

# station_rm = None
# station_rm_xdr = None

## Generate Rotation Matrices

In [None]:
def generate_rotation_matrices_new(
    rm: np.ndarray,
    frequency: np.ndarray,
    output_dtype: type = np.complex64,
) -> np.ndarray:
    """Generate station rotation matrix from RM values.

    :param rm: 1D array of rotation measure values [nstation].
    :param frequency: 1D array of frequency values [nfrequency].
    :param output_dtype: output dtype of rotation matrix

    :return: 4D array of rotation matrix: [nstation, nfrequency, 2, 2].
    """
    lambda_sq = np.power((const.c.value / frequency), 2)

    phi = rm[..., np.newaxis] * lambda_sq

    cos_val = np.cos(phi)
    sin_val = np.sin(phi)

    I = np.array([[1, 0], [0, 1]])
    A = np.array([[0, -1], [1, 0]])

    rot_array = (
        cos_val[:, :, np.newaxis, np.newaxis] * I
        + sin_val[:, :, np.newaxis, np.newaxis] * A
    )

    return rot_array.astype(output_dtype)

## Generic Beams

In [None]:
class PointingBelowHorizon(Exception):
    pass


class BeamsLow:
    """A beam class specific to handling low beams."""

    def __init__(
        self,
        configuration: Configuration,
        direction: SkyCoord,
        frequency: np.ndarray,
        ms_path: str,
        soln_time: np.ndarray | float,
    ):
        self.nstations = configuration.id.size
        self.array_location = configuration.location
        self.beam_direction = direction
        self.frequency = frequency
        self.beam_ms = ms_path

        self.delay_dir_itrf = None
        self.telescope = None
        self.scale = None

        self.solution_time = self.convert_time_to_solution_time(soln_time)
        # NOTE: HACK: Uncomment below line, and run the pipeline
        # as a check to remove subtle floating point differences
        # between datetime64 and mjd time values
        # self.solution_time = np.mean(self.convert_time_to_solution_time(vis.time.data), keepdims=True)

        self.solution_time_mjd_seconds = self.solution_time.mjd * 86400

        # Check beam pointing direction for all solution times
        self.validate_direction(self.beam_direction)

        # Coordinates of beam centre
        self.delay_dir_itrf = np.array(
            [
                radec_to_xyz(self.beam_direction, self.solution_time[time_idx])
                for time_idx in range(self.solution_time.size)
            ]
        )

        self.telescope = eb.load_telescope(self.beam_ms)

        self.scale = np.ones(
            (self.solution_time.size, self.frequency.size), dtype=self.frequency.dtype
        )
        if type(self.telescope) is eb.OSKAR:
            """
            Set normalisation scaling to the Frobenius norm of the zenith
            response divided by sqrt(2).
            Should be the same for all stations so pick one. Should use the
            station location rather than central array location, e.g. using
            the following code, but some functions (e.g. ska-sdp-datamodels
            function create_named_configuration -- at least for some
            configurations) set xyz coordinates to ENU rather than the
            geocentric coordinates. So use the array location and a central
            station for now. Note that OSKAR datasets have correct geocentric
            coordinates, but also have the array location set to the first
            station xyz, so using array_location with stn=0 works.
                xyz = vis.configuration.xyz.data[stn, :]
                self.antenna_locations.append(
                    EarthLocation.from_geocentric(
                        xyz[0], xyz[1], xyz[2], unit="m",
                    )
                )
            """
            logger.info("Setting beam normalisation for OSKAR data")

            stn = 0
            for time_idx in range(self.solution_time.size):
                dir_itrf_zen = radec_to_xyz(
                    SkyCoord(
                        alt=90,
                        az=0,
                        unit="deg",
                        frame="altaz",
                        obstime=self.solution_time[time_idx],
                        location=self.array_location,
                    ),
                    self.solution_time[time_idx],
                )

                for chan, freq in enumerate(self.frequency):
                    J = self.telescope.station_response(
                        self.solution_time_mjd_seconds[time_idx],
                        stn,
                        freq,
                        dir_itrf_zen,
                        dir_itrf_zen,
                    )
                    self.scale[time_idx, chan] = np.sqrt(2) / np.linalg.norm(J)

    @staticmethod
    def convert_time_to_solution_time(time: float | np.ndarray) -> Time:
        """
        :return Time object containing an array of time values
        """
        # This Time->Time conversion is what was originally done in INST
        # Although Time can be treated as a "Value Object", there are subtle
        # differences in the order in which floating point operations are performed
        # thus the results may slightly vary if you try to simplify this
        # By removeing the additional datetime64 to time conversion
        solution_time = Time(Time(time / 86400.0, format="mjd", scale="utc").datetime64)
        return Time([solution_time]) if solution_time.isscalar else solution_time

    def validate_direction(self, direction: SkyCoord) -> AltAz:
        """
        Calculate Altitude-Azimuth for a direction, and ensure that the
        direction is valid for the given solution interval of the beam

        raises PointingBelowHorizon exception if direction is below horizon

        :return Altaz[list] if direction is valid
        """
        altaz = direction.transform_to(
            AltAz(obstime=self.solution_time, location=self.array_location)
        )
        # Since self.solution_time is always an Time[list], altaz will also be a AltAz[list]
        if (altaz.alt.degree < 0).any():
            # TODO: More verbose error
            raise PointingBelowHorizon(
                f"Pointing below horizon for some of the solution times"
            )
        return altaz

    def array_response(
        self,
        direction: SkyCoord,
    ) -> np.ndarray:
        """Return the response of each antenna or station in a given direction

        :param direction: Direction of desired response
        :return: np.complex128 array of beam matrices [soln_time, nant, nfreq, 2, 2]
        """
        # Get the component direction in ITRF
        dir_itrf = np.array(
            [
                radec_to_xyz(direction, self.solution_time[time_idx])
                for time_idx in range(self.solution_time.size)
            ]
        )

        beams = np.empty(
            (self.solution_time.size, self.nstations, self.frequency.size, 2, 2),
            dtype=np.complex128,
        )

        for time_idx in range(self.solution_time.size):
            for stn in range(self.nstations):
                for chan, freq in enumerate(self.frequency):
                    beams[time_idx, stn, chan, :, :] = (
                        self.telescope.station_response(
                            self.solution_time_mjd_seconds[time_idx],
                            stn,
                            freq,
                            dir_itrf[time_idx],
                            self.delay_dir_itrf[time_idx],
                        )
                        * self.scale[time_idx, chan]
                    )

        return beams

## DFT

In [None]:
def gaussian_tapers_new(
    u: np.ndarray,
    v: np.ndarray,
    params: dict[float],
) -> np.ndarray:
    """Calculated visibility amplitude tapers for Gaussian components.

    Note: this needs to be tested. Generate, image and fit a model component?
    """
    # exp(-a*x^2) transforms to exp(-pi^2*u^2/a)
    # a = 4log(2)/FWHM^2 so scaling = pi^2 * FWHM^2 / (4log(2))
    scale = -(np.pi * np.pi) / (4 * np.log(2.0))
    # Rotate baselines to the major/minor axes:
    bpa = params["bpa"] * np.pi / 180
    bmaj = params["bmaj"] * np.pi / 180
    bmin = params["bmin"] * np.pi / 180

    up = np.cos(bpa) * u + np.sin(bpa) * v
    vp = -np.sin(bpa) * u + np.cos(bpa) * v

    return np.exp((bmaj * bmaj * up * up + bmin * bmin * vp * vp) * scale)


def dft_skycomponent_new(
    uvw: np.ndarray,
    skycomponent: SkyComponent,
    phase_centre: SkyCoord,
) -> np.ndarray:
    """
    uvw: (time, baselineid, spatial)
    skycomponent.frequency: (frequency,)
    skycomponent.flux: (frequency, polarisation)

    returns: (time, baselineid, frequency, polarisation)
    """

    scaled_uvw = np.einsum(
        "tbs,f->tbfs",
        uvw,
        skycomponent.frequency / const.c.value,  # pylint: disable=no-member
    )
    scaled_u = scaled_uvw[..., 0]
    scaled_v = scaled_uvw[..., 1]
    scaled_w = scaled_uvw[..., 2]

    # Get coordaintes of phase centre
    ra0 = phase_centre.ra.radian
    cdec0 = np.cos(phase_centre.dec.radian)
    sdec0 = np.sin(phase_centre.dec.radian)

    cdec = np.cos(skycomponent.direction.dec.radian)
    sdec = np.sin(skycomponent.direction.dec.radian)
    cdra = np.cos(skycomponent.direction.ra.radian - ra0)
    l_comp = cdec * np.sin(skycomponent.direction.ra.radian - ra0)
    m_comp = sdec * cdec0 - cdec * sdec0 * cdra
    n_comp = sdec * sdec0 + cdec * cdec0 * cdra

    comp_data = np.exp(
        -2j * np.pi * (scaled_u * l_comp + scaled_v * m_comp + scaled_w * (n_comp - 1))
    )

    if skycomponent.shape == "GAUSSIAN":
        comp_data = comp_data * gaussian_tapers_new(
            scaled_u, scaled_v, skycomponent.params
        )

    return np.einsum(
        "tbf,fp->tbfp",
        comp_data,
        skycomponent.flux,
    )

## Predict from components

In [None]:
def apply_antenna_gains_to_visibility(vis, gains, antenna1, antenna2, inverse=False):
    """
    vis: (time, baselineid, frequency, polarisation)
    gains: (time, antennas, frequency, nrec1, nrec2)
    antenna1: (baselineid)
        Indices of the antenna1 in all baseline pairs
    antenna2: (baselineid)
        Indices of the antenna2 in all baseline pairs
    inverse: bool
        Whether to inverse the gains before applying
    """
    if inverse:
        gains = np.linalg.pinv(gains)

    vis_old_shape = vis.shape
    vis_new_shape = vis.shape[:3] + (2, 2)

    return np.einsum(  # pylint: disable=too-many-function-args
        "tbfpx,tbfxy,tbfqy->tbfpq",
        gains[:, antenna1, ...],
        vis.reshape(vis_new_shape),
        gains[:, antenna2, ...].conj(),
    ).reshape(vis_old_shape)

In [None]:
def predict_from_components_new(
    uvw: np.ndarray,
    frequency: np.ndarray,
    polarisation: np.ndarray,
    antenna1: np.ndarray,
    antenna2: np.ndarray,
    configuration: Configuration,
    phasecentre: SkyCoord,
    lsm: list[Component],
    beam_type: Optional[str] = "everybeam",
    eb_coeffs: Optional[str] = None,
    eb_ms: Optional[str] = None,
    soln_time: float = None,
    station_rm: np.ndarray = None,
    output_dtype: type = np.complex64,
) -> np.ndarray:
    """Predict model visibilities from a Component List.

    :param uvw: (time, baselineid, spatial)
    :param frequency: (frequency,)
    :param polarisation: (polarisation,)
    :param antenna1: (nant,)
    :param antenna2: (nant,)
    :param configuration: object
    :param phasecentre: object
    :param lsm: Component List containing the local sky model
    :param beam_type: str
        Type of beam model to use. Default is "everybeam". If set
        to None, no beam will be applied.
    :param eb_coeffs: str
        Everybeam coeffs datadir containing beam coefficients.
        Required if beam_type is "everybeam".
    :param eb_ms: str
        Measurement set need to initialise the everybeam telescope.
        Required if beam_type is "everybeam".
    :param soln_time: float
        "Solution time" value of the gain solution. Used for initialising Beams
        for that current time slice. Required if beam_type is "everybeam".
        Must be a single time value.
    :param station_rm: (nant,)
        Station rotation measure values. Default is None.
    :param output_dtype: Type


    returns: (time, baselineid, frequency, polarisation)
    """
    skycomponents = convert_model_to_skycomponents(lsm, frequency)
    # TODO: Do this check outside when we compute lsm
    # if len(skycomponents) == 0:
    #     logger.warning("No sky model components to predict")
    #     return
    # TODO : Do these checks outside
    # if len(station_rm) != len(configuration.id):
    #     raise ValueError("unexpected length for station_rm")

    # Set up the beam model
    if beam_type == "everybeam":
        logger.info("Using EveryBeam model in predict")
        if eb_coeffs is None or eb_ms is None:
            raise ValueError("eb_coeffs and eb_ms required for everybeam")

        if soln_time is None:
            raise ValueError("solution time must be provided for Beam calculation")

        # Could do this once externally, but don't want to pass around
        # exotic data types.
        os.environ["EVERYBEAM_DATADIR"] = eb_coeffs

        beams = BeamsLow(
            configuration=configuration,
            direction=phasecentre,
            frequency=frequency,
            ms_path=eb_ms,
            soln_time=soln_time,
        )
    else:
        logger.info("No beam model used in predict")

    # Set up the Faraday rotation model
    faraday_rot_matrix = None
    if station_rm is not None:
        faraday_rot_matrix = generate_rotation_matrices_new(
            station_rm, frequency, output_dtype
        )[
            np.newaxis, ...
        ]  # Add time axis at the start

    predicted_vis = np.zeros(
        (*uvw.shape[:2], frequency.size, polarisation.size), dtype=output_dtype
    )

    for comp in skycomponents:
        effective_antenna_response = None
        theta = 0

        # Apply beam distortions and add to combined model visibilities
        if beam_type == "everybeam":
            # Check component direction
            try:
                altaz = beams.validate_direction(comp.direction)[0]
            except PointingBelowHorizon:
                logger.warning("LSM component [%s] below horizon", comp.name)
                continue

            theta = np.pi / 2 - altaz.alt.radian

            # NOTE: This ID mapping will not always work when the eb_ms file is
            # different. Should restrict the form of the eb_ms files allowed,
            # or preferably deprecate the eb_ms option.
            component_array_response = beams.array_response(direction=comp.direction)[
                :, configuration.id
            ]

            if faraday_rot_matrix is not None:
                effective_antenna_response = (
                    component_array_response @ faraday_rot_matrix
                )
            else:
                effective_antenna_response = component_array_response
        else:
            effective_antenna_response = faraday_rot_matrix

        sky_comp_vis = dft_skycomponent_new(
            uvw=uvw, skycomponent=comp, phase_centre=phasecentre
        )

        if effective_antenna_response is not None:
            sky_comp_vis = apply_antenna_gains_to_visibility(
                sky_comp_vis,
                effective_antenna_response,
                antenna1,
                antenna2,
            )  # * (np.cos(theta) ** 2) # NOTE: DHR-338

        predicted_vis = predicted_vis + sky_comp_vis

    return predicted_vis

In [None]:
def null_sanitize(data: np.ndarray) -> np.ndarray | None:
    if (data == None).any():
        return None
    return data


def _predict_from_components_ufunc(
    uvw: np.ndarray,
    frequency: np.ndarray,
    station_rm: np.ndarray | None,
    polarisation: np.ndarray,
    antenna1: np.ndarray,
    antenna2: np.ndarray,
    configuration: Configuration,
    phasecentre: SkyCoord,
    lsm: list[Component],
    beam_type: Optional[str] = "everybeam",
    eb_coeffs: Optional[str] = None,
    eb_ms: Optional[str] = None,
    soln_time: float = None,
    output_dtype: type = np.complex64,
):
    """
    A helper function which bridges the gap between
    predict_from_components_new and predict_vis_new functions

    :param uvw: (time, frequency, baselineid, spatial)
    :param frequency: (frequency,)
    :param station_rm: (nant,) or None
    :param polarisation: (polarisation,)
    :param antenna1: (nant,)
    :param antenna2: (nant,)
    :param configuration: object
    :param phasecentre: object
    :param lsm: Component List containing the local sky model
    :param beam_type: str
        Type of beam model to use. Default is "everybeam". If set
        to None, no beam will be applied.
    :param eb_coeffs: str
        Everybeam coeffs datadir containing beam coefficients.
        Required if beam_type is "everybeam".
    :param eb_ms: str
        Measurement set need to initialise the everybeam telescope.
        Required if beam_type is "everybeam".
    :param soln_time: float
        "Solution time" value of the gain solution. Used for initialising Beams
        for that current time slice. Required if beam_type is "everybeam".
        Must be a single time value.
    :param output_dtype: Type

    returns: (time, frequency, baselineid, polarisation)
    """
    # Need to remove extra frequency dimension from uvw
    uvw = uvw.squeeze()

    return predict_from_components_new(
        uvw,
        frequency,
        polarisation,
        antenna1,
        antenna2,
        configuration,
        phasecentre,
        lsm,
        beam_type=beam_type,
        eb_coeffs=eb_coeffs,
        eb_ms=eb_ms,
        soln_time=soln_time,
        station_rm=station_rm,
        output_dtype=output_dtype,
    ).transpose(
        0, 2, 1, 3  #  time, frequency, baselineid, polarisation
    )


def predict_vis_new(
    vis: xr.Dataset,
    lsm: list,
    beam_type: Optional[str] = "everybeam",
    eb_ms: Optional[str] = None,
    eb_coeffs: Optional[str] = None,
    gaintable: GainTable = None,
    station_rm: xr.DataArray = None,
) -> xr.Dataset:
    """
    Distributed Visibility predict.
    Supports chunking across frequency and time.

    :param vis: Visibility dataset containing observed data to be modelled.
        Should be chunked in frequency.
    :param lsm: List of LSM components. This is an intermediate format between
        the GSM and the evaluated SkyComponent list.
    :param beam_type: Type of beam model to use. Default is "everybeam".
    :param eb_ms: Pathname of Everybeam mock Measurement Set.
    :param eb_coeffs: Path to Everybeam coeffs directory.
    :param station_rm: Station rotation measure values. Default is None.
    :return: Predicted Visibility dataset
    """
    common_input_args = []
    common_input_core_dims = []

    input_kwargs = dict(
        polarisation=vis.polarisation,
        antenna1=vis.antenna1,
        antenna2=vis.antenna2,
        configuration=vis.configuration,
        phasecentre=vis.phasecentre,
        lsm=lsm,
        beam_type=beam_type,
        eb_coeffs=eb_coeffs,
        eb_ms=eb_ms,
        output_dtype=vis.vis.dtype,
    )

    # Process frequency
    # Convert frequency to a chunked dask array
    frequency_xdr = xr.DataArray(vis.frequency, name="frequency_xdr").pipe(
        with_chunks, vis.chunksizes
    )
    common_input_args.append(frequency_xdr)
    common_input_core_dims.append([])

    # Process station_rm
    if station_rm is not None:
        # Ensure that it is not chunked across any dim
        # It can still be a dask array
        station_rm = station_rm.chunk(-1)
        common_input_args.append(station_rm)
        common_input_core_dims.append(list(station_rm.dims))
    else:
        input_kwargs["station_rm"] = None

    # Process beam related parameters
    if beam_type == "everybeam":
        # Do validations early on
        if any(x is None for x in (gaintable, eb_ms, eb_coeffs)):
            raise ValueError(
                "gaintable, eb_ms and eb_coeffs must be provided "
                "for beam_type = everybeam"
            )
        soln_time = gaintable.solution_time.data
        soln_interval_slices = gaintable.soln_interval_slices
        assert len(soln_interval_slices) == len(soln_time)
    else:
        # Set appropriate values so that predict function succeeds
        soln_time = [None]
        soln_interval_slices = [slice(0, vis.time.size + 1, 1)]

    # Call predict ufunc, once per solution interval
    predicted_across_soln_time = []
    for idx, slc in enumerate(soln_interval_slices):
        predicted_per_soln_time: xr.DataArray = xr.apply_ufunc(
            _predict_from_components_ufunc,
            vis.uvw.isel(time=slc),
            *common_input_args,
            input_core_dims=[
                ["baselineid", "spatial"],
                *common_input_core_dims,
            ],
            output_core_dims=[
                ["baselineid", "polarisation"],
            ],
            dask="parallelized",
            output_dtypes=[vis.vis.dtype],
            dask_gufunc_kwargs=dict(
                output_sizes={
                    "baselineid": vis.baselineid.size,
                    "polarisation": vis.polarisation.size,
                }
            ),
            kwargs=dict(
                **input_kwargs,
                soln_time=soln_time[idx],
            ),
        )
        predicted_per_soln_time = predicted_per_soln_time.transpose(
            "time", "baselineid", "frequency", "polarisation"
        )
        predicted_across_soln_time.append(predicted_per_soln_time)

    predicted: xr.DataArray = xr.concat(predicted_across_soln_time, dim="time")

    predicted = predicted.assign_attrs(vis.vis.attrs)
    return vis.assign({"vis": predicted})

## Compare predicted vis

In [None]:
expected_model_vis = predict_vis(
    vis.copy(deep=True).chunk(time=-1),
    lsm,
    beam_type="everybeam",
    eb_ms=eb_ms,
    eb_coeffs=eb_coeffs,
    station_rm=station_rm_xdr,
)

In [None]:
expected_model_vis_path = f"{os.getcwd()}/expected_model_vis.vis.zarr"
shutil.rmtree(expected_model_vis_path, ignore_errors=True)

writer = expected_model_vis.vis.to_zarr(
    expected_model_vis_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_model_vis = predict_vis_new(
    vis,
    lsm,
    beam_type="everybeam",
    eb_ms=eb_ms,
    eb_coeffs=eb_coeffs,
    gaintable=gaintable,
    station_rm=station_rm_xdr,
)

In [None]:
actual_model_vis_path = f"{os.getcwd()}/actual_model_vis.vis.zarr"
shutil.rmtree(actual_model_vis_path, ignore_errors=True)

writer = actual_model_vis.vis.to_zarr(actual_model_vis_path, mode="w", compute=False)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_model_vis_zarr = xr.open_dataarray(
    actual_model_vis_path, engine="zarr", chunks={}
)
expected_model_vis_zarr = xr.open_dataarray(
    expected_model_vis_path, engine="zarr", chunks={}
)

In [None]:
compare_arrays(
    np.real(actual_model_vis_zarr.data),
    np.real(expected_model_vis_zarr.data),
    rtol=1e-16,
    atol=0,
    meta="Real values",
)

compare_arrays(
    np.imag(actual_model_vis_zarr.data),
    np.imag(expected_model_vis_zarr.data),
    rtol=1e-16,
    atol=0,
    meta="Imag values",
)

## Prediction central beams

In [None]:
def generate_central_beams_new(
    soln_time: np.ndarray,
    frequency: np.ndarray,
    configuration: Configuration,
    phasecentre: SkyCoord,
    eb_ms: str,
    eb_coeffs: str,
):
    """
    soln_time: np.ndarray (solution_time,)
    frequency: np.ndarray, (frequency,)
    configuration: Configuration
    phasecentre: SkyCoord
    eb_ms: str
    eb_coeffs: str

    Returns
    -------
    np.ndarray (solution_time, antenna, frequency, nrec1, nrec2)
    """
    os.environ["EVERYBEAM_DATADIR"] = eb_coeffs

    beams = BeamsLow(
        configuration=configuration,
        direction=phasecentre,
        frequency=frequency,
        ms_path=eb_ms,
        soln_time=soln_time,
    )

    # NOTE: This ID mapping will not always work when the eb_ms file is
    # different. Should restrict the form of the eb_ms files allowed,
    # or preferably deprecate the eb_ms option.
    response = beams.array_response(direction=beams.beam_direction)[:, configuration.id]

    return response

In [None]:
def _generate_central_beams_ufunc(
    soln_time: np.ndarray,
    frequency: np.ndarray,
    configuration: Configuration,
    phasecentre: SkyCoord,
    eb_ms: str,
    eb_coeffs: str,
):
    """
    soln_time: np.ndarray (solution_time,)
    frequency: np.ndarray, (frequency,)
    configuration: Configuration
    phasecentre: SkyCoord
    eb_ms: str
    eb_coeffs: str

    Returns
    -------
    np.ndarray (solution_time, frequency, antenna, nrec1, nrec2)
    """
    # xarray adds a new dimension for broadcasting with frequency
    # need to remove it
    soln_time = soln_time.squeeze()

    return generate_central_beams_new(
        soln_time,
        frequency,
        configuration,
        phasecentre,
        eb_ms,
        eb_coeffs,
    ).transpose(
        0, 2, 1, 3, 4
    )  # time, frequency, antenna, rec1, rec2


def prediction_central_beams_new(
    vis: Visibility,
    gaintable: GainTable,
    beam_type: str = "everybeam",
    eb_ms=None,
    eb_coeffs=None,
) -> GainTable:
    """

    Returns
    -------
    Gaintable
    """
    if beam_type == "everybeam":
        # Do validations early on
        if any(x is None for x in (eb_ms, eb_coeffs)):
            raise ValueError(
                "eb_ms and eb_coeffs must be provided " "for beam_type = everybeam"
            )
        # convert solution time to a chunked dask array
        soln_time = (
            xr.DataArray(gaintable.solution_time)
            .rename("solution_time_xdr")
            .pipe(with_chunks, gaintable.chunksizes)
        )

        # need to calculate central beam response across entire frequency
        frequency_xdr = (
            xr.DataArray(vis.frequency, name="frequency_xdr")
            .pipe(with_chunks, vis.chunksizes)
            .rename(frequency="solution_frequency")
        )

        # response = generate_central_beams_new(
        #     gaintable.solution_time,
        #     vis.frequency.data,
        #     vis.configuration,
        #     vis.phasecentre,
        #     eb_ms,
        #     eb_coeffs,
        # )

        response: xr.DataArray = xr.apply_ufunc(
            _generate_central_beams_ufunc,
            soln_time,
            frequency_xdr,
            output_core_dims=[("antenna", "receptor1", "receptor2")],
            dask="parallelized",
            output_dtypes=[
                np.complex128,
            ],
            dask_gufunc_kwargs={
                "output_sizes": {
                    "antenna": gaintable.antenna.size,
                    "receptor1": gaintable.receptor1.size,
                    "receptor2": gaintable.receptor2.size,
                }
            },
            kwargs={
                "configuration": vis.configuration,
                "phasecentre": vis.phasecentre,
                "eb_ms": eb_ms,
                "eb_coeffs": eb_coeffs,
            },
        )

        response = response.transpose(
            "solution_time", "antenna", "solution_frequency", "receptor1", "receptor2"
        )

    else:
        logger.info("No beam model to predict central beams")

        # TODO : Verify whether this has to be zeros or eye
        response = xr.zeros_like(gaintable.gain)
        response[..., :, :] = np.eye(2)

    response = response.assign_coords(gaintable.gain.coords)
    response = response.assign_attrs(gaintable.gain.attrs)

    return gaintable.assign({"gain": response})

## Test central beams prediction

In [None]:
expected_beams = prediction_central_beams(
    vis.copy(deep=True).chunk(time=-1),
    beam_type=beam_type,
    eb_ms=eb_ms,
    eb_coeffs=eb_coeffs,
)

expected_beams.load();

In [None]:
actual_beams = prediction_central_beams_new(
    vis,
    gaintable,
    beam_type=beam_type,
    eb_ms=eb_ms,
    eb_coeffs=eb_coeffs,
)

actual_beams.load();

In [None]:
# actual_beams = actual_beams.astype(np.complex64)
# expected_beams = expected_beams.astype(np.complex64)

compare_arrays(
    np.real(actual_beams.gain.data.astype(np.complex64)),
    np.real(expected_beams.gain.data),
    rtol=1e-16,
    atol=0,
    meta="Real values",
)

compare_arrays(
    np.imag(actual_beams.gain.data.astype(np.complex64)),
    np.imag(expected_beams.gain.data),
    rtol=1e-16,
    atol=0,
    meta="Imag values",
)

compare_arrays(
    (actual_beams.gain.data.astype(np.complex64)),
    (expected_beams.gain.data),
    rtol=1e-16,
    atol=0,
    meta="Complex values",
)

## Apply antenna gains

In [None]:
def _apply_antenna_gains_to_visibility_ufunc(
    vis: np.ndarray,
    gains: np.ndarray,
    antenna1: np.ndarray,
    antenna2: np.ndarray,
    inverse: np.ndarray = False,
):
    """
    Parameters
    ----------
    vis: (time, frequency, baselineid, polarisation)
    gains: (frequency, antennas, nrec1, nrec2)
    antenna1: (baselineid)
        Indices of the antenna1 in all baseline pairs
    antenna2: (baselineid)
        Indices of the antenna2 in all baseline pairs
    inverse: bool
        Whether to inverse the gains before applying

    Apply ufunc brings distributed dimensions to the first
    e.g. vis becomes (time, frequency, baselineid, polarisation)
    So need to transpose them before calling this function
    This whole leads to double tranpose operations, which can be avoided
    if we just write a new function to support this order?
    Code will be duplicated, and handling of edge cases has to be done
    carefully in both functions
    OR force the other user of the apply_antenna_gains function to use
    this (time, frequency, baselineid, polarisation) order.

    This is an attempt at rewriting the same logic for this new order
    """
    if inverse:
        gains = np.linalg.pinv(gains)

    vis_old_shape = vis.shape
    vis_new_shape = vis.shape[:3] + (2, 2)

    return np.einsum(  # pylint: disable=too-many-function-args
        "fbpx,tfbxy,fbqy->tfbpq",
        gains[:, antenna1, ...],
        vis.reshape(vis_new_shape),
        gains[:, antenna2, ...].conj(),
    ).reshape(vis_old_shape)


def apply_gaintable_to_dataset_new(
    vis: Visibility,
    gaintable: GainTable,
    inverse=False,
) -> Visibility:
    gains = gaintable.gain
    if gaintable.jones_type == "B":
        gains = gains.rename({"solution_frequency": "frequency"})
        # solution frequency same as vis frequency
        # Chunking, just to be sure that they match
        gains = gains.chunk({"frequency": vis.chunksizes["frequency"]})
    else:  # jones_type == T or G
        assert gains.solution_frequency.size == 1, "Gaintable solution frequency"
        "must either match to visibility frequency, or must be of size 1"
        # Remove frequency dimension for appl_ufunc to work properly
        gains = gains.isel(solution_frequency=0, drop=True)

    soln_interval_slices = gaintable.soln_interval_slices

    applied_vis_across_solutions = []
    for idx, slc in enumerate(soln_interval_slices):
        applied_vis_per_soln_interval = xr.apply_ufunc(
            _apply_antenna_gains_to_visibility_ufunc,
            vis.vis.isel(time=slc),
            gains.isel(solution_time=idx, drop=True),
            input_core_dims=[
                ["baselineid", "polarisation"],
                ["antenna", "receptor1", "receptor2"],
            ],
            output_core_dims=[
                ["baselineid", "polarisation"],
            ],
            dask="parallelized",
            output_dtypes=[vis.vis.dtype],
            dask_gufunc_kwargs=dict(
                output_sizes={
                    "baselineid": vis.baselineid.size,
                    "polarisation": vis.polarisation.size,
                }
            ),
            kwargs={
                "antenna1": vis.antenna1,
                "antenna2": vis.antenna2,
                "inverse": inverse,
            },
        )
        applied_vis_per_soln_interval = applied_vis_per_soln_interval.transpose(
            "time", "baselineid", "frequency", "polarisation"
        )
        applied_vis_across_solutions.append(applied_vis_per_soln_interval)

    applied: xr.DataArray = xr.concat(applied_vis_across_solutions, dim="time")

    applied = applied.assign_attrs(vis.vis.attrs)
    return vis.assign({"vis": applied})

## Test apply central beams

In [None]:
beams_to_apply = expected_beams.copy(deep=True)
# Applying actual beams itself to test this function in isolation
beams_to_apply.gain.data = actual_beams.gain.data
beams_to_apply = beams_to_apply.chunk(frequency=vis.chunksizes["frequency"])

expected_beam_corr_vis = apply_gaintable_to_dataset(
    vis.copy(deep=True).chunk(time=-1), beams_to_apply, inverse=True
)

In [None]:
expected_beam_corr_vis_path = f"{os.getcwd()}/expected_beam_corr_vis.vis.zarr"
shutil.rmtree(expected_beam_corr_vis_path, ignore_errors=True)

writer = expected_beam_corr_vis.vis.to_zarr(
    expected_beam_corr_vis_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_beam_corr_vis = apply_gaintable_to_dataset_new(vis, actual_beams, inverse=True)

In [None]:
actual_beam_corr_vis_path = f"{os.getcwd()}/actual_beam_corr_vis.vis.zarr"
shutil.rmtree(actual_beam_corr_vis_path, ignore_errors=True)

writer = actual_beam_corr_vis.vis.to_zarr(
    actual_beam_corr_vis_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_beam_corr_vis_zarr = xr.open_dataarray(
    actual_beam_corr_vis_path, engine="zarr", chunks={}
)
expected_beam_corr_vis_zarr = xr.open_dataarray(
    expected_beam_corr_vis_path, engine="zarr", chunks={}
)

In [None]:
compare_arrays(
    np.real(actual_beam_corr_vis_zarr.data),
    np.real(expected_beam_corr_vis_zarr.data),
    rtol=1e-32,
    atol=0,
    meta="Real values",
)

compare_arrays(
    np.imag(actual_beam_corr_vis_zarr.data),
    np.imag(expected_beam_corr_vis_zarr.data),
    rtol=1e-32,
    atol=0,
    meta="Imag values",
)

compare_arrays(
    (actual_beam_corr_vis_zarr.data),
    (expected_beam_corr_vis_zarr.data),
    rtol=1e-32,
    atol=0,
    meta="Complex values",
)