# Refactor Predict

## Imports

In [None]:
import sys

eb_coeffs = f"{sys.exec_prefix}/share/everybeam/"

sys.path.insert(0, "../")

# sys.path.insert(0, "/home/maneesh/Work/SKAO/EveryBeam/build/python")
# eb_coeffs = "/home/maneesh/Work/SKAO/EveryBeam/coeffs"

In [None]:
import os
import shutil

import dask
import dask.array as da
from distributed import Client
import everybeam as eb
import numpy as np
import xarray as xr

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    predict_vis as predict_vis_og,
    prediction_central_beams as prediction_central_beams_og,
)
from ska_sdp_instrumental_calibration.logger import setup_logger

from ska_sdp_instrumental_calibration.processing_tasks.lsm import (
    generate_lsm_from_csv,
    generate_lsm_from_gleamegc,
)

from ska_sdp_instrumental_calibration.scheduler import UpstreamOutput
from ska_sdp_instrumental_calibration.workflow.stages.load_data import load_data_stage

from ska_sdp_datamodels.calibration.calibration_create import create_gaintable_from_visibility as og_create_gaintable_from_visibility 
from ska_sdp_instrumental_calibration.data_managers.gaintable import create_gaintable_from_visibility

from ska_sdp_piper.piper.utils.log_util import LogPlugin

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import (
    apply_gaintable_to_dataset as apply_gaintable_to_dataset_og,
)

from notebook_utils import compare_arrays


np.random.seed(57)

logger = setup_logger("predict_ipynb")

## Setup Dask (optional)

In [None]:
# # if "dask_client" not in globals():
# #     # dask_client = local_cluster.get_client()
dask_client = Client("localhost:34567")
dask_client.forward_logging()
log_configure_plugin = LogPlugin(verbose=False)
dask_client.register_plugin(log_configure_plugin)

# # print(local_cluster.dashboard_link)

## Setup inputs

In [None]:
input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.small.ms"
# input_ms_path = "/home/ska/Work/data/INST/lg3/cal_bpp_vis-lg3-rotated.ms"
# input_ms_path = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/data/demo.ms"

vis_cache_dir = "/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/cache"

# eb_ms = "/home/ska/Work/data/INST/sim/OSKAR_MOCK.ms"
eb_ms = input_ms_path

lsm_csv_path = "/home/ska/Work/data/INST/lg3/sky_model_cal.csv"
gleamfile = "/home/ska/Work/data/INST/sim/gleamegc.dat"

fov = 10.0
flux_limit = 1.0
alpha0 = -0.78

In [None]:
upstream_output = UpstreamOutput()

nchannels_per_chunk = 32
ntimes_per_ms_chunk = 5

upstream_output = load_data_stage.stage_definition(
    upstream_output,
    nchannels_per_chunk,
    ntimes_per_ms_chunk,
    vis_cache_dir,
    False,
    "DATA",
    0,
    0,
    {"input": input_ms_path},
    ".",
)
vis = upstream_output.vis

# vis = read_visibility_from_zarr("/home/maneesh/Work/SKAO/ska-sdp-instrumental-calibration/dhr_338cache/cal_bpp_vis-lg3-rotated.small.ms_fid0_ddid0", {})

lsm = generate_lsm_from_gleamegc(
    gleamfile=gleamfile,
    phasecentre=vis.phasecentre,
    fov=fov,
    flux_limit=flux_limit,
    alpha0=alpha0,
)

# lsm = generate_lsm_from_csv(
#     csvfile=lsm_csv_path,
#     phasecentre=vis.phasecentre,
#     fov=fov,
#     flux_limit=flux_limit,
# )

In [None]:
# # For testing jones_type = B
timeslice = "full"
jones_type = "B"

# For testing jones_type = G
# timeslice = None
# jones_type = "G"

# # To test calls to raw sdp functions
# og_gaintable = og_create_gaintable_from_visibility(vis, timeslice=timeslice, jones_type=jones_type)
# if timeslice and isinstance(timeslice, float):
#     og_gaintable["interval"].data[0] = timeslice + 1e-5

new_gaintable = create_gaintable_from_visibility(
    vis, timeslice=timeslice, jones_type=jones_type, lower_precision=False
)

In [None]:
beam_type = "everybeam"
# beam_type = None

In [None]:
station_rm = da.from_array(np.random.rand(new_gaintable.antenna.size))
station_rm_xdr = xr.DataArray(
    station_rm, name="station_rm", coords={"antenna": new_gaintable.antenna}
)

# station_rm = None
# station_rm_xdr = None

## Predict from components

In [None]:
from ska_sdp_instrumental_calibration.dask_wrappers.predict import predict_vis as predict_vis_new

## Compare predicted vis

In [None]:
expected_model_vis = predict_vis_og(
    vis.copy(deep=True).chunk(time=-1),
    lsm,
    beam_type=beam_type,
    eb_ms=eb_ms,
    eb_coeffs=eb_coeffs,
    station_rm=station_rm_xdr,
)

In [None]:
expected_model_vis_path = f"{os.getcwd()}/expected_model_vis.vis.zarr"
shutil.rmtree(expected_model_vis_path, ignore_errors=True)

writer = expected_model_vis.vis.to_zarr(
    expected_model_vis_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
from ska_sdp_instrumental_calibration.data_managers.local_sky_model import GlobalSkyModel
from ska_sdp_instrumental_calibration.processing_tasks.predict_model.beams import BeamsFactory

gsm = GlobalSkyModel(
        vis.configuration.location,
        vis.phasecentre,
        fov,
        flux_limit,
        alpha0,
        gleamfile,
        lsm_csv_path,
    )

beams_factory = None
if beam_type == "everybeam":
    logger.info("Using EveryBeam model in predict")

    beams_factory = BeamsFactory(
        nstations=vis.configuration.id.size,
        array_location=vis.configuration.location,
        direction=vis.phasecentre,
        ms_path=eb_ms,
    )

actual_model_vis = predict_vis_new(
    vis,
    gsm,
    new_gaintable.time.data,
    new_gaintable.soln_interval_slices,
    beams_factory,
    station_rm=station_rm_xdr,
)

In [None]:
actual_model_vis_path = f"{os.getcwd()}/actual_model_vis.vis.zarr"
shutil.rmtree(actual_model_vis_path, ignore_errors=True)

writer = actual_model_vis.vis.to_zarr(actual_model_vis_path, mode="w", compute=False)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_model_vis_zarr = xr.open_dataarray(
    actual_model_vis_path, engine="zarr", chunks={}
)
expected_model_vis_zarr = xr.open_dataarray(
    expected_model_vis_path, engine="zarr", chunks={}
)

In [None]:
compare_arrays(
    np.real(actual_model_vis_zarr.data),
    np.real(expected_model_vis_zarr.data),
    rtol=1e-16,
    atol=0,
    meta="Real values",
)

compare_arrays(
    np.imag(actual_model_vis_zarr.data),
    np.imag(expected_model_vis_zarr.data),
    rtol=1e-16,
    atol=0,
    meta="Imag values",
)

## Prediction central beams

In [None]:
from ska_sdp_instrumental_calibration.dask_wrappers.beams import prediction_central_beams

## Test central beams prediction

In [None]:
expected_beams = prediction_central_beams_og(
    vis.copy(deep=True).chunk(time=-1),
    beam_type=beam_type,
    eb_ms=eb_ms,
    eb_coeffs=eb_coeffs,
)

expected_beams.load();

In [None]:
actual_beams = prediction_central_beams(
    new_gaintable,
    beams_factory
)

actual_beams.load();

In [None]:
compare_arrays(
    np.real(actual_beams.gain.data.astype(np.complex64)),
    np.real(expected_beams.gain.data),
    rtol=1e-16,
    atol=0,
    meta="Real values",
)

compare_arrays(
    np.imag(actual_beams.gain.data.astype(np.complex64)),
    np.imag(expected_beams.gain.data),
    rtol=1e-16,
    atol=0,
    meta="Imag values",
)

compare_arrays(
    (actual_beams.gain.data.astype(np.complex64)),
    (expected_beams.gain.data),
    rtol=1e-16,
    atol=0,
    meta="Complex values",
)

## Apply antenna gains

In [None]:
from ska_sdp_instrumental_calibration.dask_wrappers.apply import apply_gaintable_to_dataset

## Test apply central beams

In [None]:
beams_to_apply = expected_beams.copy(deep=True)
# Applying actual beams itself to test this function in isolation
beams_to_apply.gain.data = actual_beams.gain.data
beams_to_apply = beams_to_apply.chunk(frequency=vis.chunksizes["frequency"])

expected_beam_corr_vis = apply_gaintable_to_dataset_og(
    vis.copy(deep=True).chunk(time=-1), beams_to_apply, inverse=True
)

In [None]:
expected_beam_corr_vis_path = f"{os.getcwd()}/expected_beam_corr_vis.vis.zarr"
shutil.rmtree(expected_beam_corr_vis_path, ignore_errors=True)

writer = expected_beam_corr_vis.vis.to_zarr(
    expected_beam_corr_vis_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_beam_corr_vis = apply_gaintable_to_dataset(vis, actual_beams, inverse=True)

In [None]:
actual_beam_corr_vis_path = f"{os.getcwd()}/actual_beam_corr_vis.vis.zarr"
shutil.rmtree(actual_beam_corr_vis_path, ignore_errors=True)

writer = actual_beam_corr_vis.vis.to_zarr(
    actual_beam_corr_vis_path, mode="w", compute=False
)
dask.compute(writer, optimize_graph=True)

In [None]:
actual_beam_corr_vis_zarr = xr.open_dataarray(
    actual_beam_corr_vis_path, engine="zarr", chunks={}
)
expected_beam_corr_vis_zarr = xr.open_dataarray(
    expected_beam_corr_vis_path, engine="zarr", chunks={}
)

In [None]:
compare_arrays(
    np.real(actual_beam_corr_vis_zarr.data),
    np.real(expected_beam_corr_vis_zarr.data),
    rtol=1e-32,
    atol=0,
    meta="Real values",
)

compare_arrays(
    np.imag(actual_beam_corr_vis_zarr.data),
    np.imag(expected_beam_corr_vis_zarr.data),
    rtol=1e-32,
    atol=0,
    meta="Imag values",
)

compare_arrays(
    (actual_beam_corr_vis_zarr.data),
    (expected_beam_corr_vis_zarr.data),
    rtol=1e-32,
    atol=0,
    meta="Complex values",
)