In [None]:
import os
import xarray as xr
import numpy
import os
import dask
# dask.config.set(scheduler='threads')  # or 'synchronous' for step-by-step

# If you want a dashboard:
# from dask.distributed import Client
# client = Client()
from ska_sdp_instrumental_calibration.workflow.stages import (
    load_data_stage,
    predict_vis_stage,
    bandpass_calibration_stage
)

from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput
from copy import deepcopy
from ska_sdp_instrumental_calibration.data_managers.visibility import _generate_file_paths_for_vis_zarr_file
import pickle

from ska_sdp_instrumental_calibration.workflow.utils import with_chunks

import logging
log = logging.getLogger("func-python-logger")

## Helper functions

In [None]:
def export_model_vis(vis, output_dir, zarr_chunks ,):
    os.makedirs(output_dir, exist_ok=True)
    
    attributes_file, baselines_file, vis_zarr_file = (
        _generate_file_paths_for_vis_zarr_file(output_dir)
    )
    attrs = deepcopy(vis.attrs)
    with open(attributes_file, "wb") as file:
        pickle.dump(attrs, file)

    baselines = deepcopy(vis.baselines).compute()
    with open(baselines_file, "wb") as file:
        pickle.dump(baselines, file)

    writer = (
        vis.drop_attrs()
        .drop_vars("baselines")
        .pipe(with_chunks, zarr_chunks)
        .to_zarr(vis_zarr_file, mode="w", compute=False)
    )

    dask.compute(writer) # type: ignore

def import_model_vis(input_dir, vis_chunks):
    attributes_file, baselines_file, vis_zarr_file = (
        _generate_file_paths_for_vis_zarr_file(input_dir)
    )

    zarr_data = xr.open_dataset(
        vis_zarr_file, chunks=vis_chunks, engine="zarr"
    )

    with open(attributes_file, "rb") as file:
        attrs = pickle.load(file)

    zarr_data = zarr_data.assign_attrs(attrs)

    with open(baselines_file, "rb") as file:
        baselines = pickle.load(file)

    zarr_data = zarr_data.assign({"baselines": baselines})

    return zarr_data



In [None]:
def compute_and_clear_tasks(upstream_output: UpstreamOutput):
    for task in upstream_output.compute_tasks:
        task.compute()

    # Clear the compute tasks
    upstream_output.stage_compute_tasks = []

## Setup and configs

In [None]:
vis_ms = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.small.ms/"
cache_dir = "/home/nitin/Work/ska/sdp/ska-sdp-instrumental-calibration/cache"
eb_coeffs = "/home/ska/Work/data/INST/sim/coeffs"
lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"

In [None]:
_cli_args_ = {"input": vis_ms}
nchannels_per_chunk = 32
ntimes_per_ms_chunk = 5

upstream_output = UpstreamOutput()
ack = False
datacolumn = "DATA"
field_id = 0
data_desc_id = 0

# Predict visibilities stage
beam_type = "everybeam"
normalise_at_beam_centre = True
eb_ms = None
gleamfile = None
fov = 5.0
flux_limit = 1.0
alpha0 = -0.78


# # bandpass calibration stage
# solver = "gain_substitution"
# refant = 0
# niter = 150
# phase_only = True
# tol = 1.0e-06
# crosspol = False
# normalise_gains = "mean"
# timeslice = None
# plot_config = {
#     "plot_table": False,
#     "fixed_axis": False,
# }
# visibility_key = "vis"
# export_gaintable = True
# run_solver_config = {
#     "solver": solver,
#     "niter": niter,
#     "phase_only": phase_only,
#     "tol": tol,
#     "crosspol": crosspol,
#     "refant": refant,
#     "timeslice": timeslice,
#     "normalise_gains": normalise_gains,
# }

output_dir = f"solve_gaintable_output"

In [None]:
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}

# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": ntimes_per_ms_chunk,
    "frequency": nchannels_per_chunk,
}
vis_chunks = {
        **non_chunked_dims,
        "time": -1,
        "frequency": nchannels_per_chunk,
}

In [None]:
load_data_stage.stage_definition( # type: ignore
    upstream_output,
    nchannels_per_chunk,
    ntimes_per_ms_chunk,
    cache_dir,
    ack,
    datacolumn,
    field_id,
    data_desc_id,
    _cli_args_,
    output_dir,
)
input_vis = upstream_output["vis"]
initial_gaintable = upstream_output["gaintable"].copy(deep=True)

In [None]:
is_model_vis_exported = True

In [None]:
if not is_model_vis_exported:
    predict_vis_stage.stage_definition( # type: ignore
        upstream_output,
        beam_type,
        normalise_at_beam_centre,
        eb_ms,
        eb_coeffs,
        gleamfile,
        lsm_csv_path,
        fov,
        flux_limit,
        alpha0,
        _cli_args_,
    )

    input_modelvis = upstream_output["modelvis"]
    compute_and_clear_tasks(upstream_output)
    export_model_vis(input_modelvis, output_dir, zarr_chunks)

input_modelvis = import_model_vis(output_dir, vis_chunks)
upstream_output["modelvis"] = input_modelvis
    


## Refactoring `solve_gaintable`

In [None]:
from ska_sdp_datamodels.calibration.calibration_model import GainTable
from ska_sdp_datamodels.visibility.vis_model import Visibility
import logging
import scipy
from ska_sdp_func_python.calibration.solvers import (
    _solve_antenna_gains_itsubs_nocrossdata,
    _solve_antenna_gains_itsubs_matrix,
    _solve_antenna_gains_itsubs_scalar
)
from ska_sdp_func_python.calibration.alternative_solvers import(
    _jones_sub_solve,
_normal_equation_solve,
_normal_equation_solve_with_presumming,
)

In [None]:
def _solve_with_mask(
    crosspol,
    gaintable_gain: numpy.ndarray,
    gaintable_weight: numpy.ndarray,
    gaintable_residual: numpy.ndarray,
    mask,
    niter,
    phase_only,
    row,
    tol,
    npol,
    x,
    xwt,
    refant,
    refant_sort,
):
    """
    Method extracted from solve_gaintable to decrease
    complexity. Calculations when `numpy.sum(mask) > 0`
    """
    x_shape = x.shape
    x[mask] = x[mask] / xwt[mask]
    x[~mask] = 0.0
    xwt[mask] = xwt[mask] / numpy.max(xwt[mask])
    xwt[~mask] = 0.0
    x = x.reshape(x_shape)
    if npol == 2 or (npol == 4 and not crosspol):
        (
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            gaintable_residual[row, ...],
        ) = _solve_antenna_gains_itsubs_nocrossdata(
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            x,
            xwt,
            phase_only=phase_only,
            niter=niter,
            tol=tol,
            refant=refant,
            refant_sort=refant_sort,
        )
    elif npol == 4 and crosspol:
        (
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            gaintable_residual[row, ...],
        ) = _solve_antenna_gains_itsubs_matrix(
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            x,
            xwt,
            phase_only=phase_only,
            niter=niter,
            tol=tol,
            refant=refant,
            refant_sort=refant_sort,
        )

    else:
        (
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            gaintable_residual[row, ...],
        ) = _solve_antenna_gains_itsubs_scalar(
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            x,
            xwt,
            phase_only=phase_only,
            niter=niter,
            tol=tol,
            refant=refant,
            refant_sort=refant_sort,
        )

In [None]:
def solve_with_alternative_algorithm(
    solver,
    vis_vis: numpy.ndarray,
    vis_flags: numpy.ndarray,
    vis_weight: numpy.ndarray,
    model_vis: numpy.ndarray,
    ant1: numpy.ndarray,
    ant2: numpy.ndarray,
    gaintable_gain: numpy.ndarray,
    niter=100,
    row=0,
    tol=1e-6,
):
    """
    Solve this row (time slice) of the gain table.

    :param solver: Calibration solver to use. Options are:
        "jones_substitution", "normal_equations" and "normal_equations_presum"
    :param vis: Visibility containing the observed data_models
    :param modelvis: Visibility containing the visibility predicted by a model
    :param gain_table: Gaintable to be updated
    :param niter: Maximum number of iterations (default=100)
    :param row: Time slice to be calibrated (default=0)
    :param tol: Iteration stops when the fractional change
        in the gain solution is below this tolerance (default=1e-6)
    :return: GainTable gain_table, containing solution
    """
    gain = gaintable_gain[row, ...]
    _, nchan_gt, nrec1, nrec2 = gain.shape
    ntime, nbl, nchan_vis, npol_vis = vis_vis.shape
    assert nrec1 == 2
    assert nrec1 == nrec2
    assert nrec1 * nrec2 == npol_vis
    assert nchan_gt in (1, nchan_vis)

    # incorporate flags into weights
    wgt = vis_weight * (1 - vis_flags)
    # flag the whole Jones matrix if any element is flagged
    wgt *= numpy.all(wgt > 0, axis=-1, keepdims=True)
    # reduce the dimension to a single weight per matrix
    #  - could weight pols separately, but may be better not to
    wgt = wgt[..., 0]

    vmdl = model_vis.reshape(ntime, nbl, nchan_vis, 2, 2)
    vobs = vis_vis.reshape(ntime, nbl, nchan_vis, 2, 2)

    # Update model if a starting solution is given.
    I2 = numpy.eye(2)
    if numpy.any(gain[..., :, :] != I2):
        vmdl = numpy.einsum(
            "bfpi,tbfij,bfqj->tbfpq",
            gain[ant1],
            vmdl,
            gain[ant2].conj(),
        )

    log.debug(
        "solve_with_alternative_algorithm: "
        + "solving for %d chan in %d sub-band[s] using solver %s",
        nchan_vis,
        nchan_gt,
        solver,
    )

    for ch in range(nchan_gt):
        # select channels to average over. Just the current one if solving
        # each channel separately, or all of them if this is a joint solution.
        chan_vis = [ch] if nchan_gt == nchan_vis else range(nchan_vis)

        log.debug(
            "solve_with_alternative_algorithm: "
            + "sub-band %d, processing %d channels:",
            ch,
            len(chan_vis),
        )

        # could handle this with classes or similar, but keep it simple for now
        if solver == "jones_substitution":
            _jones_sub_solve(
                vobs[:, :, chan_vis],
                vmdl[:, :, chan_vis],
                wgt[:, :, chan_vis],
                ant1,
                ant2,
                gain,
                ch,
                niter,
                tol,
            )

        elif solver == "normal_equations":
            _normal_equation_solve(
                vobs[:, :, chan_vis],
                vmdl[:, :, chan_vis],
                wgt[:, :, chan_vis],
                ant1,
                ant2,
                gain,
                ch,
                niter,
                tol,
            )

        elif solver == "normal_equations_presum":
            _normal_equation_solve_with_presumming(
                vobs[:, :, chan_vis],
                vmdl[:, :, chan_vis],
                wgt[:, :, chan_vis],
                ant1,
                ant2,
                gain,
                ch,
                niter,
                tol,
            )

        elif solver == "gain_substitution":
            raise ValueError(
                "solve_with_alternative_algorithm: "
                + f"solver {solver} cannot be used in this function",
            )
        else:
            raise ValueError(
                f"solve_with_alternative_algorithm: unknown solver: {solver}",
            )

    # _jones_phase_referencing(gain, refant)

    # update gain_table in case the data reference became a copy
    gaintable_gain[row, ...] = gain

    return gaintable_gain

In [None]:
def find_best_refant_from_vis(
    flagged_vis: numpy.ndarray,
    flagged_weight: numpy.ndarray,
    ant1: numpy.ndarray,
    ant2: numpy.ndarray,
    nants: int,
):
    """
    This method comes from katsdpcal.
    (https://github.com/ska-sa/katsdpcal/blob/
    200c2f6e60b2540f0a89e7b655b26a2b04a8f360/katsdpcal/calprocs.py#L332)
    Determine antenna whose FFT has the maximum peak to noise ratio (PNR) by
    taking the median PNR of the FFT over all baselines to each antenna.

    When the input vis has only one channel, this uses all the vis of the
    same antenna for the operations peak, mean and std.

    :param vis: Visibilities
    :return: Array of indices of antennas in decreasing order
            of median of PNR over all baselines

    """
    visdata = flagged_vis
    _, _, nchan, _ = visdata.shape
    med_pnr_ants = numpy.zeros((nants))
    if nchan == 1:
        weightdata = flagged_weight
        for a in range(nants):
            mask = (ant1 == a) ^ (ant2 == a)
            weightdata_ant = weightdata[:, mask]
            mean_of_weight_ant = numpy.sum(weightdata_ant)
            med_pnr_ants[a] = mean_of_weight_ant
        med_pnr_ants += numpy.linspace(1e-8, 1e-9, nants)
    else:
        ft_vis = scipy.fftpack.fft(visdata, axis=2)
        max_value_arg = numpy.argmax(numpy.abs(ft_vis), axis=2)
        index = numpy.array(
            [numpy.roll(range(nchan), -n) for n in max_value_arg.ravel()]
        )
        index = index.reshape(list(max_value_arg.shape) + [nchan])
        index = numpy.transpose(index, (0, 1, 3, 2))
        ft_vis = numpy.take_along_axis(ft_vis, index, axis=2)

        peak = numpy.max(numpy.abs(ft_vis), axis=2)

        chan_slice = numpy.s_[
            nchan // 2 - nchan // 4 : nchan // 2 + nchan // 4 + 1  # noqa E203
        ]
        mean = numpy.mean(numpy.abs(ft_vis[:, :, chan_slice]), axis=2)
        std = numpy.std(numpy.abs(ft_vis[:, :, chan_slice]), axis=2) + 1e-9
        for a in range(nants):
            mask = (ant1 == a) ^ (ant2 == a)

            pnr = (peak[:, mask] - mean[:, mask]) / std[:, mask]
            med_pnr = numpy.median(pnr)
            med_pnr_ants[a] = med_pnr
    return numpy.argsort(med_pnr_ants)[::-1]

In [None]:
def divide_visibility(
    vis_flagged: numpy.ndarray,
    vis_flagged_weight: numpy.ndarray,
    model_flagged_vis: numpy.ndarray,
):
    """
    Divide visibility by model forming
    visibility for equivalent point source.

    This is a useful intermediate product for calibration.
    Variation of the visibility in time and frequency due
    to the model structure is removed and the data can be
    averaged to a limit determined by the instrumental stability.
    The weight is adjusted to compensate for the division.

    Zero divisions are avoided and the corresponding weight set to zero.

    :param vis: Visibility to be divided
    :param modelvis: Visibility to divide with
    :return: Divided Visibility
    """

    x = numpy.zeros_like(vis_flagged)
    xwt = numpy.abs(model_flagged_vis) ** 2 * vis_flagged_weight
    mask = xwt > 0.0
    x[mask] = vis_flagged[mask] / model_flagged_vis[mask]

    return (x, xwt)

In [None]:
def _apply_flag(x: numpy.ndarray, flags: numpy.ndarray):
    return x * (1 - flags)

In [None]:
def solve_gaintable_new(
    # chunked
    vis_vis: numpy.ndarray,
    vis_flags: numpy.ndarray,
    vis_weight: numpy.ndarray,
    model_vis: numpy.ndarray,  # [time, baseline, freq, pol]
    model_flags: numpy.ndarray,  # [time, baseline, freq, pol]
    gain: numpy.ndarray,
    gain_weights: numpy.ndarray,
    gain_residual: numpy.ndarray,
    # non chunked
    baselines: numpy.ndarray,
    ant1: numpy.ndarray,
    ant2: numpy.ndarray,
    # kwargs
    phase_only=True,
    niter=30,
    tol=1e-6,
    crosspol=False,
    normalise_gains="mean",
    solver="gain_substitution",
    refant=0,
):
    """
    Solve a gain table by fitting an observed visibility
    to a model visibility.

    If modelvis is None, a point source model is assumed.

    :param vis: Visibility containing the observed data_models
    :param modelvis: Visibility containing the visibility predicted by a model
    :param gain_table: Existing gaintable (default=None)
    :param phase_only: Solve only for the phases. default=True when
        solver="gain_substitution", otherwise it must be False.
    :param niter: Maximum number of iterations (default=30)
    :param tol: Iteration stops when the fractional change
        in the gain solution is below this tolerance (default=1e-6)
    :param crosspol: Do solutions including cross polarisations
        i.e. XY, YX or RL, LR. Only used by the gain_substitution solver.
    :param normalise_gains: Normalises the gains (default="mean")
        options are None, "mean", "median".
        None means no normalization. Only available with gain_substitution.
    :param solver: Calibration algorithm to use (default="gain_substitution")
        options are:
        "gain_substitution" - original substitution algorithm with separate
        solutions for each polarisation term.
        "jones_substitution" - solve antenna-based Jones matrices as a whole,
        with independent updates within each iteration.
        "normal_equations" - solve normal equations within each iteration
        formed from linearisation with respect to antenna-based gain and
        leakage terms.
        "normal_equations_presum" - the same as the normal_equations option
        but with an initial accumulation of visibility products over time
        and frequency for each solution interval. This can be much faster
        for large datasets and solution intervals.
    :param jones_type: Type of calibration matrix T or G or B
    :param timeslice: Time interval between solutions (s)
    :param refant: Reference antenna (default 0). Currently only activated for
        the gain_substitution solver.
    :return: GainTable containing solution

    """
    log.info("solve_gaintable: Starting calibration")
    log.info("solve_gaintable: Using solver %s", solver)

    if (model_vis is not None) ^ (model_flags is not None):
        raise ValueError("solve_gaintable: model_vis and model_flag are required")

    if model_vis is not None:
        # pylint: disable=unneeded-not
        if not numpy.max(numpy.abs(model_vis)) > 0.0:
            raise ValueError("solve_gaintable: Model visibility is zero")

    if model_vis is None and solver != "gain_substitution":
        raise ValueError(f"solve_gaintable: model_vis required for solver {solver}")

    if normalise_gains is not None and solver != "gain_substitution":
        raise ValueError(
            f"solve_gaintable: normalise_gains unsupported with {solver}",
        )

    if phase_only is True and solver != "gain_substitution":
        log.warning(
            "solve_gaintable: resetting phase_only to False for solver %s",
            solver,
        )
        phase_only = False

    # this is only needed by the gain_substitution solver
    if phase_only:
        log.debug("solve_gaintable: Solving for phase only")
    else:
        log.debug("solve_gaintable: Solving for complex gain")

    # if gain_table is None:
    #     log.debug("solve_gaintable: Creating new gaintable")
    #     gain_table = create_gaintable_from_visibility(
    #         vis, jones_type=jones_type, timeslice=timeslice
    #     )
    # else:
    #     log.debug("solve_gaintable: Starting from existing gaintable")

    nants = gain.shape[1]
    nchan = gain.shape[2]
    npol = vis_vis.shape[-1]

    axes = (0, 2) if nchan == 1 else 0

    if solver == "gain_substitution":
        vis_flagged = _apply_flag(vis_vis, vis_flags)
        vis_flagged_weight = _apply_flag(vis_weight, vis_flags)

        (pointvis_vis, pointvis_weight) = (
            divide_visibility(
                vis_flagged, vis_flagged_weight, _apply_flag(model_vis, model_flags)
            )
            if model_vis is not None
            else (vis_vis, vis_weight)
        )

        
        pointvis_flags = vis_flags
        pointvis_flagged_vis = _apply_flag(pointvis_vis, pointvis_flags)
        pointvis_flagged_weight = _apply_flag(pointvis_weight, pointvis_flags)
        # pointvis_sel = pointvis # type: ignore
        # pylint: disable=unneeded-not
        # if not pointvis_sel.visibility_acc.ntimes > 0:
        #     log.warning(
        #         "solve_gaintable: time mismatch. Gaintable %s, vis %s",
        #         gain_table.time,
        #         vis.time,
        #     )
        # continue

        refant_sort = find_best_refant_from_vis(
            pointvis_flagged_vis, pointvis_flagged_weight, baselines, nants
        )
        x_b = numpy.sum(
            (pointvis_vis * pointvis_weight) * (1 - pointvis_flags),
            axis=axes,
        )
        xwt_b = numpy.sum(
            pointvis_weight * (1 - pointvis_flags),
            axis=axes,
        )
        x = numpy.zeros([nants, nants, nchan, npol], dtype="complex")
        xwt = numpy.zeros([nants, nants, nchan, npol])
        for ibaseline, (a1, a2) in enumerate(baselines):
            x[a1, a2, ...] = numpy.conjugate(x_b[ibaseline, ...])
            xwt[a1, a2, ...] = xwt_b[ibaseline, ...]
            x[a2, a1, ...] = x_b[ibaseline, ...]
            xwt[a2, a1, ...] = xwt_b[ibaseline, ...]

        mask = numpy.abs(xwt) > 0.0
        if numpy.sum(mask) > 0:
            _solve_with_mask(
                crosspol,
                gain,
                gain_weights,
                gain_residual,
                mask,
                niter,
                phase_only,
                0,
                tol,
                npol,
                x,
                xwt,
                refant,
                refant_sort,
            )
        else:
            gain[0, ...] = 1.0 + 0.0j
            gain_weights[0, ...] = 0.0
            gain_residual[0, ...] = 0.0

    else:
        # the remaining solvers require separate observed and model vis

        # pylint: disable=unneeded-not
        # if not vis_sel.visibility_acc.ntimes > 0:
        #     log.warning(
        #         "solve_gaintable: time mismatch. Gaintable %s, vis %s",
        #         gain_table.time,
        #         vis.time,
        #     )
        # continue

        solve_with_alternative_algorithm(
            solver,
            vis_vis,
            vis_flags,
            vis_weight,
            model_vis,
            ant1,
            ant2,
            gain,
            niter,
            0,
            tol,
        )


    if normalise_gains in ["median", "mean"] and not phase_only:
        normaliser = {
            "median": numpy.median,
            "mean": numpy.mean,
        }
        gabs = normaliser[normalise_gains](numpy.abs(gain[:]))
        gain[:] /= gabs

    log.info("solve_gaintable: Finished calibration")

    return gain

## Solvers

In [None]:
from typing import Tuple
from ska_sdp_func_python.calibration.solvers import solve_gaintable


class Solver:
    def __init__(self, niter, tol):
        self.niter = niter
        self.tol = tol

    def solve(
        self,
        vis_vis: numpy.ndarray,
        vis_flags: numpy.ndarray,
        vis_weight: numpy.ndarray,
        model_vis: numpy.ndarray,  # [time, baseline, freq, pol]
        model_flags: numpy.ndarray,  # [time, baseline, freq, pol]
        gain_gain: numpy.ndarray,
        gain_weight: numpy.ndarray,
        gain_residual: numpy.ndarray,
        ant1: numpy.ndarray,
        ant2: numpy.ndarray,
    ) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
        raise NotImplementedError("Not implemented")

In [None]:
def normalise_gains(gain: numpy.ndarray, type: str) -> numpy.ndarray:
    if type in ["median", "mean"]:
        normaliser = {
            "median": numpy.median,
            "mean": numpy.mean,
        }
        gabs = normaliser[type](numpy.abs(gain))
        gain = gain / gabs
        return gain

    return gain

In [None]:
from typing import Tuple


from numpy import ndarray


def gain_substitution(
    gain: numpy.ndarray,
    gain_weight: numpy.ndarray,
    gain_residual: numpy.ndarray,
    pointvis_vis: numpy.ndarray,
    pointvis_flags: numpy.ndarray,
    pointvis_weight: numpy.ndarray,
    ant1: numpy.ndarray,
    ant2: numpy.ndarray,
    crosspol: bool = False,
    niter: int = 30,
    phase_only: bool = True,
    tol: float = 1e-6,
    refant: int = 0
    
) -> Tuple[ndarray, ndarray, ndarray]:
    _gain = gain.copy()
    _gain_weight = gain_weight.copy()
    _gain_residual = gain_residual.copy()

    nants = _gain.shape[1]  # TODO: should get antennas number from Configuration?
    nchan = _gain.shape[2]
    npol = pointvis_vis.shape[-1]

    axes = (0, 2) if nchan == 1 else 0

    pointvis_flagged_vis = _apply_flag(pointvis_vis, pointvis_flags)
    pointvis_flagged_weight = _apply_flag(pointvis_weight, pointvis_flags)

    refant_sort = find_best_refant_from_vis(
        pointvis_flagged_vis, pointvis_flagged_weight, ant1, ant2, nants
    )
    x_b = numpy.sum(
        (pointvis_vis * pointvis_weight) * (1 - pointvis_flags),
        axis=axes,
    )
    xwt_b = numpy.sum(
        pointvis_weight * (1 - pointvis_flags),
        axis=axes,
    )
    x = numpy.zeros([nants, nants, nchan, npol], dtype="complex")
    xwt = numpy.zeros([nants, nants, nchan, npol])

    for ibaseline, (a1, a2) in enumerate(zip(ant1, ant2)):
        x[a1, a2, ...] = numpy.conjugate(x_b[ibaseline, ...])
        xwt[a1, a2, ...] = xwt_b[ibaseline, ...]
        x[a2, a1, ...] = x_b[ibaseline, ...]
        xwt[a2, a1, ...] = xwt_b[ibaseline, ...]

    mask = numpy.abs(xwt) > 0.0
    if numpy.sum(mask) > 0:
        _solve_with_mask(
            crosspol,
            _gain,
            _gain_weight,
            _gain_residual,
            mask,
            niter,
            phase_only,
            0,
            tol,
            npol,
            x,
            xwt,
            refant,
            refant_sort,
        )
    else:
        _gain[...] = 1.0 + 0.0j
        _gain_weight[...] = 0.0
        _gain_residual[...] = 0.0

    return (_gain, _gain_weight, _gain_residual)


def create_point_vis(
    vis_vis: numpy.ndarray,
    vis_flags: numpy.ndarray,
    vis_weight: numpy.ndarray,
    model_vis: numpy.ndarray,
    model_flags: numpy.ndarray,
):
    vis_flagged = _apply_flag(vis_vis, vis_flags)
    vis_flagged_weight = _apply_flag(vis_weight, vis_flags)

    return (
        divide_visibility(
            vis_flagged, vis_flagged_weight, _apply_flag(model_vis, model_flags)
        )
        if model_vis is not None
        else (vis_vis, vis_weight)
    )


class GainSubstitution(Solver):
    def __init__(self, refant, phase_only, crosspol, **kwargs):
        super(GainSubstitution, self).__init__(**kwargs)
        self.refant = refant
        self.phase_only = phase_only
        self.crosspol = crosspol

    def solve(
        self,
        vis_vis: ndarray,
        vis_flags: ndarray,
        vis_weight: ndarray,
        model_vis: ndarray,
        model_flags: ndarray,
        gain_gain: ndarray,
        gain_weight: ndarray,
        gain_residual: ndarray,
        ant1: ndarray,
        ant2: ndarray,
    ) -> Tuple[ndarray, ndarray, ndarray]:
        if self.refant < 0 or self.refant >= gain_gain.shape[1]:
            raise ValueError(f"Invalid refant: {self.refant}")

        if model_vis is not None:
            # pylint: disable=unneeded-not
            if not numpy.max(numpy.abs(model_vis)) > 0.0:
                raise ValueError("solve_gaintable: Model visibility is zero")

        if (model_vis is not None) ^ (model_flags is not None):
            raise ValueError("solve_gaintable: model_vis and model_flag are required")

        (pointvis_vis, pointvis_weight) = create_point_vis(
            vis_vis, vis_flags, vis_weight, model_vis, model_flags
        )

        return gain_substitution(
            gain_gain,
            gain_weight,
            gain_residual,
            pointvis_vis,
            vis_flags,
            pointvis_weight,
            ant1,
            ant2,
            crosspol=self.crosspol,
            niter=self.niter,
            phase_only=self.phase_only,
            tol=self.tol,
            refant=self.refant,
        )

        # Note: this should happen on the whole data. This will move to the caller.
        # if self.normalise_gains is not None and not self.phase_only:
        #     self._normalise_gains(gain_gain, self.normalise_gains)

In [None]:
from numpy import ndarray


class AlternativeSolver(Solver):
    def __init__(self, **kwargs):
        super(AlternativeSolver, self).__init__(**kwargs)
        self.solver_fn = None

    def solve(
        self,
        vis_vis: ndarray,
        vis_flags: ndarray,
        vis_weight: ndarray,
        model_vis: ndarray,
        model_flags: ndarray,
        gain_gain: ndarray,
        gain_weight: ndarray,
        gain_residual: ndarray,
        ant1: ndarray,
        ant2: ndarray,
    ) -> Tuple[ndarray, ndarray, ndarray]:
        if self.solver_fn is None:
            raise ValueError(
                "AlternativeSolver: alternative solver function to be used is not provided."
            )
        _gain_gain = gain_gain.copy()
        gain = _gain_gain[0] # select first time

        _, nchan_gt, nrec1, nrec2 = gain.shape
        ntime, nbl, nchan_vis, npol_vis = vis_vis.shape
        assert nrec1 == 2
        assert nrec1 == nrec2
        assert nrec1 * nrec2 == npol_vis
        assert nchan_gt in (1, nchan_vis)

        # incorporate flags into weights
        wgt = vis_weight * (1 - vis_flags)
        # flag the whole Jones matrix if any element is flagged
        wgt *= numpy.all(wgt > 0, axis=-1, keepdims=True)
        # reduce the dimension to a single weight per matrix
        #  - could weight pols separately, but may be better not to
        wgt = wgt[..., 0]

        vmdl = model_vis.reshape(ntime, nbl, nchan_vis, 2, 2)
        vobs = vis_vis.reshape(ntime, nbl, nchan_vis, 2, 2)

        # Update model if a starting solution is given.
        I2 = numpy.eye(2)
        if numpy.any(gain[..., :, :] != I2):
            vmdl = numpy.einsum(
                "bfpi,tbfij,bfqj->tbfpq",
                gain[ant1],
                vmdl,
                gain[ant2].conj(),
            )

        log.debug(
            "solve_with_alternative_algorithm: "
            + "solving for %d chan in %d sub-band[s] using solver %s",
            nchan_vis,
            nchan_gt,
            self.solver_fn,
        )

        for ch in range(nchan_gt):
            # select channels to average over. Just the current one if solving
            # each channel separately, or all of them if this is a joint solution.
            chan_vis = [ch] if nchan_gt == nchan_vis else range(nchan_vis)

            log.debug(
                "solve_with_alternative_algorithm: "
                + "sub-band %d, processing %d channels:",
                ch,
                len(chan_vis),
            )

            self.solver_fn(
                vobs[:, :, chan_vis],
                vmdl[:, :, chan_vis],
                wgt[:, :, chan_vis],
                ant1,
                ant2,
                gain,
                ch,
                self.niter,
                self.tol,
            )

        _gain_gain[0, ...] = gain

        return (_gain_gain, gain_weight, gain_residual)

In [None]:
from numpy import ndarray


class JonesSubtitution(AlternativeSolver):
    def __init__(self, *args, **kwargs):
        super(JonesSubtitution, self).__init__(*args, **kwargs)
        self.solver_fn = _jones_sub_solve

In [None]:
class NormalEquation(AlternativeSolver):
    def __init__(self, *args, **kwargs):
        super(NormalEquation, self).__init__(*args, **kwargs)
        self.solver_fn = _normal_equation_solve

In [None]:

class NormalEquationsPreSum(AlternativeSolver):
    def __init__(self, *args, **kwargs):
        super(NormalEquationsPreSum, self).__init__(*args, **kwargs)
        self.solver_fn = _normal_equation_solve_with_presumming

In [None]:
class SolverFactory:
    _solvers = {
        "gain_substitution": GainSubstitution,
        "jones_substitution": JonesSubtitution,
        "normal_equations": NormalEquation,
        "normal_equations_presum": NormalEquationsPreSum,
    }

    @classmethod
    def get_solver(cls, solver, *args, **kwargs):
        return cls._solvers[solver](*args, **kwargs)

-------------

## Testing

### Integration test

In [None]:
solver = "gain_substitution"
refant = 0
niter = 150
phase_only = True
tol = 1.0e-06
crosspol = False
normalise_gains = "mean"
timeslice = None

#### Solver - `gain_substitution`

In [None]:
from ska_sdp_func_python.calibration.solvers import solve_gaintable
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import restore_baselines_dim

In [None]:
input_vis = input_vis.compute()
input_modelvis = input_modelvis.compute()

input_vis_restored = restore_baselines_dim(input_vis)
input_modelvis_restored = restore_baselines_dim(input_modelvis)

In [None]:

expected_gaintable_gain_substitution = solve_gaintable(
    vis=input_vis_restored,
    modelvis=input_modelvis_restored,
    gain_table=initial_gaintable.compute().copy(deep=True),
    solver="gain_substitution",
    phase_only=phase_only,
    niter=niter,
    crosspol=crosspol,
    tol=tol,
    normalise_gains=None,
    jones_type="B",
)

In [None]:
solver = GainSubstitution(
    refant=refant, phase_only=phase_only, crosspol=crosspol, tol=tol, niter=niter
)

initial_gaintable = initial_gaintable.compute().copy(deep=True)

(
    actual_gaintable_gain_subtitution,
    actual_gaintable_gain_subtitution_weight,
    actual_gaintable_gain_subtitution_residual,
) = solver.solve(
    vis_vis=input_vis.vis.data,
    vis_flags=input_vis.flags.data,
    vis_weight=input_vis.weight.data,
    model_vis=input_modelvis.vis.data,
    model_flags=input_modelvis.flags.data,
    gain_gain=initial_gaintable["gain"].values,
    gain_weight=initial_gaintable["weight"].values,
    gain_residual=initial_gaintable["residual"].values,
    ant1=input_vis.antenna1.data,
    ant2=input_vis.antenna2.data,
)

In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_gain_substitution["gain"],
    actual_gaintable_gain_subtitution
)
print("Gaintable gain values match expected results.")


In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_gain_substitution["weight"],
    actual_gaintable_gain_subtitution_weight
)
print("Gaintable weight values match expected results.")


In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_gain_substitution["residual"],
    actual_gaintable_gain_subtitution_residual
)
print("Gaintable residual values match expected results.")


#### Solver - `jone_substitution`

In [None]:
expected_gaintable_jones_substitution = solve_gaintable(
    vis=input_vis_restored,
    modelvis=input_modelvis_restored,
    gain_table=initial_gaintable.compute().copy(deep=True),
    solver="jones_substitution",
    phase_only=phase_only,
    niter=niter,
    crosspol=crosspol,
    tol=tol,
    normalise_gains=None,
    jones_type="B",
)

In [None]:
solver = JonesSubtitution(tol=tol, niter=niter)

initial_gaintable = initial_gaintable.compute().copy(deep=True)

(
    actual_gaintable_jones_subtitution_gain,
    actual_gaintable_jones_subtitution_weight,
    actual_gaintable_jones_subtitution_residual,
) = solver.solve(
    vis_vis=input_vis.vis.data,
    vis_flags=input_vis.flags.data,
    vis_weight=input_vis.weight.data,
    model_vis=input_modelvis.vis.data,
    model_flags=input_modelvis.flags.data,
    gain_gain=initial_gaintable["gain"].values,
    gain_weight=initial_gaintable["weight"].values,
    gain_residual=initial_gaintable["residual"].values,
    ant1=input_vis.antenna1.data,
    ant2=input_vis.antenna2.data,
)

In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_jones_substitution["gain"],
    actual_gaintable_jones_subtitution_gain
)
print("Gaintable gain values match expected results.")


In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_jones_substitution["weight"],
    actual_gaintable_jones_subtitution_weight
)
print("Gaintable weight values match expected results.")


In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_jones_substitution["residual"],
    actual_gaintable_jones_subtitution_residual
)
print("Gaintable residual values match expected results.")


#### Solver - `normal_equations`

In [None]:
expected_gaintable_normal_equations = solve_gaintable(
    vis=input_vis_restored,
    modelvis=input_modelvis_restored,
    gain_table=initial_gaintable.compute().copy(deep=True),
    solver="normal_equations",
    phase_only=phase_only,
    niter=niter,
    crosspol=crosspol,
    tol=tol,
    normalise_gains=None,
    jones_type="B",
)

In [None]:
solver = NormalEquation(tol=tol, niter=niter)

initial_gaintable = initial_gaintable.compute().copy(deep=True)

(
    actual_gaintable_normal_equations_gain,
    actual_gaintable_normal_equations_weight,
    actual_gaintable_normal_equations_residual,
) = solver.solve(
    vis_vis=input_vis.vis.data,
    vis_flags=input_vis.flags.data,
    vis_weight=input_vis.weight.data,
    model_vis=input_modelvis.vis.data,
    model_flags=input_modelvis.flags.data,
    gain_gain=initial_gaintable["gain"].values,
    gain_weight=initial_gaintable["weight"].values,
    gain_residual=initial_gaintable["residual"].values,
    ant1=input_vis.antenna1.data,
    ant2=input_vis.antenna2.data,
)

In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_normal_equations["gain"],
    actual_gaintable_normal_equations_gain
)
print("Gaintable gain values match expected results.")


In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_normal_equations["weight"],
    actual_gaintable_normal_equations_weight
)
print("Gaintable weight values match expected results.")


In [None]:
numpy.testing.assert_allclose(
    expected_gaintable_normal_equations["residual"],
    actual_gaintable_normal_equations_residual
)
print("Gaintable residual values match expected results.")


-------------------

### Tests from `ska_sdp_func_python`

In [None]:
from ska_sdp_datamodels.configuration import create_named_configuration
from astropy.coordinates import SkyCoord
from ska_sdp_datamodels.science_data_model import PolarisationFrame
from ska_sdp_datamodels.sky_model import SkyComponent
from ska_sdp_func_python.imaging import dft_skycomponent_visibility
from ska_sdp_datamodels.visibility import create_visibility
from astropy import units as u
import numpy

import logging
log = logging.getLogger("func-python-logger")


#### Helper functions for tests

In [None]:
def simulate_gaintable(
    gain_table,
    phase_error=0.1,
    amplitude_error=0.0,
    leakage=0.0,
):
    """
    Simulate a gain table

    :type gain_table: GainTable
    :param phase_error: std of normal distribution, zero mean
    :param amplitude_error: std of log normal distribution
    :param leakage: std of cross hand leakage
    :return: updated GainTable
    """
    # pylint: disable=import-outside-toplevel
    from numpy.random import default_rng

    rng = default_rng(1805550721)

    log.debug(
        "simulate_gaintable: Simulating amplitude "
        "error = %.4f, phase error = %.4f",
        amplitude_error,
        phase_error,
    )
    amps = 1.0
    phases = 1.0
    nrec = gain_table["gain"].data.shape[3]

    if phase_error > 0.0:
        phases = rng.normal(0, phase_error, gain_table["gain"].data.shape)

    if amplitude_error > 0.0:
        amps = rng.lognormal(
            0.0, amplitude_error, gain_table["gain"].data.shape
        )

    gain_table["gain"].data = amps * numpy.exp(0 + 1j * phases)
    nrec = gain_table["gain"].data.shape[-1]
    if nrec > 1:
        if leakage > 0.0:
            leak = rng.normal(
                0, leakage, gain_table["gain"].data[..., 0, 0].shape
            ) + 1j * rng.normal(
                0, leakage, gain_table["gain"].data[..., 0, 0].shape
            )
            gain_table["gain"].data[..., 0, 1] = (
                gain_table["gain"].data[..., 0, 0] * leak
            )
            leak = rng.normal(
                0, leakage, gain_table["gain"].data[..., 1, 1].shape
            ) + 1j * rng.normal(
                0, leakage, gain_table["gain"].data[..., 1, 1].shape
            )
            gain_table["gain"].data[..., 1, 0] = (
                gain_table["gain"].data[..., 1, 1].data * leak
            )
        else:
            gain_table["gain"].data[..., 0, 1] = 0.0
            gain_table["gain"].data[..., 1, 0] = 0.0

    return gain_table



In [None]:
def vis_with_component_data(
    sky_pol_frame, data_pol_frame, flux_array, **kwargs
):
    """
    Generate Visibility data for testing.

    :param sky_pol_frame: PolarisationFrame of SkyComponents
    :param data_pol_frame: PolarisationFrame of Visibility data
    :param flux_array: Flux data for SkyComponents
    :param kwargs: includes:
            ntimes: number of time samples
            rmax: maximum distance of antenna from centre
                  when configuration is determined
            nchan: number of frequency channels
    """
    ntimes = kwargs.get("ntimes", 3)
    rmax = kwargs.get("rmax", 300)
    lowcore = create_named_configuration("LOWBD2", rmax=rmax)
    times = (numpy.pi / 43200.0) * numpy.linspace(0.0, 30.0, 1 + ntimes)

    nchan = kwargs.get("nchan", 1)
    if nchan > 1:
        frequency = numpy.linspace(1.0e8, 1.1e8, nchan)
        channel_bandwidth = numpy.array(nchan * [frequency[1] - frequency[0]])
    else:
        frequency = 1e8 * numpy.ones([1])
        channel_bandwidth = 1e7 * numpy.ones([1])

    # The phase centre is absolute and the component is specified relative
    # This means that the component should end up at the position
    # phasecentre+compredirection
    phasecentre = SkyCoord(
        ra=+180.0 * u.deg, dec=-35.0 * u.deg, frame="icrs", equinox="J2000"
    )
    compabsdirection = SkyCoord(
        ra=+181.0 * u.deg, dec=-35.0 * u.deg, frame="icrs", equinox="J2000"
    )

    if sky_pol_frame == "stokesI":
        flux_array = [100.0]
    flux = numpy.outer(
        numpy.array([numpy.power(freq / 1e8, -0.7) for freq in frequency]),
        flux_array,
    )

    comp = SkyComponent(
        direction=compabsdirection,
        frequency=frequency,
        flux=flux,
        polarisation_frame=PolarisationFrame(sky_pol_frame),
    )
    vis = create_visibility(
        lowcore,
        times,
        frequency,
        phasecentre=phasecentre,
        channel_bandwidth=channel_bandwidth,
        weight=1.0,
        polarisation_frame=PolarisationFrame(data_pol_frame),
    )
    vis = dft_skycomponent_visibility(vis, comp)
    return vis



In [None]:
from ska_sdp_datamodels.calibration.calibration_create import (
    create_gaintable_from_visibility,
)
from ska_sdp_func_python.calibration.operations import apply_gaintable

#### Test 1

In [None]:
from ska_sdp_func_python.calibration.solvers import (
    solve_gaintable,
)

from ska_sdp_instrumental_calibration.workflow.utils import create_bandpass_table

sky_pol_frame = "stokesIQUV"
data_pol_frame = "circular"
flux_array = [100.0, 0.0, 0.0, 50.0]
phase_error = 10.0

#expected_gain_sum = (-2.3575149649, -19.50250306305245)

jones_type = "T"

vis = vis_with_component_data(sky_pol_frame, data_pol_frame, flux_array)
vis = vis.mean(dim="time", keep_attrs=True, keepdims=True)

gain_table = create_gaintable_from_visibility(vis, jones_type=jones_type)
original_gaintable = create_gaintable_from_visibility(vis, jones_type=jones_type)

gain_table = simulate_gaintable(
    gain_table,
    phase_error=phase_error,
    amplitude_error=0.0,
    leakage=0.0,
)

original = vis.copy(deep=True)
vis = apply_gaintable(vis, gain_table)

expected_gain_table = solve_gaintable(
    vis=vis.compute(),
    modelvis=original.compute(),
    gain_table=original_gaintable.copy(deep=True).compute(),
    phase_only=True,
    niter=200,
    crosspol=False,
    tol=1e-6,
    normalise_gains=None,
    jones_type=jones_type,
)

solver = GainSubstitution(
    refant=refant,
    phase_only=True,
    crosspol=False,
    niter=200,
    tol=1e-6,
)


actual_gain_table = solver.solve(
    vis_vis=vis.vis.data,
    vis_flags=vis.flags.data,
    vis_weight=vis.weight.data,
    model_vis=original.vis.data,
    model_flags=original.flags.data,
    gain_gain=original_gaintable["gain"].values,
    gain_weight=original_gaintable["weight"].values,
    gain_residual=original_gaintable["residual"].values,
    ant1=vis.antenna1.data,
    ant2=vis.antenna2.data,
)

assert actual_gain_table[0].sum().real.round(10) == expected_gain_table[
    "gain"
].sum().real.round(10)
assert actual_gain_table[0].sum().imag.round(10) == expected_gain_table[
    "gain"
].sum().imag.round(10)