# Validate Bandpass Polarisation Stage

## Imports

In [None]:
from ska_sdp_instrumental_calibration.workflow.stages.bandpass_calibration import (
    bandpass_calibration_stage,
)

from ska_sdp_instrumental_calibration.data_managers.dask_wrappers import run_solver

from ska_sdp_instrumental_calibration.data_managers.visibility import (
    read_dataset_from_zarr,
    write_ms_to_zarr,
)
from ska_sdp_instrumental_calibration.workflow.utils import (
    create_bandpass_table,
    with_chunks,
)
from ska_sdp_datamodels.calibration.calibration_create import (
    create_gaintable_from_visibility,
)

import numpy as np
import xarray as xr

import os

In [None]:
# If you need to connect to a dask client
# from distributed import Client

# client = Client()

## Generate data

We use the scripts stored in `scripts/ska_low_sim` to generate the data.
Information about how to use the scripts is present in this [confluence page](https://confluence.skatelescope.org/x/eZHMF).

### Simulation config

The following is the custom configuration used for simulation (further referred to as `custom_sim.yaml`)

```yaml
scenario: "low40s-model"          # Scenario name (used for output folder prefix)

# ===============================
# Global simulation parameters
# ===============================

n_stations: 40                                         # Number of stations
tel_model: "./SKA-Low_AA2_40S_rigid-rotation_model.tm" # Telescope model directory

simulation_start_frequency_hz: 123.0e6                  # Start frequency (Hz)
simulation_end_frequency_hz: 153.0e6                    # End frequency (Hz)
correlated_channel_bandwidth_hz: 21.70138888888889e3    # Channel width (Hz)

observing_time_mins: 10                              # Observation duration (minutes)
sampling_time_sec: 3.3973862400000003                   # Dump/integration time (seconds)

fields:
  EoR2:
    Cal1:
      ra_deg: 197.914612
      dec_deg: -22.277973
      scan_id_start: 300
      transit_time: "2000-01-03 22:33:30.000"

# ==================================
# Options for generate_gaintable.py
# ==================================

generate_gaintable:
  output_gaintable: &gen_gaintable "./gaintables/custom_gaintable.h5"

  spline_data_path: "./SKA_Low_AA2_SP5175_spline_data.npz" # Bandpass spline data file
  station_offset: true              # Apply per-station amplitude/phase offsets
  time_variant: true                # Apply time-dependent effects

  rfi: false                        # Inject RFI band
  rfi_start_freq_hz: 154.25347222228538e6        # Hz
  rfi_end_freq_hz: 159.8090277778474e6           # Hz

  plot: true                        # Generate diagnostic plots
  plot_output_dir: "./gaintables/generation_plots/"

# ===============================
# Options for run_sim.py
# ===============================

run_sim:
  oskar_sif: "./OSKAR-2.11.1-Python3.sif" # Path to OSKAR Singularity image

  # GLEAM sky model. Optional. Comment to disable.
  gleam_file: "./sky-models/GLEAM_EGC.fits" # GLEAM catalogue FITS file
  field_radius_deg: 10.0            # Radius of field of view (degrees)

  # Corruptions to be applied. All are optional. Comment to disable.
  # gaintable: *gen_gaintable           # Gaintable containing bandpass corruptions
  # cable_delay: "./cable_delays/cable_length_error_40s.txt" # Cable delay error file
  # tec_screen: "./tec/calibrator_iono_tec.fits" # Ionospheric TEC screen FITS

  # Imaging parameters using wsclean. Optional. Comment to disable.
  create_dirty_image: true          # Whether to run wsclean imaging
  image_size: 1024                  # Image size (pixels)
  pixel_size: "2arcsec"             # Pixel size (angular units)

  # Extra parameters to pass directly to run_oskar.py
  run_oskar_extra_params: "--use-gpus --double-precision"
  ```

### Overview of steps followed

Assuming the you have followed steps on the [confluence page](https://confluence.skatelescope.org/x/eZHMF) to get the data and setup python environment, here are the general steps to follow to simulate the data needed to run this script:

1. Create custom gaintable using the the above config and generate_gaintable script

    ```bash
    python3 generate_gaintable.py custom_sim.yaml
    ```

2. Convert gaintable to DP3 h5parm file

    ```bash
    python3 utils/h5parm_from_oskar_gains.py custom_sim.yaml "./gaintables/custom_gaintable.h5"
    ```

3. Create "model" visibilities

    ```bash
    python run_sim.py custom_sim.yaml
    ```

4. Apply the corruptions using DP3 to create corrupted visibilities

    ```bash
    msin="path_to_the_model_visibilities"
    parmdb="./gaintables/custom_gaintable.h5parm"
    mout="corrupted_visibilities.ms"
    DP3_CMD=${DP3_CMD:-"DP3"}

    ${DP3_CMD} steps="[applycal]" msin=${msin}\
	  applycal.invert=False applycal.parmdb=${parmdb}\
	  applycal.missingantennabehavior=flag applycal.updateweights=True\
	  applycal.steps="[phase, amplitude]" applycal.phase.correction=phase000\
	  applycal.amplitude.correction=amplitude000 msout=${msout}
    ```


## Convert MS to Zarr files

In [None]:
# Set these paths to the simulated data generated in previous steps

vis_ms = ""
model_ms = ""

In [None]:
working_dir = os.getcwd()

vis_cache_directory = f"{working_dir}/cache/vis/"
model_cache_directory = f"{working_dir}/cache/model/"

output_dir = f"{working_dir}/output/"

In [None]:
nchannels_per_chunk = 20
ntimes_per_ms_chunk = 20

# Common dimensions across zarr and loaded visibility dataset
non_chunked_dims = {
    dim: -1
    for dim in [
        "baselineid",
        "polarisation",
        "spatial",
    ]
}

# This is chunking of the intermidiate zarr file
zarr_chunks = {
    **non_chunked_dims,
    "time": ntimes_per_ms_chunk,
    "frequency": nchannels_per_chunk,
}

# Pipeline only works on frequency chunks
# Its expected that later stages follow same chunking pattern
vis_chunks = {
    **non_chunked_dims,
    "time": -1,
    "frequency": nchannels_per_chunk,
}

In [None]:
write_ms_to_zarr(vis_ms, vis_cache_directory, zarr_chunks)

In [None]:
write_ms_to_zarr(model_ms, model_cache_directory, zarr_chunks)

## Load visibilities from zarr and create initial gaintable

In [None]:
vis_ds = read_dataset_from_zarr(vis_cache_directory, vis_chunks)

model_ds = read_dataset_from_zarr(model_cache_directory, vis_chunks)
# Fix to ensure that time coords match
model_ds = model_ds.assign_coords({"time": vis_ds.time})

In [None]:
gaintable = create_bandpass_table(vis_ds).pipe(with_chunks, vis_chunks)

# gaintable = create_gaintable_from_visibility(vis_ds, jones_type='B', timeslice=None)

## Test run_solver directly

In [None]:
gaintable = run_solver(
    vis_ds,
    model_ds,
    # gaintable,
    solver="jones_substitution",
    # refant,
    niter=50,
    # phase_only,
    # tol,
    # crosspol,
    # normalise_gains,
)

In [None]:
gaintable.load()

### Export gaintable to h5parm

In [None]:
from ska_sdp_instrumental_calibration.data_managers.data_export.export_to_h5parm import (
    export_gaintable_to_h5parm,
)

export_gaintable_to_h5parm(gaintable, "actual_bandpass_stage.h5parm")