# Refactor Predict

## Imports

In [None]:
import sys

eb_coeffs = "/home/maneesh/.cache/pypoetry/virtualenvs/ska-sdp-instrumental-calibration-vujiG8jS-py3.10/share/everybeam/"

# sys.path.insert(0, "/home/maneesh/Work/SKAO/EveryBeam/build/python")
# eb_coeffs = "/home/maneesh/Work/SKAO/EveryBeam/coeffs"

In [None]:
import gc
import importlib
import os
import shutil
from typing import Literal, Optional, Union

import dask
import dask.array as da
from distributed import Client, LocalCluster, performance_report
import everybeam as eb
import numpy as np
import numpy.typing as npt
import xarray as xr
from astropy import constants as const
from astropy.coordinates import ITRS, AltAz, SkyCoord
from astropy.time import Time
from ska_sdp_datamodels.calibration import GainTable
from ska_sdp_datamodels.configuration import Configuration
from ska_sdp_datamodels.sky_model import SkyComponent
from ska_sdp_datamodels.visibility import Visibility
from ska_sdp_func_python.imaging.dft import dft_skycomponent_visibility
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    predict_vis,
    prediction_central_beams,
    restore_baselines_dim,
    simplify_baselines_dim,
)
from ska_sdp_instrumental_calibration.data_managers.visibility import (
    read_dataset_from_zarr,
)
from ska_sdp_instrumental_calibration.logger import setup_logger
from ska_sdp_instrumental_calibration.processing_tasks.beams import radec_to_xyz
from ska_sdp_instrumental_calibration.processing_tasks.lsm import (
    Component,
    convert_model_to_skycomponents,
    generate_lsm_from_csv,
    generate_lsm_from_gleamegc,
)
from ska_sdp_instrumental_calibration.processing_tasks.predict import (
    GenericBeams,
    dft_skycomponent_local,
    gaussian_tapers,
    generate_rotation_matrices,
    predict_from_components,
)
from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput
from ska_sdp_instrumental_calibration.workflow.stages.load_data import load_data_stage
from ska_sdp_instrumental_calibration.workflow.utils import (
    create_bandpass_table,
    with_chunks,
)
from xarray.core.utils import Frozen
from ska_sdp_instrumental_calibration.workflow.utils import create_solint_slices

from ska_sdp_piper.piper.utils.log_util import LogPlugin

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    apply_gaintable_to_dataset,
)

from ska_sdp_func_python.calibration.solvers import solve_gaintable

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    restore_baselines_dim,
)


from ska_sdp_datamodels.science_data_model import ReceptorFrame

from ska_sdp_instrumental_calibration.workflow.utils import (
    get_indices_from_grouped_bins,
    get_intervals_from_grouped_bins,
)


np.random.seed(57)

logger = setup_logger("run_solver_ipynb")

## Setup Dask (optional)

In [None]:
# # if "local_cluster" not in globals():
# #     local_cluster = LocalCluster(
# #         n_workers=4, threads_per_worker=3, dashboard_address=":30088"
# #     )

# # if "dask_client" not in globals():
# #     # dask_client = local_cluster.get_client()
dask_client = Client("localhost:34567")
dask_client.forward_logging()
log_configure_plugin = LogPlugin(verbose=False)
dask_client.register_plugin(log_configure_plugin)

# # print(local_cluster.dashboard_link)

In [None]:
# if 'dask_client' in globals():
#     dask_client.close()
#     del dask_client

# if 'local_cluster' in globals():
#     local_cluster.close()
#     del local_cluster

## Utils

In [None]:
def compare_arrays(
    actual: np.ndarray,
    expected: np.ndarray,
    rtol=1e-6,
    atol=1e-12,
    meta="Data",
    output: Literal["print", "log", "raise"] = "print",
):
    """
    Compare two arrays (NumPy or Dask) element-wise with tolerance thresholds.

    This function checks absolute and relative differences between two arrays and
    reports the maximum absolute difference, maximum relative difference, and the
    number and percentage of elements that differ beyond the specified tolerances.

    If either input is a Dask array, all required values are finalized in a single
    `dask.compute` call to minimize graph execution overhead.

    Parameters
    ----------
    actual : numpy.ndarray or dask.array.Array
        The array containing computed or observed values.
    expected : numpy.ndarray or dask.array.Array
        The reference array to compare against.
    rtol : float, default 1e-6
        Relative tolerance.
    atol : float, default 1e-12
        Absolute tolerance.
    meta : str, default "Data"
        Label used for output messages.
    output: str, Literal["print", "log", "raise"]
        Determines how to present the output
        print: Prints the comparision message
        log: Logs the comparision message
        raise: Raises AssertionError if values don't match
    """
    actions = {
        "print": lambda level, m: print(m),
        "log": lambda level, m: (
            logger.error(m) if level == "error" else logger.info(m)
        ),
        "raise": lambda level, m: (
            (_ for _ in ()).throw(AssertionError(m)) if level == "error" else None
        ),
    }

    is_dask = isinstance(actual, da.Array) or isinstance(expected, da.Array)

    if actual.dtype != expected.dtype:
        actions[output](
            "info",
            f"{meta}: dtype mismatch : actual: {actual.dtype}, expected: {expected.dtype}.\n"
            f"Comparison may be influenced by different precision.",
        )

    diff = actual - expected

    abs_diff = np.abs(diff)
    abs_diff_max = abs_diff.max()

    # element-wise relative difference (safe for expected == 0)
    rel_diff = abs_diff / np.maximum(np.abs(expected), 1e-30)
    rel_diff_max = rel_diff.max()

    # Decide rule for comparision
    # diff_mask = (abs_diff > atol) & (rel_diff > rtol) # More permissive
    diff_mask = abs_diff > (atol + rtol * np.abs(expected))  # numpy style comparision
    num_diff = diff_mask.sum()

    if is_dask:
        abs_diff_max, rel_diff_max, num_diff = da.compute(
            abs_diff_max, rel_diff_max, num_diff
        )

    total = actual.size

    if num_diff > 0:
        msg = (
            f"{meta}: do not match for atol={atol}, rtol={rtol}\n"
            f"\tmax abs diff = {abs_diff_max}\n"
            f"\tmax rel diff = {rel_diff_max}\n"
            f"\tdifferent elements: {num_diff} / {total} ({num_diff / total * 100:.6f}%)"
        )
        actions[output]("error", msg)
    else:
        msg = f"{meta}: match within atol={atol}, rtol={rtol}"
        actions[output]("info", msg)

In [None]:
# def create_gaintable_from_vis_new(
#     vis: Visibility,
#     timeslice: Union[float, Literal["auto"], None] = None,
#     jones_type: Literal["T", "G", "B"] = "T",
# ):
#     # TODO: Write full logic
#     # Reference the original create_gaintable_from_visibility function

#     # TODO: To simplify dask operations,
#     # if timeslice is provided such that 1 < len(solution_time) < len(vis.time)
#     # then still the gaintable must have len(solution_time) = len(vis.time)
#     # where the solution_time values are duplicated
#     # this can be optimized in the future

#     gaintable = create_bandpass_table(vis)
#     # Since its bandpass table, single interval contains whole time chunk
#     gaintable.attrs["soln_interval_slices"] = create_solint_slices(
#         vis.time, timeslice, True
#     )

#     # This should be part of the dim creation logic itself
#     gaintable = gaintable.rename(time="solution_time", frequency="solution_frequency")

#     # # Following chunking logics should be part of the data creation logic itself
#     # This is when we are sure that rest of the code can handle different solutions times
#     gaintable = gaintable.chunk({"solution_time": 1})

#     # if gaintable.jones_type == "B":
#     if gaintable.solution_frequency.size == vis.frequency.size:
#         gaintable = gaintable.chunk({"solution_frequency": vis.chunksizes["frequency"]})

#     return gaintable

In [None]:
def create_gaintable_from_vis_new(
    vis: Visibility,
    timeslice: Union[float, Literal["auto"], None] = None,
    jones_type: Literal["T", "G", "B"] = "T",
    lower_precision: bool = True,
):
    """
    Similar behavior as create_gaintable_from_vis, except
    1. Creates dask backed data variables
    2. Ability to toggle precision of the data variables, currently between 4 or 8 bytes
    """
    # Backward compatibility
    if timeslice == "auto":
        timeslice = None
    # TODO: review this time slice creation logic
    gain_time_bins = create_solint_slices(vis.time, timeslice, False)
    gain_time = gain_time_bins.mean().data
    gain_interval = get_intervals_from_grouped_bins(gain_time_bins)
    ntimes = len(gain_time)

    nants = vis.visibility_acc.nants

    # Set the frequency sampling
    if jones_type == "B":
        gain_frequency = vis.frequency.data
        nfrequency = len(gain_frequency)
    elif jones_type in ("G", "T"):
        gain_frequency = np.mean(vis.frequency.data, keepdims=True)
        nfrequency = 1
    else:
        raise ValueError(f"Unknown Jones type {jones_type}")

    # There is only one receptor frame in Visibility
    # Use it for both receptor1 and receptor2
    receptor_frame = ReceptorFrame(vis.visibility_acc.polarisation_frame.type)
    nrec = receptor_frame.nrec

    gain_shape = [ntimes, nants, nfrequency, nrec, nrec]

    # Create dask backed data variables
    comp_dtype, fl_dtype = np.complex128, np.float64
    if lower_precision:
        comp_dtype, fl_dtype = np.complex64, np.float32

    gain = da.broadcast_to(da.eye(nrec, dtype=comp_dtype), gain_shape)
    gain_weight = da.ones(gain_shape, dtype=fl_dtype)
    gain_residual = da.zeros([ntimes, nfrequency, nrec, nrec], dtype=fl_dtype)

    gain_table = GainTable.constructor(
        gain=gain,
        time=gain_time,
        interval=gain_interval,
        weight=gain_weight,
        residual=gain_residual,
        frequency=gain_frequency,
        receptor_frame=receptor_frame,
        phasecentre=vis.phasecentre,
        configuration=vis.configuration,
        jones_type=jones_type,
    )

    # Rename dimensions to be more explicity about their usage
    gain_table = gain_table.rename(time="solution_time", frequency="solution_frequency")

    # Chunk data variables
    gain_table = gain_table.chunk({"solution_time": 1})
    if gain_table.solution_frequency.size == vis.frequency.size:
        gain_table = gain_table.chunk(
            {"solution_frequency": vis.chunksizes["frequency"]}
        )

    # Logic duplicated from create_solint_slices
    gain_table.attrs["soln_interval_slices"] = get_indices_from_grouped_bins(
        gain_time_bins
    )

    return gain_table

## Setup inputs

In [None]:
input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.small.ms"
model_vis_vis_zarr = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/notebooks/refactors/expected_model_vis.vis.zarr"
# input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.ms"
# input_ms_path = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/data/demo.ms"

vis_cache_dir = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/cache"

# eb_ms = "/home/ska/Work/data/INST/sim/OSKAR_MOCK.ms"
eb_ms = input_ms_path

lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"
gleamfile = "/home/ska/Work/data/INST/sim/gleamegc.dat"

fov = 10.0
flux_limit = 1.0
alpha0 = -0.78

In [None]:
upstream_output = UpstreamOutput()

nchannels_per_chunk = 64
ntimes_per_ms_chunk = 16

upstream_output = load_data_stage.stage_definition(
    upstream_output,
    nchannels_per_chunk,
    ntimes_per_ms_chunk,
    vis_cache_dir,
    False,
    "DATA",
    0,
    0,
    {"input": input_ms_path},
    ".",
)
vis = upstream_output.vis

# Explcitly load it, in case its a dask array
vis.antenna1.load()
vis.antenna2.load()

model_vis_vis = xr.open_dataarray(model_vis_vis_zarr, engine="zarr", chunks={})

modelvis = vis.assign(vis=model_vis_vis)

og_gaintable = create_bandpass_table(vis).chunk(frequency=vis.chunksizes["frequency"])

timeslice = np.max(vis.time.data) - np.min(vis.time.data)
new_gaintable = create_gaintable_from_vis_new(
    vis, timeslice=timeslice, jones_type="B", lower_precision=True
)

## Run solver

In [None]:
def _solve_gaintable_new(
    vis: Visibility,
    modelvis: Visibility,
    gain_table: GainTable,
    phase_only: bool,
    niter: int,
    tol: float,
    crosspol: bool,
    normalise_gains: bool,
    solver: str,
    refant: int,
):
    """
    A map-block compatible wrapper function which internally calls
    `solve_gaintable` function.

    Returns
    -------
    Gaintable
        The gaintable xarray dataset
    """
    # Deep copy as solve_gaintable mutates gaintable
    gain_table = gain_table.copy(deep=True)

    # Rename time
    gain_table = gain_table.rename(solution_time="time")
    # For sol_types T and G, solution_frequency will be present
    REVERT_FREQUENCY = False
    if "solution_frequency" in gain_table.dims:
        REVERT_FREQUENCY = True
        gain_table = gain_table.rename(solution_frequency="frequency")

    vis = restore_baselines_dim(vis)
    modelvis = restore_baselines_dim(modelvis)

    solved_gaintable = solve_gaintable(
        vis=vis,
        modelvis=modelvis,
        gain_table=gain_table,
        phase_only=phase_only,
        niter=niter,
        tol=tol,
        crosspol=crosspol,
        normalise_gains=normalise_gains,
        solver=solver,
        refant=refant,
    )

    # Revert frequency change
    if REVERT_FREQUENCY:
        gain_table = gain_table.rename(frequency="solution_frequency")
    # Bring back solution_time
    solved_gaintable = solved_gaintable.rename(time="solution_time")

    return solved_gaintable


def run_solver_new(
    vis: Visibility,
    modelvis: Visibility,
    gaintable: GainTable,
    solver: str = "gain_substitution",
    refant: int = 0,
    niter: int = 200,
    phase_only: bool = False,
    tol: float = 1e-06,
    crosspol: bool = False,
    normalise_gains: str = None,
    # jones_type: Literal["T", "G", "B"] = "T",
    # timeslice: float = None,
) -> GainTable:
    """
    A generic function to solve for gaintables, given
    visibility, model visibility and gaintable.

    Parameters
    ----------
    vis: Visibility
        Visibility dataset containing observed data.
    modelvis: Visibility
        Visibility dataset containing model data.
    gaintable: Gaintable
        GainTable dataset containing initial solutions.
    solver: str, default: "gain_substitution"
        Solver type to use. Currently any solver type accepted by
        solve_gaintable.
    refant: int, default: 0
        Reference antenna. Note that how referencing is done
        depends on the solver.
    niter: int, default: 200
        Number of solver iterations.
    phase_only: bool, default: False
        Solve only for the phases.
    tol: float, default: 1e-06
        Iteration stops when the fractional change in the gain solution is
        below this tolerance.
    crosspol: bool, default: False
        Do solutions including cross polarisations.
    normalise_gains: str, default: "mean"
        Normalises the gains.

    Returns
    -------
    GainTable
        A new gaintabel xarray dataset, or the mutated input gaintable
    """
    if refant < 0 or refant >= len(gaintable.antenna):
        raise ValueError(f"Invalid refant: {refant}")

    if gaintable.jones_type == "B":
        gaintable = gaintable.rename(solution_frequency="frequency")
        # solution frequency same as vis frequency
        # Chunking, just to be sure that they match
        gaintable = gaintable.chunk(frequency=vis.chunksizes["frequency"])
    else:  # jones_type == T or G
        assert gaintable.solution_frequency.size == 1, "Gaintable solution frequency"
        "must either match to visibility frequency, or must be of size 1"

    soln_interval_slices = gaintable.soln_interval_slices

    gaintable_across_solutions = []
    for idx, slc in enumerate(soln_interval_slices):
        template_gaintable = gaintable.isel(
            solution_time=[idx]
        )  # Select index but keep dimension
        gaintable_per_solution = xr.map_blocks(
            _solve_gaintable_new,
            vis.isel(time=slc).chunk(time=-1),
            args=[
                modelvis.isel(time=slc).chunk(time=-1),
                template_gaintable,
            ],
            kwargs={
                "phase_only": phase_only,
                "niter": niter,
                "tol": tol,
                "crosspol": crosspol,
                "normalise_gains": normalise_gains,
                "solver": solver,
                "refant": refant,
            },
            template=template_gaintable,
        )
        gaintable_across_solutions.append(gaintable_per_solution)

    combined_gaintable: GainTable = xr.concat(
        gaintable_across_solutions, dim="solution_time"
    )

    if combined_gaintable.jones_type == "B":
        combined_gaintable = combined_gaintable.rename(frequency="solution_frequency")

    return combined_gaintable

## Test

In [None]:
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    run_solver as run_solver_og,
)


run_solver_config = {
    "solver": "gain_substitution",
    "refant": 0,
    "niter": 200,
    "phase_only": False,
    "tol": 1e-06,
    "crosspol": False,
    "normalise_gains": None,
    # "jones_type": "T",
    # "timeslice": None,
}

In [None]:
expected_gaintable = run_solver_og(
    vis=vis, modelvis=modelvis, gaintable=og_gaintable, **run_solver_config
)

In [None]:
expected_gaintable_path = f"{os.getcwd()}/expected_gaintable.zarr"
shutil.rmtree(expected_gaintable_path, ignore_errors=True)

writer = expected_gaintable.drop_attrs().to_zarr(
    expected_gaintable_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_gaintable = run_solver_new(
    vis=vis, modelvis=modelvis, gaintable=new_gaintable, **run_solver_config
)

In [None]:
actual_gaintable_path = f"{os.getcwd()}/actual_gaintable.zarr"
shutil.rmtree(actual_gaintable_path, ignore_errors=True)

writer = actual_gaintable.drop_attrs().to_zarr(
    actual_gaintable_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_gaintable_zarr = xr.open_dataset(actual_gaintable_path, engine="zarr", chunks={})
expected_gaintable_zarr = xr.open_dataset(
    expected_gaintable_path, engine="zarr", chunks={}
)

In [None]:
compare_arrays(
    np.real(actual_gaintable_zarr.gain.data),
    np.real(expected_gaintable_zarr.gain.data),
    rtol=1e-32,
    atol=0,
    meta="Real gain values",
)

compare_arrays(
    np.imag(actual_gaintable_zarr.gain.data),
    np.imag(expected_gaintable_zarr.gain.data),
    rtol=1e-32,
    atol=0,
    meta="Imag gain values",
)

compare_arrays(
    (actual_gaintable_zarr.gain.data),
    (expected_gaintable_zarr.gain.data),
    rtol=1e-32,
    atol=0,
    meta="Complex gain values",
)

In [None]:
compare_arrays(
    np.real(actual_gaintable_zarr.weight.data),
    np.real(expected_gaintable_zarr.weight.data),
    rtol=1e-32,
    atol=0,
    meta="Real weight values",
)

compare_arrays(
    np.imag(actual_gaintable_zarr.weight.data),
    np.imag(expected_gaintable_zarr.weight.data),
    rtol=1e-32,
    atol=0,
    meta="Imag weight values",
)

compare_arrays(
    (actual_gaintable_zarr.weight.data),
    (expected_gaintable_zarr.weight.data),
    rtol=1e-32,
    atol=0,
    meta="Complex weight values",
)