In [None]:
import os
import xarray as xr
import numpy
import os
import dask
# dask.config.set(scheduler='threads')  # or 'synchronous' for step-by-step

# If you want a dashboard:
# from dask.distributed import Client
# client = Client()
from ska_sdp_instrumental_calibration.workflow.stages import (
    load_data_stage,
    predict_vis_stage,
    bandpass_calibration_stage
)

from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput
from copy import deepcopy
from ska_sdp_instrumental_calibration.data_managers.visibility import _generate_file_paths_for_vis_zarr_file
import pickle

from ska_sdp_instrumental_calibration.workflow.utils import with_chunks



## Helper functions

In [None]:
def export_model_vis(vis, output_dir, zarr_chunks ,):
    os.makedirs(output_dir, exist_ok=True)
    
    attributes_file, baselines_file, vis_zarr_file = (
        _generate_file_paths_for_vis_zarr_file(output_dir)
    )
    attrs = deepcopy(vis.attrs)
    with open(attributes_file, "wb") as file:
        pickle.dump(attrs, file)

    baselines = deepcopy(vis.baselines).compute()
    with open(baselines_file, "wb") as file:
        pickle.dump(baselines, file)

    writer = (
        vis.drop_attrs()
        .drop_vars("baselines")
        .pipe(with_chunks, zarr_chunks)
        .to_zarr(vis_zarr_file, mode="w", compute=False)
    )

    dask.compute(writer) # type: ignore

def import_model_vis(input_dir, vis_chunks):
    attributes_file, baselines_file, vis_zarr_file = (
        _generate_file_paths_for_vis_zarr_file(input_dir)
    )

    zarr_data = xr.open_dataset(
        vis_zarr_file, chunks=vis_chunks, engine="zarr"
    )

    with open(attributes_file, "rb") as file:
        attrs = pickle.load(file)

    zarr_data = zarr_data.assign_attrs(attrs)

    with open(baselines_file, "rb") as file:
        baselines = pickle.load(file)

    zarr_data = zarr_data.assign({"baselines": baselines})

    return zarr_data



In [None]:
def compute_and_clear_tasks(upstream_output: UpstreamOutput):
    for task in upstream_output.compute_tasks:
        task.compute()

    # Clear the compute tasks
    upstream_output.stage_compute_tasks = []

## Setup and configs

In [None]:
vis_ms = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.small.ms/"
cache_dir = "/home/nitin/Work/ska/sdp/ska-sdp-instrumental-calibration/cache"
eb_coeffs = "/home/ska/Work/data/INST/sim/coeffs"
lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"

In [None]:
_cli_args_ = {"input": vis_ms}
nchannels_per_chunk = 32
ntimes_per_ms_chunk = 5

upstream_output = UpstreamOutput()
ack = False
datacolumn = "DATA"
field_id = 0
data_desc_id = 0

# Predict visibilities stage
beam_type = "everybeam"
normalise_at_beam_centre = True
eb_ms = None
gleamfile = None
fov = 5.0
flux_limit = 1.0
alpha0 = -0.78


# bandpass calibration stage
solver = "gain_substitution"
refant = 0
niter = 150
phase_only = True
tol = 1.0e-06
crosspol = False
normalise_gains = None
timeslice = None
plot_config = {
    "plot_table": False,
    "fixed_axis": False,
}
visibility_key = "vis"
export_gaintable = True
run_solver_config = {
    "solver": solver,
    "niter": niter,
    "phase_only": phase_only,
    "tol": tol,
    "crosspol": crosspol,
    "refant": refant,
    "timeslice": timeslice,
    "normalise_gains": normalise_gains,
}

output_dir = f"bandpass_stage_{solver}_{niter}"

In [None]:
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}

# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": ntimes_per_ms_chunk,
    "frequency": nchannels_per_chunk,
}
vis_chunks = {
        **non_chunked_dims,
        "time": -1,
        "frequency": nchannels_per_chunk,
}

In [None]:
load_data_stage.stage_definition( # type: ignore
    upstream_output,
    nchannels_per_chunk,
    ntimes_per_ms_chunk,
    cache_dir,
    ack,
    datacolumn,
    field_id,
    data_desc_id,
    _cli_args_,
    output_dir,
)
input_vis = upstream_output["vis"]
initial_gaintable = upstream_output["gaintable"].copy(deep=True)

In [None]:
is_model_vis_exported = True

In [None]:
if not is_model_vis_exported:
    predict_vis_stage.stage_definition( # type: ignore
        upstream_output,
        beam_type,
        normalise_at_beam_centre,
        eb_ms,
        eb_coeffs,
        gleamfile,
        lsm_csv_path,
        fov,
        flux_limit,
        alpha0,
        _cli_args_,
    )

    input_modelvis = upstream_output["modelvis"]
    compute_and_clear_tasks(upstream_output)
    export_model_vis(input_modelvis, output_dir, zarr_chunks)

input_modelvis = import_model_vis(output_dir, vis_chunks)
upstream_output["modelvis"] = input_modelvis
    


In [None]:

bandpass_calibration_stage.stage_definition( # type: ignore
    upstream_output,
    run_solver_config,
    plot_config,
    visibility_key,
    export_gaintable,
    output_dir,
)

gaintable_expected = upstream_output["gaintable"]
gaintable_expected = gaintable_expected.compute()

In [None]:
from ska_sdp_datamodels.calibration.calibration_model import GainTable
from ska_sdp_datamodels.visibility.vis_model import Visibility
import logging
import scipy
from ska_sdp_func_python.calibration.solvers import (
    _solve_antenna_gains_itsubs_nocrossdata,
    _solve_antenna_gains_itsubs_matrix,
    _solve_antenna_gains_itsubs_scalar
)
from ska_sdp_func_python.calibration.alternative_solvers import(
    _jones_sub_solve,
_normal_equation_solve,
_normal_equation_solve_with_presumming,
)

log = logging.getLogger("func-python-logger")

In [None]:

def _solve_with_mask(
    crosspol,
    gaintable_gain: numpy.ndarray,
    gaintable_weight: numpy.ndarray,
    gaintable_residual: numpy.ndarray,
    mask,
    niter,
    phase_only,
    row,
    tol,
    vis,
    x,
    xwt,
    refant,
    refant_sort,
):
    """
    Method extracted from solve_gaintable to decrease
    complexity. Calculations when `numpy.sum(mask) > 0`
    """
    x_shape = x.shape
    x[mask] = x[mask] / xwt[mask]
    x[~mask] = 0.0
    xwt[mask] = xwt[mask] / numpy.max(xwt[mask])
    xwt[~mask] = 0.0
    x = x.reshape(x_shape)
    if vis.visibility_acc.npol == 2 or (
        vis.visibility_acc.npol == 4 and not crosspol
    ):
        (
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            gaintable_residual[row, ...],
        ) = _solve_antenna_gains_itsubs_nocrossdata(
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            x,
            xwt,
            phase_only=phase_only,
            niter=niter,
            tol=tol,
            refant=refant,
            refant_sort=refant_sort,
        )
    elif vis.visibility_acc.npol == 4 and crosspol:
        (
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            gaintable_residual[row, ...],
        ) = _solve_antenna_gains_itsubs_matrix(
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            x,
            xwt,
            phase_only=phase_only,
            niter=niter,
            tol=tol,
            refant=refant,
            refant_sort=refant_sort,
        )

    else:
        (
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            gaintable_residual[row, ...],
        ) = _solve_antenna_gains_itsubs_scalar(
            gaintable_gain[row, ...],
            gaintable_weight[row, ...],
            x,
            xwt,
            phase_only=phase_only,
            niter=niter,
            tol=tol,
            refant=refant,
            refant_sort=refant_sort,
        )

In [None]:
def solve_with_alternative_algorithm(
    solver,
    vis: Visibility,
    model_vis: numpy.ndarray,
    gain,
    niter=100,
    row=0,
    tol=1e-6,
):
    """
    Solve this row (time slice) of the gain table.

    :param solver: Calibration solver to use. Options are:
        "jones_substitution", "normal_equations" and "normal_equations_presum"
    :param vis: Visibility containing the observed data_models
    :param modelvis: Visibility containing the visibility predicted by a model
    :param gain_table: Gaintable to be updated
    :param niter: Maximum number of iterations (default=100)
    :param row: Time slice to be calibrated (default=0)
    :param tol: Iteration stops when the fractional change
        in the gain solution is below this tolerance (default=1e-6)
    :return: GainTable gain_table, containing solution
    """
    gain = gain[row, ...]
    _, nchan_gt, nrec1, nrec2 = gain.shape
    ntime, nbl, nchan_vis, npol_vis = vis.vis.shape
    assert nrec1 == 2
    assert nrec1 == nrec2
    assert nrec1 * nrec2 == npol_vis
    assert nchan_gt in (1, nchan_vis)

    # incorporate flags into weights
    wgt = vis.weight.data * (1 - vis.flags.data)
    # flag the whole Jones matrix if any element is flagged
    wgt *= numpy.all(wgt > 0, axis=-1, keepdims=True)
    # reduce the dimension to a single weight per matrix
    #  - could weight pols separately, but may be better not to
    wgt = wgt[..., 0]

    vmdl = model_vis.reshape(ntime, nbl, nchan_vis, 2, 2)
    vobs = vis.vis.data.reshape(ntime, nbl, nchan_vis, 2, 2)

    # Update model if a starting solution is given.
    I2 = numpy.eye(2)
    ant1 = vis.antenna1.data
    ant2 = vis.antenna2.data
    if numpy.any(gain[..., :, :] != I2):
        vmdl = numpy.einsum(
            "bfpi,tbfij,bfqj->tbfpq",
            gain[ant1],
            vmdl,
            gain[ant2].conj(),
        )

    log.debug(
        "solve_with_alternative_algorithm: "
        + "solving for %d chan in %d sub-band[s] using solver %s",
        nchan_vis,
        nchan_gt,
        solver,
    )

    for ch in range(nchan_gt):
        # select channels to average over. Just the current one if solving
        # each channel separately, or all of them if this is a joint solution.
        chan_vis = [ch] if nchan_gt == nchan_vis else range(nchan_vis)

        log.debug(
            "solve_with_alternative_algorithm: "
            + "sub-band %d, processing %d channels:",
            ch,
            len(chan_vis),
        )

        # could handle this with classes or similar, but keep it simple for now
        if solver == "jones_substitution":
            _jones_sub_solve(
                vobs[:, :, chan_vis],
                vmdl[:, :, chan_vis],
                wgt[:, :, chan_vis],
                ant1,
                ant2,
                gain,
                ch,
                niter,
                tol,
            )

        elif solver == "normal_equations":
            _normal_equation_solve(
                vobs[:, :, chan_vis],
                vmdl[:, :, chan_vis],
                wgt[:, :, chan_vis],
                ant1,
                ant2,
                gain,
                ch,
                niter,
                tol,
            )

        elif solver == "normal_equations_presum":
            _normal_equation_solve_with_presumming(
                vobs[:, :, chan_vis],
                vmdl[:, :, chan_vis],
                wgt[:, :, chan_vis],
                ant1,
                ant2,
                gain,
                ch,
                niter,
                tol,
            )

        elif solver == "gain_substitution":
            raise ValueError(
                "solve_with_alternative_algorithm: "
                + f"solver {solver} cannot be used in this function",
            )
        else:
            raise ValueError(
                f"solve_with_alternative_algorithm: unknown solver: {solver}",
            )

    # _jones_phase_referencing(gain, refant)

    # update gain_table in case the data reference became a copy
    gain[row, ...] = gain

    return gain

In [None]:
def find_best_refant_from_vis(
    flagged_vis: numpy.ndarray,
    flagged_weight: numpy.ndarray,
    baselines: numpy.ndarray,
    nants: int,
):
    """
    This method comes from katsdpcal.
    (https://github.com/ska-sa/katsdpcal/blob/
    200c2f6e60b2540f0a89e7b655b26a2b04a8f360/katsdpcal/calprocs.py#L332)
    Determine antenna whose FFT has the maximum peak to noise ratio (PNR) by
    taking the median PNR of the FFT over all baselines to each antenna.

    When the input vis has only one channel, this uses all the vis of the
    same antenna for the operations peak, mean and std.

    :param vis: Visibilities
    :return: Array of indices of antennas in decreasing order
            of median of PNR over all baselines

    """
    visdata = flagged_vis
    _, _, nchan, _ = visdata.shape
    med_pnr_ants = numpy.zeros((nants))
    if nchan == 1:
        weightdata = flagged_weight
        for a in range(nants):
            mask = (baselines[:, 0] == a) ^ (baselines[:, 1] == a)
            weightdata_ant = weightdata[:, mask]
            mean_of_weight_ant = numpy.sum(weightdata_ant)
            med_pnr_ants[a] = mean_of_weight_ant
        med_pnr_ants += numpy.linspace(1e-8, 1e-9, nants)
    else:
        ft_vis = scipy.fftpack.fft(visdata, axis=2)
        max_value_arg = numpy.argmax(numpy.abs(ft_vis), axis=2)
        index = numpy.array(
            [numpy.roll(range(nchan), -n) for n in max_value_arg.ravel()]
        )
        index = index.reshape(list(max_value_arg.shape) + [nchan])
        index = numpy.transpose(index, (0, 1, 3, 2))
        ft_vis = numpy.take_along_axis(ft_vis, index, axis=2)

        peak = numpy.max(numpy.abs(ft_vis), axis=2)

        chan_slice = numpy.s_[
            nchan // 2 - nchan // 4 : nchan // 2 + nchan // 4 + 1  # noqa E203
        ]
        mean = numpy.mean(numpy.abs(ft_vis[:, :, chan_slice]), axis=2)
        std = numpy.std(numpy.abs(ft_vis[:, :, chan_slice]), axis=2) + 1e-9
        for a in range(nants):
            mask = (baselines[:, 0] == a) ^ (baselines[:, 1] == a)

            pnr = (peak[:, mask] - mean[:, mask]) / std[:, mask]
            med_pnr = numpy.median(pnr)
            med_pnr_ants[a] = med_pnr
    return numpy.argsort(med_pnr_ants)[::-1]

In [None]:

def divide_visibility(
        vis: Visibility,
        model_flagged_vis: numpy.ndarray
        ):
    """
    Divide visibility by model forming
    visibility for equivalent point source.

    This is a useful intermediate product for calibration.
    Variation of the visibility in time and frequency due
    to the model structure is removed and the data can be
    averaged to a limit determined by the instrumental stability.
    The weight is adjusted to compensate for the division.

    Zero divisions are avoided and the corresponding weight set to zero.

    :param vis: Visibility to be divided
    :param modelvis: Visibility to divide with
    :return: Divided Visibility
    """

    x = numpy.zeros_like(vis.visibility_acc.flagged_vis)
    xwt = (
        numpy.abs(model_flagged_vis) ** 2
        * vis.visibility_acc.flagged_weight
    )
    mask = xwt > 0.0
    x[mask] = (
        vis.visibility_acc.flagged_vis[mask]
        / model_flagged_vis[mask]
    )

    # pointsource_vis = Visibility.constructor(
    #     flags=vis.flags.data,
    #     baselines=vis.baselines,
    #     frequency=vis.frequency.data,
    #     channel_bandwidth=vis.channel_bandwidth.data,
    #     phasecentre=vis.phasecentre,
    #     configuration=vis.configuration,
    #     uvw=vis.uvw.data,
    #     time=vis.time.data,
    #     integration_time=vis.integration_time.data,
    #     vis=x,
    #     weight=xwt,
    #     source=vis.source,
    #     meta=vis.meta,
    #     polarisation_frame=vis.visibility_acc.polarisation_frame,
    # )
    # pointsource_vis.visibility_acc.imaging_weight = (
    #     vis.visibility_acc.imaging_weight.copy(deep=True)
    # )
    return (x, xwt)

In [None]:
def apply_flag(x: numpy.ndarray, flags: numpy.ndarray):
    return x * (1 - flags)

In [None]:
def solve_gaintable_new(
    vis: Visibility,
    # modelvis: Visibility,
    model_vis: numpy.ndarray,  # [time, baseline, freq, pol]
    model_flags: numpy.ndarray,  # [time, baseline, freq, pol]
    gain: numpy.ndarray,
    gain_weights: numpy.ndarray,
    gain_residual: numpy.ndarray,
    phase_only=True,
    niter=30,
    tol=1e-6,
    crosspol=False,
    normalise_gains="mean",
    solver="gain_substitution",
    jones_type="T",
    timeslice=None,
    refant=0,
):
    """
    Solve a gain table by fitting an observed visibility
    to a model visibility.

    If modelvis is None, a point source model is assumed.

    :param vis: Visibility containing the observed data_models
    :param modelvis: Visibility containing the visibility predicted by a model
    :param gain_table: Existing gaintable (default=None)
    :param phase_only: Solve only for the phases. default=True when
        solver="gain_substitution", otherwise it must be False.
    :param niter: Maximum number of iterations (default=30)
    :param tol: Iteration stops when the fractional change
        in the gain solution is below this tolerance (default=1e-6)
    :param crosspol: Do solutions including cross polarisations
        i.e. XY, YX or RL, LR. Only used by the gain_substitution solver.
    :param normalise_gains: Normalises the gains (default="mean")
        options are None, "mean", "median".
        None means no normalization. Only available with gain_substitution.
    :param solver: Calibration algorithm to use (default="gain_substitution")
        options are:
        "gain_substitution" - original substitution algorithm with separate
        solutions for each polarisation term.
        "jones_substitution" - solve antenna-based Jones matrices as a whole,
        with independent updates within each iteration.
        "normal_equations" - solve normal equations within each iteration
        formed from linearisation with respect to antenna-based gain and
        leakage terms.
        "normal_equations_presum" - the same as the normal_equations option
        but with an initial accumulation of visibility products over time
        and frequency for each solution interval. This can be much faster
        for large datasets and solution intervals.
    :param jones_type: Type of calibration matrix T or G or B
    :param timeslice: Time interval between solutions (s)
    :param refant: Reference antenna (default 0). Currently only activated for
        the gain_substitution solver.
    :return: GainTable containing solution

    """
    log.info("solve_gaintable: Starting calibration")
    log.info("solve_gaintable: Using solver %s", solver)

    if (model_vis is not None) ^ (model_flags is not None):
        raise ValueError("solve_gaintable: model_vis and model_flag are required")

    if model_vis is not None:
        # pylint: disable=unneeded-not
        if not numpy.max(numpy.abs(model_vis)) > 0.0:
            raise ValueError("solve_gaintable: Model visibility is zero")

    if model_vis is None and solver != "gain_substitution":
        raise ValueError(f"solve_gaintable: model_vis required for solver {solver}")

    if normalise_gains is not None and solver != "gain_substitution":
        raise ValueError(
            f"solve_gaintable: normalise_gains unsupported with {solver}",
        )

    if phase_only is True and solver != "gain_substitution":
        log.warning(
            "solve_gaintable: resetting phase_only to False for solver %s",
            solver,
        )
        phase_only = False

    # this is only needed by the gain_substitution solver
    if phase_only:
        log.debug("solve_gaintable: Solving for phase only")
    else:
        log.debug("solve_gaintable: Solving for complex gain")

    # if gain_table is None:
    #     log.debug("solve_gaintable: Creating new gaintable")
    #     gain_table = create_gaintable_from_visibility(
    #         vis, jones_type=jones_type, timeslice=timeslice
    #     )
    # else:
    #     log.debug("solve_gaintable: Starting from existing gaintable")

    nants = gain.shape[1]
    nchan = gain.shape[2]
    npol = vis.visibility_acc.npol

    model_flagged_vis = apply_flag(model_vis, model_flags)

    axes = (0, 2) if nchan == 1 else 0

    if solver == "gain_substitution":
        (pointvis_vis, pointvis_weight) = (
            divide_visibility(vis, model_flagged_vis)
            if model_vis is not None
            else (vis.vis, vis.weight)
        )
        pointvis_flags = vis.flags.data
        pointvis_flagged_vis = apply_flag(pointvis_vis, pointvis_flags)
        pointvis_flagged_weight = apply_flag(pointvis_weight, pointvis_flags)
        baselines = numpy.array(vis.baselines.data.tolist())
        # pointvis_sel = pointvis # type: ignore
        # pylint: disable=unneeded-not
        # if not pointvis_sel.visibility_acc.ntimes > 0:
        #     log.warning(
        #         "solve_gaintable: time mismatch. Gaintable %s, vis %s",
        #         gain_table.time,
        #         vis.time,
        #     )
        # continue

        refant_sort = find_best_refant_from_vis(
            pointvis_flagged_vis, pointvis_flagged_weight, baselines, nants
        )
        x_b = numpy.sum(
            (pointvis_vis * pointvis_weight) * (1 - pointvis_flags),
            axis=axes,
        )
        xwt_b = numpy.sum(
            pointvis_weight * (1 - pointvis_flags),
            axis=axes,
        )
        x = numpy.zeros([nants, nants, nchan, npol], dtype="complex")
        xwt = numpy.zeros([nants, nants, nchan, npol])
        for ibaseline, (a1, a2) in enumerate(baselines):
            x[a1, a2, ...] = numpy.conjugate(x_b[ibaseline, ...])
            xwt[a1, a2, ...] = xwt_b[ibaseline, ...]
            x[a2, a1, ...] = x_b[ibaseline, ...]
            xwt[a2, a1, ...] = xwt_b[ibaseline, ...]

        mask = numpy.abs(xwt) > 0.0
        if numpy.sum(mask) > 0:
            _solve_with_mask(
                crosspol,
                gain,
                gain_weights,
                gain_residual,
                mask,
                niter,
                phase_only,
                0,
                tol,
                vis,
                x,
                xwt,
                refant,
                refant_sort,
            )
        else:
            gain[0, ...] = 1.0 + 0.0j
            gain_weights[0, ...] = 0.0
            gain_residual[0, ...] = 0.0

    else:
        # the remaining solvers require separate observed and model vis

        vis_sel = vis
        # pylint: disable=unneeded-not
        # if not vis_sel.visibility_acc.ntimes > 0:
        #     log.warning(
        #         "solve_gaintable: time mismatch. Gaintable %s, vis %s",
        #         gain_table.time,
        #         vis.time,
        #     )
        # continue

        solve_with_alternative_algorithm(
            solver,
            vis_sel,
            model_vis,
            gain,
            niter,
            0,
            tol,
        )

    if normalise_gains in ["median", "mean"] and not phase_only:
        normaliser = {
            "median": numpy.median,
            "mean": numpy.mean,
        }
        gabs = normaliser[normalise_gains](numpy.abs(gain[:]))
        gain[:] /= gabs

    log.info("solve_gaintable: Finished calibration")

    return {"gain": gain}

In [None]:
# x = 2.0/0.0
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import restore_baselines_dim
input_vis = restore_baselines_dim(input_vis)
input_modelvis = restore_baselines_dim(input_modelvis)
input_vis = input_vis.compute()
input_modelvis = input_modelvis.compute()
initial_gaintable = initial_gaintable.compute()

In [None]:
gaintable_actual = solve_gaintable_new(
    vis=input_vis,
    model_vis=input_modelvis.vis.data,
    model_flags=input_modelvis.flags.data,
    gain=initial_gaintable['gain'].values,
    gain_weights=initial_gaintable['weight'].values,
    gain_residual=initial_gaintable['residual'].values,
    solver=solver,
    phase_only=phase_only,
    niter=niter,
    tol=tol,
    crosspol=crosspol,
    normalise_gains=normalise_gains,
    jones_type="B",
    refant=refant,
    timeslice=timeslice,
)

In [None]:
numpy.testing.assert_allclose(
    gaintable_expected["gain"].values,
    gaintable_actual["gain"],
    rtol=1.0e-6,
    atol=1.0e-6,
)
print("Gaintable values match expected results.")
