# Refactor Predict

## Imports

In [None]:
import sys

sys.path.insert(0, "../")

In [None]:
import os
import shutil

import dask
from distributed import Client
import numpy as np
import xarray as xr
from ska_sdp_datamodels.calibration import GainTable
from ska_sdp_datamodels.visibility import Visibility

from ska_sdp_instrumental_calibration.data_managers.visibility import (
    read_visibility_from_zarr,
)
from ska_sdp_instrumental_calibration.logger import setup_logger

from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput
from ska_sdp_instrumental_calibration.workflow.stages.load_data import load_data_stage
from ska_sdp_instrumental_calibration.workflow.utils import (
    create_bandpass_table,
    with_chunks,
)

from ska_sdp_piper.piper.utils.log_util import LogPlugin

from ska_sdp_func_python.calibration.solvers import solve_gaintable

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    restore_baselines_dim,
)
from ska_sdp_datamodels.calibration.calibration_create import (
    create_gaintable_from_visibility,
)

from notebook_utils import create_gaintable_from_vis_new, compare_arrays

np.random.seed(57)

logger = setup_logger("run_solver_ipynb")

## Setup Dask (optional)

In [None]:
# # if "local_cluster" not in globals():
# #     local_cluster = LocalCluster(
# #         n_workers=4, threads_per_worker=3, dashboard_address=":30088"
# #     )

# # if "dask_client" not in globals():
# #     # dask_client = local_cluster.get_client()
dask_client = Client("localhost:34567")
dask_client.forward_logging()
log_configure_plugin = LogPlugin(verbose=False)
dask_client.register_plugin(log_configure_plugin)

# # print(local_cluster.dashboard_link)

In [None]:
# if 'dask_client' in globals():
#     dask_client.close()
#     del dask_client

# if 'local_cluster' in globals():
#     local_cluster.close()
#     del local_cluster

## Setup inputs

In [None]:
input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.small.ms"
model_vis_vis_zarr = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/notebooks/refactors/actual_model_vis.vis.zarr"
# input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.ms"
# input_ms_path = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/data/demo.ms"

vis_cache_dir = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/cache"

# eb_ms = "/home/ska/Work/data/INST/sim/OSKAR_MOCK.ms"
eb_ms = input_ms_path

lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"
gleamfile = "/home/ska/Work/data/INST/sim/gleamegc.dat"

fov = 10.0
flux_limit = 1.0
alpha0 = -0.78

In [None]:
upstream_output = UpstreamOutput()

nchannels_per_chunk = 32
ntimes_per_ms_chunk = 5

upstream_output = load_data_stage.stage_definition(
    upstream_output,
    nchannels_per_chunk,
    ntimes_per_ms_chunk,
    vis_cache_dir,
    False,
    "DATA",
    0,
    0,
    {"input": input_ms_path},
    ".",
)
vis = upstream_output.vis

model_vis_vis = xr.open_dataarray(model_vis_vis_zarr, engine="zarr", chunks={})
modelvis = vis.assign(vis=model_vis_vis)

In [None]:
# # For testing jones_type = B
# timeslice = "full"
# jones_type = "B"

# For testing jones_type = G
timeslice = None
jones_type = "G"

# # To test call to run_solver
# og_gaintable = create_bandpass_table(vis).chunk(frequency=vis.chunksizes["frequency"])
# # To test raw call to solve_gaintable
og_gaintable = create_gaintable_from_visibility(
    vis, timeslice=timeslice, jones_type=jones_type
)
if timeslice and isinstance(timeslice, float):
    og_gaintable["interval"].data[0] = timeslice + 1e-5


new_gaintable = create_gaintable_from_vis_new(
    vis, timeslice=timeslice, jones_type=jones_type, lower_precision=False
)

## Run solver

In [None]:
def _solve_gaintable_new(
    vis: Visibility,
    modelvis: Visibility,
    gain_table: GainTable,
    phase_only: bool,
    niter: int,
    tol: float,
    crosspol: bool,
    normalise_gains: bool,
    solver: str,
    refant: int,
):
    """
    A map-block compatible wrapper function which internally calls
    `solve_gaintable` function.

    Returns
    -------
    Gaintable
        The gaintable xarray dataset
    """
    # Deep copy as solve_gaintable mutates gaintable
    gain_table = gain_table.copy(deep=True)

    # Rename time
    gain_table = gain_table.rename(solution_time="time")
    # For sol_types T and G, solution_frequency will be present
    REVERT_FREQUENCY = False
    if "solution_frequency" in gain_table.dims:
        REVERT_FREQUENCY = True
        gain_table = gain_table.rename(solution_frequency="frequency")

    vis = restore_baselines_dim(vis)
    modelvis = restore_baselines_dim(modelvis)

    solved_gaintable = solve_gaintable(
        vis=vis,
        modelvis=modelvis,
        gain_table=gain_table,
        phase_only=phase_only,
        niter=niter,
        tol=tol,
        crosspol=crosspol,
        normalise_gains=normalise_gains,
        solver=solver,
        refant=refant,
    )

    # Revert frequency change
    if REVERT_FREQUENCY:
        solved_gaintable = solved_gaintable.rename(frequency="solution_frequency")
    # Bring back solution_time
    solved_gaintable = solved_gaintable.rename(time="solution_time")

    return solved_gaintable


def run_solver_new(
    vis: Visibility,
    modelvis: Visibility,
    gaintable: GainTable,
    solver: str = "gain_substitution",
    refant: int = 0,
    niter: int = 200,
    phase_only: bool = False,
    tol: float = 1e-06,
    crosspol: bool = False,
    normalise_gains: str = None,
    # jones_type: Literal["T", "G", "B"] = "T",
    # timeslice: float = None,
) -> GainTable:
    """
    A generic function to solve for gaintables, given
    visibility, model visibility and gaintable.

    Parameters
    ----------
    vis: Visibility
        Visibility dataset containing observed data.
    modelvis: Visibility
        Visibility dataset containing model data.
    gaintable: Gaintable
        GainTable dataset containing initial solutions.
    solver: str, default: "gain_substitution"
        Solver type to use. Currently any solver type accepted by
        solve_gaintable.
    refant: int, default: 0
        Reference antenna. Note that how referencing is done
        depends on the solver.
    niter: int, default: 200
        Number of solver iterations.
    phase_only: bool, default: False
        Solve only for the phases.
    tol: float, default: 1e-06
        Iteration stops when the fractional change in the gain solution is
        below this tolerance.
    crosspol: bool, default: False
        Do solutions including cross polarisations.
    normalise_gains: str, default: "mean"
        Normalises the gains.

    Returns
    -------
    GainTable
        A new gaintabel xarray dataset, or the mutated input gaintable
    """
    if refant < 0 or refant >= len(gaintable.antenna):
        raise ValueError(f"Invalid refant: {refant}")

    vis_chunks_per_solution = {"time": -1}
    gaintable = gaintable.rename(time="solution_time")
    soln_interval_slices = gaintable.soln_interval_slices

    if gaintable.jones_type == "B":
        # solution frequency same as vis frequency
        # Chunking, just to be sure that they match
        gaintable = gaintable.chunk(frequency=vis.chunksizes["frequency"])
    else:  # jones_type == T or G
        assert gaintable.frequency.size == 1, "Gaintable frequency"
        "must either match to visibility frequency, or must be of size 1"
        gaintable = gaintable.rename(frequency="solution_frequency")
        # Need to pass full frequency to process single solution
        vis_chunks_per_solution["frequency"] = -1

    gaintable_across_solutions = []
    for idx, slc in enumerate(soln_interval_slices):
        template_gaintable = gaintable.isel(
            solution_time=[idx]
        )  # Select index but keep dimension
        gaintable_per_solution = xr.map_blocks(
            _solve_gaintable_new,
            vis.isel(time=slc).chunk(vis_chunks_per_solution),
            args=[
                modelvis.isel(time=slc).chunk(vis_chunks_per_solution),
                template_gaintable,
            ],
            kwargs={
                "phase_only": phase_only,
                "niter": niter,
                "tol": tol,
                "crosspol": crosspol,
                "normalise_gains": normalise_gains,
                "solver": solver,
                "refant": refant,
            },
            template=template_gaintable,
        )
        gaintable_across_solutions.append(gaintable_per_solution)

    combined_gaintable: GainTable = xr.concat(
        gaintable_across_solutions, dim="solution_time"
    )

    combined_gaintable = combined_gaintable.rename(solution_time="time")

    if "solution_frequency" in combined_gaintable.dims:
        combined_gaintable = combined_gaintable.rename(solution_frequency="frequency")

    return combined_gaintable

## Test

In [None]:
from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    run_solver as run_solver_og,
)


run_solver_config = {
    "solver": "gain_substitution",
    "refant": 0,
    "niter": 200,
    "phase_only": False,
    "tol": 1e-06,
    "crosspol": False,
    "normalise_gains": None,
}

In [None]:
# # Only for sol_type = B, test call to run_solver, dask distributed
# expected_gaintable = run_solver_og(
#     vis=vis.copy(deep=True).chunk(time=-1),
#     modelvis=modelvis.copy(deep=True).chunk(time=-1),
#     gaintable=og_gaintable,
#     **run_solver_config
# )

# For any gaintable, test raw call to solve_gaintable
# WARNING: Entire visibility should fit in memory
restored_vis = restore_baselines_dim(vis).copy(deep=True).load()
restored_modelvis = restore_baselines_dim(modelvis).copy(deep=True).load()
expected_gaintable = solve_gaintable(
    restored_vis,
    restored_modelvis,
    og_gaintable,
    **run_solver_config,
)

In [None]:
expected_gaintable_path = f"{os.getcwd()}/expected_gaintable.zarr"
shutil.rmtree(expected_gaintable_path, ignore_errors=True)

writer = expected_gaintable.drop_attrs().to_zarr(
    expected_gaintable_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_gaintable = run_solver_new(
    vis=vis, modelvis=modelvis, gaintable=new_gaintable, **run_solver_config
)

In [None]:
actual_gaintable_path = f"{os.getcwd()}/actual_gaintable.zarr"
shutil.rmtree(actual_gaintable_path, ignore_errors=True)

writer = actual_gaintable.drop_attrs().to_zarr(
    actual_gaintable_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_gaintable_zarr = xr.open_dataset(actual_gaintable_path, engine="zarr", chunks={})
expected_gaintable_zarr = xr.open_dataset(
    expected_gaintable_path, engine="zarr", chunks={}
)

In [None]:
compare_arrays(
    np.real(actual_gaintable_zarr.gain.data),
    np.real(expected_gaintable_zarr.gain.data),
    rtol=1e-32,
    atol=0,
    meta="Real gain values",
)

compare_arrays(
    np.imag(actual_gaintable_zarr.gain.data),
    np.imag(expected_gaintable_zarr.gain.data),
    rtol=1e-32,
    atol=0,
    meta="Imag gain values",
)

compare_arrays(
    (actual_gaintable_zarr.gain.data),
    (expected_gaintable_zarr.gain.data),
    rtol=1e-32,
    atol=0,
    meta="Complex gain values",
)

In [None]:
compare_arrays(
    np.real(actual_gaintable_zarr.weight.data),
    np.real(expected_gaintable_zarr.weight.data),
    rtol=1e-32,
    atol=0,
    meta="Real weight values",
)

compare_arrays(
    np.imag(actual_gaintable_zarr.weight.data),
    np.imag(expected_gaintable_zarr.weight.data),
    rtol=1e-32,
    atol=0,
    meta="Imag weight values",
)

compare_arrays(
    (actual_gaintable_zarr.weight.data),
    (expected_gaintable_zarr.weight.data),
    rtol=1e-32,
    atol=0,
    meta="Complex weight values",
)