# Create Local Computing Cluster

In [None]:
import os
import sys
import importlib
import toolviper
from toolviper.dask import local_client
import xradio
import dask
from importlib.metadata import version

package_names = ["toolviper", "xradio", "dask"]

print(
    "\nReminder:"
    "\nafter packages updates, reloading packages may not be enough."
    "\nRestarting the Jupyter kernel may also be required.\n"
)
for package_name in package_names:
    package = sys.modules[package_name]
    importlib.reload(package)
    print(f"Using (reloaded) {package_name} version: {version(package_name)}")





In [None]:
from toolviper.dask import local_client

viper_client = local_client(
    cores=10,
    memory_limit="0.5GB"  # Per core
)
viper_client

# Convert on-disk MSv2 File to on-disk Processing Set of MsV4 Files
  * Convert the on-disk MSv2 "file" to the MSv4 ProcessingSet format
  * Store the result on disk, in Zarr format 

In [None]:
import os
from xradio.measurement_set import convert_msv2_to_processing_set

ms_file = "uid___A002_X11a51f7_X1a7.ms.split.cal"
# ms_file = "solar-regional-maps.ms"
# ms_file = "uid___A002_Xae00c5_X2e6b.ms.cal.split.shrink"

# xradio 0.58:
# The name of the Processing Set - outfile - MUST end with '.ps.zarr'
# Otherwise, this extension is appended to user-supplied outfile :=( !!!
processing_set_file = f"{ms_file}.ps.zarr"

# Default partitioning
# if not os.path.exists(processing_set_file):
#     convert_msv2_to_processing_set(
#         in_file=ms_file,
#         out_file=processing_set_file,  # 
#         parallel_mode="partition",
#         overwrite=True,
#     )

# Finer partitioning: split by scan and by antenna
# Takes ~ 1 hour !!!
fine_processing_set_file = f"{ms_file}.by-scan.by-antenna.ps.zarr"
if not os.path.exists(fine_processing_set_file):
    convert_msv2_to_processing_set(
        in_file=ms_file,
        out_file=fine_processing_set_file,
        partition_scheme=["SCAN_NUMBER", "ANTENNA1"],
        parallel_mode="partition",
        overwrite=True,
    )

# Utilities

## head_tail

In [None]:
def head_tail(my_list, n=3, elements='lines'):
    """
    Prints the first `n` and last `n` elements of a list, each on a new line,
    with '...' in between to indicate skipped elements.
    If the list has `2n` or fewer elements, prints all elements without '...'.
    """
    if len(my_list) <= 2 * n:
        summary = my_list
    else:
        summary = my_list[:n] + ['...'] + my_list[-n:]
    print('\n'.join(map(str, summary)))
    print(f"[{len(my_list):,} {elements}(s)]")

# Select Data

The goal of this section is to select
  * from the on-disk Processing Set of the Solar Regional Maps single-dish observation,
  * data required to image the first full scan of the Sun,
  * in a manner equivalent to that of the original CASA6 script

The original CASA6 imaging script (scriptForSDFullSunImaging.py) performs the following data selection:
  * Antenna Selection
    * Select that one with good Tsys for all EBs (check Tsys.plots)
    * antenna = 'PM04'
  * Time Selection 
    * Scan Name / Scan Number
        * scanNum = '6'
    * Observation Intent
        * intent = '\*ON_SOURCE\*'
    * Field
        * field = '1'
        * Note: field name: 'Sun'
  * Frequency Selection (Spectral Window)
    * SPW with the best Tsys values.
    * spw = '3'
    * Note: corresponding spectral window
        * 106.979 GHz +/- 1GHz
        * [105.979 GHz, 107.979 GHz], 1 channel of width 2GHz

## open_processing_set: too slow !!!

In [None]:
import pandas as pd
from xradio.measurement_set.open_processing_set import open_processing_set
from xarray import DataTree

pd.options.display.max_colwidth = 100

# Takes 11 minutes @ 10 cores !!!
# Takes 13 minutes @  4 cores
#     * but setting up a 4-core cluster takes only 6 seconds
if False:
    fine_ps_xdt: DataTree = open_processing_set(
        ps_store=fine_processing_set_file,
        # intents=None  # all
        intents=['OBSERVE_TARGET#ON_SOURCE']
    )
    on_source_fine_ps_xdt = fine_ps_xdt


In [None]:
if False:
    # Full, non-selected DataTree
    # Takes 9 minutes @ 10 cores
    ps_full = ps = open_processing_set(
        ps_store=fine_processing_set_file,
        intents=None  # all
    )

## Collect all MsV4s

In [None]:
from glob import glob

print(
    "Listing MsV4s of on-disk Processing Set:\n"
    f"{fine_processing_set_file}\n"
)

fine_msv4s = sorted(glob(f"{fine_processing_set_file}/*"))
head_tail(fine_msv4s, elements="MsV4")

msv4s_selection = set(fine_msv4s)
print(f'\nCurrent Selection: {len(msv4s_selection)} MsV4s')

## Filter by Scan Intent

### Observation Intent Storage Location
Observation Intent information is stored:
  * in \<Processing Set\>/\<MsV4\>/.zmetadata JSON file
  * under key: metadata.".zattrs"."scan_name/.zattrs".scan_intents

Example:

In [None]:
import json
from pprint import pprint

metadata_file = f"{fine_msv4s[0]}/.zmetadata"
print(f"metadata file:\n{metadata_file}")

with open(metadata_file, 'r') as f:
    root = json.load(f)

print('\nmetadata."scan_name/.zattrs".scan_intents:')
pprint(
    root['metadata']['scan_name/.zattrs']['scan_intents'],
    indent=4,
)


### Restrict current selection to selected intent

In [None]:
import os

intent_selection = "OBSERVE_TARGET#ON_SOURCE"

def intent_match(msv4_path, intent_selection):
    metadata_file = os.path.join(msv4_path, ".zmetadata")
    with open(metadata_file, 'r') as f:
        root = json.load(f)
        scan_attrs = root['metadata']["scan_name/.zattrs"]
        return intent_selection in scan_attrs["scan_intents"]

print(f"Current selection: {len(msv4s_selection)} MsV4(s)")

# Narrow down selection
msv4s_selection = {
  msv4 for msv4 in msv4s_selection
    if intent_match(msv4, intent_selection)
}

print(f"Updated selection: {len(msv4s_selection)} MsV4(s)")
head_tail(sorted(msv4s_selection), elements="MsV4(s)")

## Filter by Field Name

### Field Name Storage Location

Field name can be retrieved e.g. from the field_name coordinate of the VisibilityXds, stored under:
  * the \<Processing Set\>/\<MsV4\>/field_name directory

Example:

In [None]:
import zarr

field_name_dir = f"{fine_msv4s[0]}/field_name"
print(f"field_name directory:\n{field_name_dir}")

field_names = zarr.open_array(field_name_dir)

print("\nfield_name coordinate array:")
head_tail(sorted(list(field_names[:])), elements="elements")

### Restrict current selection to selected field name

In [None]:
field_name_selection = "Sun_1"  # MsV2.FieldName_MsV2.FieldId

def field_name_match(msv4_path, field_name_selection):
    field_name_dir = os.path.join(msv4_path, "field_name")
    field_names = zarr.open_array(
        field_name_dir,
        mode='r'
    )
    return field_names[0] == field_name_selection

print(f"Current selection: {len(msv4s_selection)} MsV4(s)")

# Narrow down selection
msv4s_selection = {
  msv4 for msv4 in msv4s_selection
    if field_name_match(msv4, field_name_selection)
}

print(f"Updated selection: {len(msv4s_selection)} MsV4(s)")
head_tail(sorted(msv4s_selection), elements="MsV4(s)")



## Filter by Scan Name

### Scan Name Storage Location

Scan name can be retrieved e.g. from the antenna_name (optional) coordinate of the VisibilityXds, stored under:
  * the \<Processing Set\>/\<MsV4\>/scan_name directory

Example:

In [None]:
import zarr

scan_name_dir = f"{fine_msv4s[0]}/scan_name"
print(f"scan_name directory:\n{scan_name_dir}")

scan_names = zarr.open_array(scan_name_dir)

print("\nscan_name coordinate array:")
head_tail(sorted(list(scan_names[:])), elements="elements")

### Restrict current selection to selected scan name

In [None]:
scan_name_selection = "6"  # First Full Sun Scan

def scan_name_match(msv4_path, scan_name_selection):
    scan_name_dir = os.path.join(msv4_path, "scan_name")
    scan_names = zarr.open_array(
        scan_name_dir,
        mode='r'
    )
    return scan_names[0] == scan_name_selection

print(f"Current selection: {len(msv4s_selection)} MsV4(s)")

# Narrow down selection
msv4s_selection = {
  msv4 for msv4 in msv4s_selection
    if scan_name_match(msv4, scan_name_selection)
}

print(f"Updated selection: {len(msv4s_selection)} MsV4(s)")
head_tail(sorted(msv4s_selection), elements="MsV4")

## Filter by Antenna Name

### Antenna Name Storage Location

Antenna names can be retrieved e.g. from the baseline_antenna{1,2}_name coordinates of the VisibilityXds, stored under:
  * the \<Processing Set\>/\<MsV4\>/baseline_antenna{1,2}_name directories

Example:

In [None]:
import zarr

antenna1_name_dir = f"{fine_msv4s[0]}/baseline_antenna1_name"
print(f"antenna_name directory:\n{antenna1_name_dir}")

antenna1_names = zarr.open_array(antenna1_name_dir)

print("\nantenna_name coordinate array:")
head_tail(sorted(list(antenna1_names[:])), elements="elements")

### Restrict current selection to selected antenna

In [None]:
antenna_name_selection = "PM04"

def antenna_name_match(msv4_path, antenna_name_selection):
    antenna_dirs = [
        "baseline_antenna1_name",
        "baseline_antenna2_name",
    ]
    for antenna_dir in antenna_dirs:
        antenna_path = os.path.join(msv4_path, antenna_dir)
        antenna_names = zarr.open_array(
            antenna_path,
            mode='r'
        )
        first_antenna_name = str(antenna_names[0])
        if not first_antenna_name.startswith(antenna_name_selection):
            return False
    
    return True

print(f"Current selection: {len(msv4s_selection)} MsV4(s)")

# Narrow down selection
msv4s_selection = {
  msv4 for msv4 in msv4s_selection
    if antenna_name_match(msv4, antenna_name_selection)
}

print(f"Updated selection: {len(msv4s_selection)} MsV4(s)")
head_tail(sorted(msv4s_selection), elements="MsV4(s)")

## Filter by Spectral Window Id

### Spectral Window Id Storage Location

Spectral Window Id information is stored:
  * in \<Processing Set\>/\<MsV4\>/frequency/.zattrs JSON file
  * under the key: spectral_window_id

Example:

In [None]:
import json
from pprint import pprint

metadata_file = f"{fine_msv4s[0]}/frequency/.zattrs"
print(f"metadata file:\n{metadata_file}")

with open(metadata_file, 'r') as f:
    root = json.load(f)

print('\nSpectral Window Name:')
pprint(
    root['spectral_window_name'],
    indent=4,
)

### Restrict current selection to selected spectral window id

In [None]:
spw_id_selection = "3"

def spw_id_match(msv4_path, spw_id_selection):
    metadata_file = os.path.join(msv4_path, "frequency", ".zattrs")
    with open(metadata_file, 'r') as f:
        root = json.load(f)
        spw_name = str(root['spectral_window_name'])
        return spw_name.endswith(spw_id_selection)

print(f"Current selection: {len(msv4s_selection)} MsV4(s)")

# Narrow down selection
msv4s_selection = {
  msv4 for msv4 in msv4s_selection
    if spw_id_match(msv4, spw_id_selection)
}

print(f"Updated selection: {len(msv4s_selection)} MsV4(s)")
head_tail(sorted(msv4s_selection), elements="MsV4(s)")

# Inspect Selected Data

In this section we inspect the DataTree of the first selected MsV4.<br>
In our specific case, we have only 1 selected MsV4.


## Display the root node

In [None]:
import xarray
from xarray import DataTree


selected_msv4s = list(msv4s_selection)
current_msv4 = selected_msv4s[0]

current_msv4_xdt: DataTree = xarray.open_datatree(
    current_msv4,
    engine="zarr"
)

root_node = current_msv4_xdt
display(root_node)



In [None]:
current_msv4_xdt.time

## Inspect the DataTree structure

In [None]:
from collections.abc import Iterable

attr_names = [
    "depth",
    "width",
    "is_root",
    "children",
    "groups"
]

for attr_name in attr_names:
    attr = getattr(root_node, attr_name)
    if not isinstance(attr, Iterable):
        print(f"{attr_name}: {attr}")
    else:
        print(f"{attr_name}:")
        for element in attr:
            print("    ", end="")
            print(element)

## Inspect Antenna Information

In [None]:
print("baseline_antenna1_name coordinate:")
display(current_msv4_xdt.baseline_antenna1_name)

print("baseline_antenna1_name values:")
display(current_msv4_xdt.baseline_antenna1_name.values)

print("baseline_antenna2_name values:")
display(current_msv4_xdt.baseline_antenna2_name.values)

## Inspect Antenna Pointings Data

In [None]:
from xarray import DataTree


pointing_node: DataTree = current_msv4_xdt['/pointing_xds']
display(pointing_node)


# Plot Antenna Pointings: Altitude vs Azimuth

## Check Azimuth angle unit

In [None]:
pointing_xar = pointing_node.POINTING_BEAM.squeeze()

az = pointing_xar.sel(local_sky_dir_label='az')
alt = pointing_xar.sel(local_sky_dir_label='alt')

az

## Create the Plot

In [None]:
import numpy as np
import pandas as pd
import pathlib
from matplotlib import pyplot as plt

plt.style.use("dark_background")

# Figure title
asdm_uid = fine_processing_set_file.split('.')[0]
current_msv4_name = pathlib.Path(current_msv4).name
plt.suptitle(
    f"Processing Set: {fine_processing_set_file}\n"
    f"MsV4: {current_msv4_name}")

plt.subplots_adjust(top=0.8)

# Plot title
antenna_name = str(pointing_xar.antenna_name.values)
plt.title(f"Antenna {antenna_name.replace("_", " at station ")}")

# Axes labels
plt.xlabel('Azimuth (degrees)')
plt.ylabel('Altitude (degrees)')

# Plot data
pointings_azimuth_deg = np.degrees(az.values)
pointings_altitude_deg = np.degrees(alt.values)

color_values = pointing_xar.time_pointing.values
marker_sizes = np.full(color_values.size, 0.1)
opacity = np.full(color_values.size, 0.8)

# Distinguish off-source pointings, and initial on-source pointing
# The value 2 is empirical
marker_sizes[0:2] = 128
opacity[0:2] = 1

# Create plot
scatter_plot = plt.scatter(
    pointings_azimuth_deg,
    pointings_altitude_deg,
    s=marker_sizes,
    alpha=opacity,
    # color="DodgerBlue",
    c=color_values,
    cmap='viridis'
)
scatter_plot.axes.set_aspect('equal', adjustable='datalim')

# Add colorbar
cbar = plt.colorbar(scatter_plot)

# Update colormap labels to display 5 equidistant human-readable date-times
num_labels = 5
min_time = color_values.min()
max_time = color_values.max()
equidistant_times = np.linspace(min_time, max_time, num_labels)
equidistant_date_times = pd.to_datetime(equidistant_times, unit='s')

cbar.set_ticks(equidistant_times)
cbar.set_ticklabels(equidistant_date_times.strftime('%H:%M:%S'))


plt.show()




# Convert Antenna Pointings to Celestial Reference Frame ICRS

## Inspect Antenna DataTree

In [None]:
antenna_xdt = current_msv4_xdt.antenna_xds
display(antenna_xdt)

## Display Antenna Position Array

In [None]:
antenna_position_xar = antenna_xdt.ANTENNA_POSITION
display(antenna_position_xar)

## Retrieve Antenna Position Coordinates

In [None]:
antenna_position = antenna_xdt.ANTENNA_POSITION.squeeze().values

antenna_name = antenna_xdt.antenna_name.values[0]
print(f"Antenna Name: {antenna_name}")

antenna_position = (
    antenna_xdt.ANTENNA_POSITION.sel({'antenna_name': antenna_name})
)

antenna_coords = {}
print("Antenna Position Coordinates:")
for coord_name in antenna_position_xar.cartesian_pos_label.values:
    coord_value = antenna_coords[coord_name] = (
        antenna_position.sel({'cartesian_pos_label': coord_name}).values
    )
    print(f"    {coord_name}: {coord_value:>15f}")


## Convert Antenna Pointing Directions to ICRS

In [None]:
from xarray import DataArray

pointing_xdt = current_msv4_xdt.pointing_xds
pointing_beam: DataArray = pointing_xdt.POINTING_BEAM
antenna_pointings: DataArray = pointing_beam.sel(antenna_name=antenna_name)

### Inspect Antenna Pointings Times

In [None]:
display(antenna_pointings.time_pointing)

print(
    f"Pointings Times Format: {antenna_pointings.time_pointing.format}\n"
    f"Pointings Times Unit  : {antenna_pointings.time_pointing.units}\n"
)

### Convert Pointings Directions to ICRS using Astropy

In [None]:
from astropy.coordinates import SkyCoord, EarthLocation
from astropy.time import Time as AstroTime
from astropy.units import Unit as AstroUnit

pointings_times = AstroTime(
    antenna_pointings.time_pointing.values * AstroUnit(
        antenna_pointings.time_pointing.units
    ),
    format=antenna_pointings.time_pointing.format,
    scale=antenna_pointings.time_pointing.scale
)

pointings_az_alt_coords = SkyCoord(
    # The 'degrees' version does not work
    # and results in an incorrect plot !!!
    # alt=pointings_azimuth_deg,
    # az=pointings_altitude_deg,
    # unit="deg",
    alt=alt,
    az=az,
    unit="rad",
    obstime=pointings_times,
    frame='altaz',
    location=EarthLocation.from_geocentric(
        antenna_coords['x'],
        antenna_coords['y'],
        antenna_coords['z'],
        unit="m"
    )
)

pointings_icrs_coords = pointings_az_alt_coords.icrs


# Plot Antenna Pointings: Declination vs Right Ascension

## Inspect Astropy SkyCoord data

In [None]:
print(
    "Pointings ICRS Coords data:"
    f"\n{type(pointings_icrs_coords.data)}"
    f"\n{pointings_icrs_coords.data}"
)

print(
    "\nPointings ICRS Coords data Units:"
    f" [{pointings_icrs_coords.data.lon.unit}"
    f", {pointings_icrs_coords.data.lat.unit}]"
)

## Create the Plot

In [None]:
import pandas as pd
from matplotlib import pyplot as plt
from astropy.coordinates import Angle as AstroAngle
from astropy.units import Unit as AstroUnit

# Plot markers: default sizes and colors
color_values = pointing_xar.time_pointing.values
marker_sizes = np.full(color_values.size, 0.1)
opacity = np.full(color_values.size, 0.8)

# Distinguish off-source pointings, and initial on-source pointing
# The value 2 is empirical
# TODO: find out how-to know whether a data is taken on-source or off-source
marker_sizes[0:2] = 128
opacity[0:2] = 1
# Highlight last pointing
marker_sizes[-1] = 128
opacity[-1] = 1

plt.rc('text', usetex=False)

scatter_plot_icrs = plt.scatter(
    pointings_icrs_coords.data.lon,
    pointings_icrs_coords.data.lat,
    s=marker_sizes[0:],
    alpha=opacity,
    c=color_values[0:],
    cmap='viridis'
)
scatter_plot_icrs.axes.set_aspect('equal', adjustable='datalim')

# Figure title
plt.suptitle(
    f"Processing Set: {fine_processing_set_file}\n"
    f"MSv4: {current_msv4_name}")

plt.subplots_adjust(top=0.8)

# Plot title
plt.title(f"Antenna {antenna_name.replace("_", " at station ")}")

# Axes labels
plt.xlabel('Right Ascension')
plt.ylabel('Declination')

# X axis ticks and ticks labels
# ---- First pointing (OFF source) and last pointing (ON source)
xticks_indices = [0, -1]
# ---- ON source: min and max
xticks_indices.extend([
    np.argmin(pointings_icrs_coords.data.lon.value[1:]),
    np.argmax(pointings_icrs_coords.data.lon.value[1:])
])

xticks_locations = pointings_icrs_coords.data.lon[xticks_indices].value
xticks_unit = pointings_icrs_coords.data.lon.unit
xticks_labels = AstroAngle(
    xticks_locations,
    unit=AstroUnit(xticks_unit)
    ).to_string(
        sep='hms',
        precision=2
)

scatter_plot_icrs.axes.set_xticks(
    xticks_locations,
    labels=xticks_labels,
    minor=False,
    fontsize=8
)

scatter_plot_icrs.axes.xaxis.set_inverted(True)

# Y axis ticks and ticks labels
y_min = pointings_icrs_coords.data.lat.value.min()
y_max = pointings_icrs_coords.data.lat.value.max()
yticks_locations = [
    y_min,
    np.mean([y_min, y_max]),
    y_max
]
yticks_unit = pointings_icrs_coords.data.lat.unit
yticks_labels = AstroAngle(
    yticks_locations,
    unit=AstroUnit(yticks_unit)
    ).to_string(
        unit=AstroUnit('degree'),
        precision=2,
        format='unicode'
)

scatter_plot_icrs.axes.set_yticks(
    yticks_locations,
    labels=yticks_labels,
    minor=False,
    fontsize=8
)

# Axes grid
scatter_plot_icrs.axes.grid(
    visible=True,
    which='major',
    axis='both',
    alpha=0.5
)

# Colorbar for color-annotated time
cbar_icrs = plt.colorbar(scatter_plot_icrs)
# ---- Update colormap labels 
# to display 5 equidistant human-readable date-times
num_labels = 5
min_time = color_values.min()
max_time = color_values.max()
equidistant_times = np.linspace(min_time, max_time, num_labels)
equidistant_date_times = pd.to_datetime(equidistant_times, unit='s')
cbar_icrs.set_ticks(equidistant_times)
cbar_icrs.set_ticklabels(equidistant_date_times.strftime('%H:%M:%S'))



plt.show()

# Compute Data Celestial Directions: Interpolate Celestial Pointing Directions at Data-taking Times

## Compute Cubic Spline Interpolators Using SciPy

In [None]:
from scipy.interpolate import CubicSpline

times_known = pointings_times.value
ra_known = pointings_icrs_coords.ra
dec_known = pointings_icrs_coords.dec

# Compute splines
ra_spline = CubicSpline(times_known, ra_known)
dec_spline = CubicSpline(times_known, dec_known)

## Interpolate Pointing Directions (Unknown at Data-taking Times) using Interpolators

In [None]:
data_taking_times = current_msv4_xdt.time.values

ra_data = ra_spline(data_taking_times)
dec_data = dec_spline(data_taking_times)

# Plot Data Celestial Directions

In [None]:
import pandas as pd
from matplotlib import pyplot as plt
from astropy.coordinates import Angle as AstroAngle
from astropy.units import Unit as AstroUnit

# Figure title
plt.suptitle(
    f"ProcessingSet: {fine_processing_set_file}\n"
    f"MSv4: {current_msv4_name}")

plt.subplots_adjust(top=0.8)

# Plot title
plt.title(f"Antenna {antenna_name.replace("_", " at station ")}")

# Plot markers: default sizes and colors
color_values = data_taking_times
marker_sizes = np.full(color_values.size, 0.1)
opacity = np.full(color_values.size, 0.25)

# Highlight first and last interpolated pointing directions
marker_sizes[0] = 128
opacity[0] = 1

marker_sizes[-1] = 128
opacity[-1] = 1

plt.rc('text', usetex=False)

n_points = color_values.shape[0]
scatter_plot_icrs_interp = plt.scatter(
    ra_data[0:n_points],
    dec_data[0:n_points],
    s=marker_sizes[0:n_points],
    alpha=opacity[0:n_points],
    c=color_values[0:n_points],
    cmap='viridis'
)
# Redraw initial point on top
n_points = 1
plt.scatter(
    ra_data[0:n_points],
    dec_data[0:n_points],
    s=marker_sizes[0:n_points],
    alpha=opacity[0:n_points],
    c=color_values[0:n_points],
    cmap='viridis'
)


scatter_plot_icrs_interp.axes.set_aspect('equal', adjustable='datalim')

# Axes labels
plt.xlabel('Right Ascension')
plt.ylabel('Declination')

# X axis ticks and ticks labels
# ---- First pointing (OFF source) and last pointing (ON source)
xticks_indices = [0, -1]
xticks_indices.extend([
    np.argmin(ra_data),
    np.argmax(ra_data)
])

xticks_locations = ra_data[xticks_indices]
xticks_unit = pointings_icrs_coords.ra.unit
xticks_labels = AstroAngle(
    xticks_locations,
    unit=AstroUnit(xticks_unit)
    ).to_string(
        sep='hms',
        precision=0
)

scatter_plot_icrs_interp.axes.set_xticks(
    xticks_locations,
    labels=xticks_labels,
    minor=False,
    fontsize=8
)

scatter_plot_icrs_interp.axes.xaxis.set_inverted(True)

# Y axis ticks and ticks labels
y_min = dec_data.min()
y_max = dec_data.max()
yticks_locations = [
    y_min,
    np.mean([y_min, y_max]),
    y_max
]
yticks_unit = pointings_icrs_coords.dec.unit
yticks_labels = AstroAngle(
    yticks_locations,
    unit=AstroUnit(yticks_unit)
    ).to_string(
        unit=AstroUnit('degree'),
        precision=2,
        format='unicode'
)

scatter_plot_icrs_interp.axes.set_yticks(
    yticks_locations,
    labels=yticks_labels,
    minor=False,
    fontsize=8
)

# Axes grid
scatter_plot_icrs_interp.axes.grid(
    visible=True,
    which='major',
    axis='both',
    alpha=0.5
)

# Colorbar for color-annotated time
cbar_icrs_interp = plt.colorbar(scatter_plot_icrs_interp)
# ---- Update colormap labels 
# to display 5 equidistant human-readable date-times
num_labels = 5
min_time = color_values.min()
max_time = color_values.max()
equidistant_times = np.linspace(min_time, max_time, num_labels)
equidistant_date_times = pd.to_datetime(equidistant_times, unit='s')
cbar_icrs_interp.set_ticks(equidistant_times)
cbar_icrs_interp.set_ticklabels(equidistant_date_times.strftime('%H:%M:%S'))

plt.show()

# Project Data Celestial directions: Spherical World Coordinate System (WCS) Projection

## Construct Astropy WCS Object

  * Because in the end we want to create an image in FITS format, we proceed as follows to keep things consistent:
    * Create image's FITS file header
    * Construct an Astropy WCS object from that header
  * We hard-code most FITS parameters with values stored in the FITS header of the full Sun reference FITS image created by CASA.
  * Deriving these values from user-specified input parameters will be addressed later
  * The short-term goal is to compare the prototype image with the reference CASA6 image.

In [None]:
# Refs: 
#   * Astropy WCS: 
#       https://docs.astropy.org/en/stable/wcs/example_create_imaging.html
#   * FITS Keywords:
#       https://heasarc.gsfc.nasa.gov/docs/fcg/standard_dict.html
from astropy import wcs
from astropy.units import Quantity
from astropy.io import fits
from astropy import units as u
import numpy as np
from dotmap import DotMap


# uid___A002_X11a51f7_X1a7, ALMA Band 3, full Sun imaging:
# related parameters of CASA6 tsdimaging task:
#
# Image Dimensions
# ---- Spatial Axes Sizes
imsize = [400, 400]  # x and y image size in pixels
x_axis_size, y_axis_size = imsize
# ---- Frequency Axis Size
mode = 'channel'
nchan = 1  # number of channels (planes) in output image (-1=all)
frequency_axis_size = nchan
# ---- Polarization Axis Size
stokes = 'I'
polarization_axis_size = 1 if stokes == 'I' else '1 or 2 or 4 ?'
#
# World Coordinate System-related parameters
phasecenter = 'Sun'
projection = 'SIN'
cell_size = DotMap({
    'x': '6.0 arcsec',
    'y': '6.0 arcsec'
})

# Time to play with the FITS cards!
fits_cards = []

initial_ordered_cards = [
    fits.Card('SIMPLE', 'T', 'Standard FITS'),
    fits.Card('BITPIX', -32, 'Floating point (32 bit)'),
    fits.Card('NAXIS', 4),
    fits.Card('NAXIS1', x_axis_size),
    fits.Card('NAXIS2', y_axis_size),
    fits.Card('NAXIS3', frequency_axis_size),
    fits.Card('NAXIS4', polarization_axis_size),
    fits.Card('EXTEND', 'T')  # FITS file may contain extensions
]

scaling_cards = [  
    # FITS default values since we do not store floats as integers
    fits.Card('BSCALE', 1.0, 'PHYSICAL = PIXEL*BSCALE + BZERO'),
    fits.Card('BZERO', 0.0),
]

beam_cards = [
    # Beam size is set by Python function set_beam_size()
    # inside CASA6 task tsdimaging

    # Hard-coding here values from reference FITS image header,
    # exported from CARTA
    # Full width at half maximum (FWHM) of the major axis of the beam,
    # in degrees.
    fits.Card('BMAJ', 1.667502277778E-02),
    # The FWHM of the minor axis of the beam, also in degrees.
    fits.Card('BMIN', 1.667502277778E-02),
    # Beam position angle, in degrees, measured counterclockwise from north.
    fits.Card('BPA', 0.0),

]

brightness_cards = [
    # CASA6 task tsdimaging retrieves the brightness unit by calling Python 
    # function get_brightness_unit_from_ms()

    # Hard-coding here values from reference FITS image header,
    # exported from CARTA
    # Physical meaning of the pixel values in the image
    fits.Card('BTYPE', 'Intensity'),
    # Target of observation
    fits.Card('OBJECT', 'Sun'),
    # Physical units of the pixel values in the image
    fits.Card('BUNIT', 'K', 'Brightness (pixel) unit'),
]

wcs_cards = [
    # World Coordinate System (WCS) cards
    # Celestial reference system used for RA/Dec astronomical coordinates
    fits.Card('RADESYS', 'ICRS')
]
wcs_cards.extend([  # Axis 1: X axis: RA---SIN
    # Name of coordinate axis 1
    fits.Card('CTYPE1', f'RA---{projection}'),
    # World coordinate value at reference pixel defined by CRPIX1,
    # hard-coded from reference FITS image header
    fits.Card('CRVAL1', 1.151895265540E+02),  # RA at reference pixel
    # Increment per pixel along axis 1
    fits.Card(
        'CDELT1',  # Pixel scale in RA (degrees/pixel)
        # The sign indicates axis orientation
        # Negative sign because RA increase to the left
        - Quantity(cell_size.x).to(u.degree).value
    ),
    # Coordinate of reference pixel along axis 1
    # hard-coded from reference FITS image header,
    # offset by 0.5 pixel from axis center:
    # 0.5 * (1 + x_axis_size) = 200.5
    fits.Card('CRPIX1', 201),  # Reference pixel X
    # Unit of the world coordinate along axis 1
    fits.Card('CUNIT1', 'deg'),  # RA values are in degrees
])
wcs_cards.extend([  # Axis 2: Y axis: DEC--SIN
    # Name of coordinate axis 2
    fits.Card('CTYPE2', f'DEC--{projection}'),
    # World coordinate value at reference pixel defined by CRPIX2,
    # hard-coded from reference FITS image header
    fits.Card('CRVAL2', 2.141998704917E+01),  # Declination at reference pixel
    # Increment per pixel along axis 2
    fits.Card(
        'CDELT2',  # Pixel scale in DEC (degrees/pixel)
        Quantity(cell_size.y).to(u.degree).value
    ),
    # Coordinate of reference pixel along axis 2
    # hard-coded from reference FITS image header,
    # offset by 0.5 pixel from axis center:
    # 0.5 * (1 + y_axis_size) = 200.5
    fits.Card('CRPIX2', 201),  # Reference pixel Y
    # Unit of the world coordinate along axis 1
    fits.Card('CUNIT2', 'deg'),  # DEC values are in degrees
])
wcs_cards.extend([  # Axis 3: Frequency axis
    # Name of coordinate axis 3
    fits.Card('CTYPE3', f'FREQ'),
    # World coordinate value at reference pixel defined by CRPIX3,
    # hard-coded from reference FITS image header
    fits.Card('CRVAL3', 1.069829163331E+11),  # Frequency at reference pixel
    # Increment per pixel along axis 3
    fits.Card(
        'CDELT3',  # Pixel scale in frequency (Hz/pixel)
        2.000073216857E+09  # hard-coded from reference FITS image header
    ),
    # Coordinate of reference pixel along axis 3
    # hard-coded from reference FITS image header,
    fits.Card('CRPIX3', 1),  # Reference pixel ... "spectral window index" ?
    # Unit of the world coordinate along axis 3
    fits.Card('CUNIT3', 'Hz'),  # Frequency values are in Hz
])
wcs_cards.extend([  # Axis 4: Polarization: Stokes I
    # Name of coordinate axis 4
    fits.Card('CTYPE4', f'STOKES'),
    # World coordinate value at reference pixel defined by CRPIX4,
    # hard-coded from reference FITS image header
    fits.Card('CRVAL4', 1.0),
    # Increment per pixel along axis 4
    fits.Card(
        'CDELT4',
        1.0  # hard-coded from reference FITS image header
    ),
    # Coordinate of reference pixel along axis 3
    # hard-coded from reference FITS image header,
    fits.Card('CRPIX4', 1),  # Reference pixel "polarization index" ?
    # Unit of the world coordinate along axis 4
    fits.Card('CUNIT4', ''),  # Dimensionless
])
wcs_cards.extend([  # Spectral Reference Frame
    fits.Card('RESTFRQ', 1.069790000000E+11, 'Rest Frequency (Hz)'),
    fits.Card('SPECSYS', 'LSRK', 'Spectral reference frame'),
    fits.Card('ALTRVAL', -1.097493083694E+04,
              'Alternate frequency reference value'),
    fits.Card('ALTRPIX', 1, 'Alternate frequency reference pixel'),
    fits.Card('VELREF', 257, '1 LSR, 2 HEL, 3 OBS, +256 Radio')
])

observation_cards = [
    fits.Card('TELESCOP', 'ALMA'),
    fits.Card('OBSERVER', 'bvilavil'),
    fits.Card('DATE-OBS', '2024-07-15T16:43:07.343999'),
    fits.Card('TIMESYS', 'UTC'),
    fits.Card('OBSRA', 1.151895265540E+02),
    fits.Card('OBSDEC', 2.141998704917E+01),
    fits.Card('OBSGEO-X', 2.225142180269E+06),
    fits.Card('OBSGEO-Y', -5.440307370349E+06),
    fits.Card('OBSGEO-Z', -2.481029851874E+06),
    fits.Card('INSTRUME', 'ALMA'),
]

for cards_group in [
            initial_ordered_cards,
            scaling_cards,
            beam_cards,
            brightness_cards,
            wcs_cards,
            observation_cards
        ]:
    fits_cards.extend(cards_group)

fits_header = fits.Header(fits_cards)
print("FITS Header:")
display(fits_header)

wcs_4d = wcs.WCS(fits_header)
print("\nWorld Coordinate System object from header:")
display(wcs_4d)

In [None]:
from astropy.io import fits
import numpy as np

fits_name = 'reference-image/antenna-PM04.fullSun-num0_sci.spw3_.sd.TP.cont.I.manual.sd.not-rescaled.fits'
hdul = fits.open(fits_name)

hdul.info()

image_shape = (
    x_axis_size,
    y_axis_size,
    frequency_axis_size,
    polarization_axis_size
)
print(f"Image shape: {image_shape}")
image_array = np.zeros(image_shape, np.float32)

my_hdul = fits.HDUList(
    hdus=[
        fits.PrimaryHDU(
            data=np.transpose(image_array),
            header=fits_header
        )
    ]
)

my_hdul.info()

## Perform Spherical WCS Projection of Celestial Pointing Directions Interpolated at Data-Taking Time

In [None]:
world_coordinates = np.stack((ra_data, dec_data), axis=1)
world_array_origin = 0  # 0-based Numpy array

print("World coordinates:")
print(world_coordinates, world_coordinates.shape)

wcs_2d = wcs_4d.sub(['longitude', 'latitude'])

pixel_coordinates = wcs_2d.wcs_world2pix(
    world_coordinates, world_array_origin
)

print("\nPixel coordinates:")
print(pixel_coordinates, pixel_coordinates.shape)

## Plot Projected Data Directions

In [None]:
import matplotlib
from matplotlib.axes import Axes
from pathlib import Path

x, y = (
    pixel_coordinates[:,0],
    pixel_coordinates[:, 1]
)

plt.scatter(
    x, y,
    color='white',
    alpha=0.5,
    s=0.0005
)
axes: Axes = plt.gca()
axes.set_aspect('equal')

# Figure title
antenna_name_only = antenna_name.split('_')[0]
ps_name = Path(fine_processing_set_file).name
current_msv4_name = Path(current_msv4).name
plt.suptitle(
    f"ProcessingSet: {ps_name}"
    f"\nMSv4: {current_msv4_name} Antenna: {antenna_name_only}"
)
plt.subplots_adjust(top=0.8)
# Plot Title
plt.title(f'WCS-{projection} Projected Data Directions')
plt.show()

png_name = (
    f"results/{ps_name}.{current_msv4_name}.{antenna_name}"
    f".data-directions.projection.WCS-{projection}.png"
)
plt.savefig(f"{png_name}")
print(f"Saved: {png_name}")

# Prepare for Calling the Python Single Dish Gridder

## Check how CASA6 SDGrid handles tsdimaging parameter: stokes = 'I'
* Q: is Stokes I computed before calling the gridder ?
* A: no

In [None]:
long_answer = """
(lldb) frame select
frame #0: 0x000000010de79109 libcasacpp_synthesis.6.7.2.3.dylib`casa::refim::SDGrid::put(this=0x00007fa406facc00, vb=0x00006000033e2f70, row=-1, dopsf=false, type=CORRECTED) at SDGrid.cc:975:9
   972 	
   973 	      if (call_ggridsd) { // Call plain gridder
   974 	
-> 975 	        ggridsd(
   976 	          xyPositions.getStorage(del),
   977 	          datStorage,
   978 	          &s[0],
(lldb) p polMap.nelements()
(size_t) 2
(lldb) p polMap[0]
(int) 0
(lldb) p polMap[1]
(int) 0
(lldb) 

"""

## Check what weights CASA6 SDGrid::put passes to the Fortran gridder

### Conclusion

In [None]:
def casa6_gridder_input_weights(
        use_imaging_weights: bool=False,
        ms_has_weight_spectrum_column=True
        ):
    if use_imaging_weights:
        return "vb.imagingweight()"

    weights_source = (
        "ms.MAIN.WEIGHT_SPECTRUM" if ms_has_weight_spectrum_column
        else "ms.MAIN.WEIGHT"
    )

    return f"to_stokes_I({weights_source})"

class UseCase:
    def __init__(self,
            name=None,
            use_imaging_weights=None,
            ms_has_weight_spectrum_column=None):
        self.name = name
        self.use_imaging_weights = use_imaging_weights
        self.ms_has_weight_spectrum_column = ms_has_weight_spectrum_column

use_case = UseCase(
    name="Solar Regional Imaging",
    use_imaging_weights=False,
    ms_has_weight_spectrum_column=True
)


print("CASA6 Gridder Input Weights:")
use_case_input_weights = casa6_gridder_input_weights(
    use_imaging_weights=use_case.use_imaging_weights,
    ms_has_weight_spectrum_column=use_case.ms_has_weight_spectrum_column
)
print(
    f"    * Use Case: {use_case.name}:\n"
    f"      => {use_case_input_weights}"
)


### Details

#### CASA6 SDGrid::put

In [None]:
class SDGrid:  # put
    def put(self):
        """
        void SDGrid::put(const vi::VisBuffer2& vb, Int row, Bool dopsf,
             FTMachine::Type type)
        {
          LogIO os(LogOrigin("SDGrid", "put"));

          gridOk(convSupport);

          // There is no channel mapping cache in VI/VB2 version of FTMachine
          // Perform matchChannel everytime
          matchChannel(vb);

          // No point in reading data if its not matching in frequency
          if (max(chanMap)==-1) return;

          Matrix<Float> imagingweight;
          //imagingweight=&(vb.imagingWeight());
          pickWeights(vb, imagingweight);

          if (type==FTMachine::PSF || type==FTMachine::COVERAGE) dopsf=true;
          if (dopsf) type=FTMachine::PSF;
          Cube<Complex> data;
          Cube<Int> flags; //Fortran gridder need the flag as ints
          Matrix<Float> elWeight;
          interpolateFrequencyTogrid(vb, imagingweight,data, flags, elWeight, type);
          // (lldb) p data.shape()
          // (const casacore::IPosition) {
          //   size_p = 3
          //   buffer_p = ([0] = 2, [1] = 1, [2] = 1, [3] = 140702053814768)
          //   data_p = 0x00007ff7bfefd1f8
          // }
          // (lldb)
          // (lldb) p flags.shape()
          // (const casacore::IPosition) {
          //   size_p = 3
          //   buffer_p = ([0] = 2, [1] = 1, [2] = 1, [3] = 11810099496)
          //   data_p = 0x00007ff7bfefd0f8
          // }
          // (lldb) 
          Bool iswgtCopy;
          const Float *wgtStorage;
          wgtStorage=elWeight.getStorage(iswgtCopy);
          Bool isCopy;
          const Complex *datStorage=0;
          if (!dopsf) datStorage=data.getStorage(isCopy);

          // If row is -1 then we pass through all rows
          Int startRow, endRow, nRow;
          if (row==-1) {
            nRow=vb.nRows();
            startRow=0;
            endRow=nRow-1;
          } else {
            nRow=1;
            startRow=row;
            endRow=row;
          }

          Vector<Int> rowFlags(vb.flagRow().nelements(), 0);
          for (Int rownr=startRow; rownr<=endRow; rownr++) {
            if(vb.flagRow()(rownr)) rowFlags(rownr)=1;
          }

          // Take care of translation of Bools to Integer
          Int idopsf = dopsf ? 1 : 0;

          { // Compute spectra pixel coordinates and call gridder
            // Make sure failed getXYPos does not fall on grid
            constexpr Double kFarAway = -1e9;
            Matrix<Double> xyPositions(2, endRow-startRow+1, kFarAway);
            for (Int rownr=startRow; rownr<=endRow; rownr++) {
              if (getXYPos(vb, rownr)) {
                xyPositions(0, rownr)=xyPos(0);
                xyPositions(1, rownr)=xyPos(1);
              }
            }
            { // Call gridder
              Bool del;
              const IPosition& fs=flags.shape();
              std::vector<Int> s(fs.begin(), fs.end());
              Bool datCopy, wgtCopy;
              Complex * datStor=griddedData.getStorage(datCopy);
              Float * wgtStor=wGriddedData.getStorage(wgtCopy);

              //Bool call_ggridsd = !clipminmax_ || dopsf;
              Bool call_ggridsd = !clipminmax_;

              if (call_ggridsd) { // Call plain gridder

                ggridsd(
                  xyPositions.getStorage(del),
                  datStorage,
                  &s[0],
                  &s[1],
                  &idopsf,
                  flags.getStorage(del),
                  rowFlags.getStorage(del),
                  wgtStorage,
                  &s[2],
                  &row,
                  datStor,
                  wgtStor,
                  &nx,
                  &ny,
                  &npol,
                  &nchan,
                  &convSupport,
                  &convSampling,
                  convFunc.getStorage(del),
                  chanMap.getStorage(del),
                  polMap.getStorage(del),
                  sumWeight.getStorage(del)
                );

              } else { // Call clipping gridder
                Bool gminCopy;
                Complex *gminStor = gmin_.getStorage(gminCopy);
                Bool gmaxCopy;
                Complex *gmaxStor = gmax_.getStorage(gmaxCopy);
                Bool wminCopy;
                Float *wminStor = wmin_.getStorage(wminCopy);
                Bool wmaxCopy;
                Float *wmaxStor = wmax_.getStorage(wmaxCopy);
                Bool npCopy;
                Int *npStor = npoints_.getStorage(npCopy);

                ggridsdclip(
                  xyPositions.getStorage(del),
                  datStorage,
                  &s[0],
                  &s[1],
                  &idopsf,
                  flags.getStorage(del),
                  rowFlags.getStorage(del),
                  wgtStorage,
                  &s[2],
                  &row,
                  datStor,
                  wgtStor,
                  npStor,
                  gminStor,
                  wminStor,
                  gmaxStor,
                  wmaxStor,
                  &nx,
                  &ny,
                  &npol,
                  &nchan,
                  &convSupport,
                  &convSampling,
                  convFunc.getStorage(del),
                  chanMap.getStorage(del),
                  polMap.getStorage(del),
                  sumWeight.getStorage(del)
                );

                gmin_.putStorage(gminStor, gminCopy);
                gmax_.putStorage(gmaxStor, gmaxCopy);
                wmin_.putStorage(wminStor, wminCopy);
                wmax_.putStorage(wmaxStor, wmaxCopy);
                npoints_.putStorage(npStor, npCopy);
              }
              griddedData.putStorage(datStor, datCopy);
              wGriddedData.putStorage(wgtStor, wgtCopy);
            }
          }

          { // Free memory
            if (!dopsf) data.freeStorage(datStorage, isCopy);

            elWeight.freeStorage(wgtStorage, iswgtCopy);
          }

        }
        """
    def readable_put(self):
        """
        A simplified and more readable put
        
        void SDGrid::put(
            const vi::VisBuffer2& vb, 
            Int row,
            Bool dopsf,
            FTMachine::Type ftMachineType)
        {
            // gridOk(convSupport);
            assert nx > 2 * convSupport;
            assert ny > 2 * convSupport;

            // TODO
            matchChannel(vb);

            // No point in reading data if does not match in frequency
            if ( max(chanMap) == -1 ) return;

            // TODO: Compute imagingweight
            Matrix<Float> imagingweight;
            pickWeights(vb, imagingweight);

            // Tweak 'dopsf' input parameter 
            if (ftMachineType == FTMachine::PSF or
                ftMachineType == FTMachine::COVERAGE)
                    dopsf = true;

            if (dopsf) ftMachineType = FTMachine::PSF;


            // Select from the Visibility Buffer 
            // - and depending on the Fourier Transform machine type -
            // the chunk of data of shape (nvispol, nvischan, nrow)
            // that we will feed to the gridder:
            // => gridderInputDataCube
            //
            // Compute its associated flags
            // => gridderInputDataCubeFlags_Int
            // by flagging non-selected data
            //
            // Compute the channelWeight matrix (nvischan, nrow)
            // 
            // Solar Data Imaging Case:
            //  * ftMachineType = FTMachine::CORRECTED
            //  * gridderInputDataCube = vb.visCubeCorrected
            //  * nvischan = 1: in that case there is no interpolation and:
            //    channelWeight = imagingWeight
            Cube<Complex> gridderInputDataCube; // (nvispol, nvischan, nrow)
            // Fortran gridder needs the flag as ints
            Cube<Int> gridderInputDataCubeFlags_Int; // (nvispol, nvischan, nrow)
            Matrix<Float> channelWeight; // (nvischan, nrow)

            // From FTMachine.h:
            // interpolate visibility data of vb to grid frequency definition
            // flag will be set the one as described in interpolateArray1D
            // return false if no interpolation is done...
            // e.g for nearest case
            interpolateFrequencyTogrid(
                vb,
                imagingweight,
                gridderInputDataCube,
                gridderInputDataCubeFlags_Int,
                channelWeight,
                ftMachineType
            );

            Bool iswgtCopy;
            const Float *wgtStorage;
            channelWeightStorage = channelWeight.getStorage(iswgtCopy);

            Bool isCopy;
            const Complex *gridderInputDataCubeStorage = 0;
            if (not dopsf) {
                gridderInputDataCubeStorage = 
                    gridderInputDataCube.getStorage(isCopy);
            }

            // If row is -1 then we pass through all rows
            Int startRow, endRow, nRows;
            if (row == -1) {
              nRows = vb.nRows(); startRow =   0; endRow = nRow-1;
            } else {
              nRows =          1; startRow = row; endRow =    row;
            }

            // Convert boolean row flags to integers
            Vector<Int> rowFlags_Int(vb.flagRow().nelements(), 0);
            for (Int rownr=startRow; rownr<=endRow; rownr++) {
                if (vb.flagRow()(rownr)) rowFlags(rownr) = 1;
            }

            // Take care of translation of Bools to Integer
            Int dopsf_Int = dopsf ? 1 : 0;

            { // Compute spectra pixel coordinates and call gridder
                // Make sure failed getXYPos does not fall on grid
                constexpr Double kFarAway = -1e9;
                Matrix<Double> xyPositions(2, nRow, kFarAway);
                for (Int rownr=startRow; rownr<=endRow; rownr++) {
                  if (getXYPos(vb, rownr)) {
                    xyPositions(0, rownr)=xyPos(0);
                    xyPositions(1, rownr)=xyPos(1);
                  }
                }

                { // Call gridder
                    // Convert gridder input data shape to integers
                    const IPosition& gridderInputDataShape = 
                        gridderInputDataCube.shape();
                    std::vector<Int> gridderInputDataShape_Int(
                      gridderInputDataShape.begin(),
                      gridderInputDataShape.end()
                    );

                    Bool datCopy, wgtCopy;
                    Complex * griddedDataStorage =
                        griddedData.getStorage(datCopy);
                    Float * griddedWeightStorage = 
                        wGriddedData.getStorage(wgtCopy);


                    Bool call_ggridsd = not clipminmax_;

                    Bool del;
                    if (call_ggridsd) { // Call plain gridder

                        ggridsd(
                            /* xy(2, nRow)= */
                                xyPositions.getStorage(del),
                            /* values(nvispol, nvischan, nrow)= */
                                gridderInputDataCubeStorage,
                            /* nvispol= */
                                &gridderInputDataShape_Int[0],
                            /* nvischan= */
                                &gridderInputDataShape_Int[1],
                            /* dowt= */
                                &dopsf_Int,
                            /* flag(nvispol, nvischan, nrow)= */
                                gridderInputDataCubeFlags_Int.getStorage(del),
                            /* rflag(nrow)= */
                                rowFlags_Int.getStorage(del),
                            /* weight(nvischan, nrow)= */
                                channelWeightStorage,
                            /* nrow */
                                &gridderInputDataShape_Int[2],
                            /* irow */
                                &row,
                            /* grid(nx, ny, npol, nchan) */
                                griddedDataStorage,
                            /* wgrid(nx, ny, npol, nchan) */
                                griddedWeightStorage,
                            /* nx */
                                &nx,
                            /* ny */
                                &ny,
                            /* npol */
                                &npol,
                            /* nchan */
                                &nchan,
                            /* support */
                                &convSupport,
                            /* sampling */
                                &convSampling,
                            /* convFunc */
                                convFunc.getStorage(del),
                            /* chanmap(nvischan) */
                                chanMap.getStorage(del),
                            /* polmap(nvispol) */
                                polMap.getStorage(del),
                            /* sumwt(npol, nchan) */
                                sumWeight.getStorage(del)
                        );

                  } else { // Call clipping gridder
                    Bool gminCopy;
                    Complex *gminStor = gmin_.getStorage(gminCopy);
                    Bool gmaxCopy;
                    Complex *gmaxStor = gmax_.getStorage(gmaxCopy);
                    Bool wminCopy;
                    Float *wminStor = wmin_.getStorage(wminCopy);
                    Bool wmaxCopy;
                    Float *wmaxStor = wmax_.getStorage(wmaxCopy);
                    Bool npCopy;
                    Int *npStor = npoints_.getStorage(npCopy);

                    ggridsdclip(
                      xyPositions.getStorage(del),
                      datStorage,
                      &s[0],
                      &s[1],
                      &idopsf,
                      flags.getStorage(del),
                      rowFlags.getStorage(del),
                      wgtStorage,
                      &s[2],
                      &row,
                      datStor,
                      wgtStor,
                      npStor,
                      gminStor,
                      wminStor,
                      gmaxStor,
                      wmaxStor,
                      &nx,
                      &ny,
                      &npol,
                      &nchan,
                      &convSupport,
                      &convSampling,
                      convFunc.getStorage(del),
                      chanMap.getStorage(del),
                      polMap.getStorage(del),
                      sumWeight.getStorage(del)
                    );

                    gmin_.putStorage(gminStor, gminCopy);
                    gmax_.putStorage(gmaxStor, gmaxCopy);
                    wmin_.putStorage(wminStor, wminCopy);
                    wmax_.putStorage(wmaxStor, wmaxCopy);
                    npoints_.putStorage(npStor, npCopy);
                  }
                  griddedData.putStorage(datStor, datCopy);
                  wGriddedData.putStorage(wgtStor, wgtCopy);
                }
          }

          { // Free memory
            if (not dopsf) data.freeStorage(datStorage, isCopy);

            elWeight.freeStorage(wgtStorage, iswgtCopy);
          }

        }
        """

#### CASA6 void SDGrid::pickWeights(const vi::VisBuffer2& vb, Matrix<Float>& weight)

In [None]:
class SDGrid:  # pickWeights
    def pickWeights(self):
        """
        void SDGrid::pickWeights(const vi::VisBuffer2& vb, Matrix<Float>& weight){
          weight.resize();

          if (useImagingWeight_p) {
            weight.reference(vb.imagingWeight());
          } else {
            const Cube<Float> weightSpec(vb.weightSpectrum());
            weight.resize(vb.nChannels(), vb.nRows());

            // CAS-9957 correct weight propagation from linear/circular correlations to Stokes I
            const auto toStokesWeight = [](float weight0, float weight1) {
                  const auto denominator = weight0 + weight1;
                  const auto numerator = weight0 * weight1;
                  constexpr float fmin = std::numeric_limits<float>::min();
                  return abs(denominator) < fmin ? 0.0f : 4.0f * numerator / denominator;
            };

            if (weightSpec.nelements() == 0) {
              const auto &weightMat = vb.weight();
              const ssize_t npol = weightMat.shape()(0);
              if (npol == 1) {
                const auto weight0 = weightMat.row(0);
                for (rownr_t k = 0; k < vb.nRows(); ++k) {
                  weight.column(k).set(weight0(k));
                }
              } else if (npol == 2) {
                const auto weight0 = weightMat.row(0);
                const auto weight1 = weightMat.row(1);
                for (rownr_t k = 0; k < vb.nRows(); ++k) {
                  weight.column(k).set(toStokesWeight(weight0(k), weight1(k)));
                }
              } else {
                // It seems current code doesn't support 4 pol case. So, give up
                // processing such data to avoid producing unintended result
                throw AipsError("Imaging full-Stokes data (npol=4) is not supported.");
              }
            } else {
              const ssize_t npol = weightSpec.shape()(0);
              if (npol == 1) {
                weight = weightSpec.yzPlane(0);
              } else if (npol == 2) {
                const auto weight0 = weightSpec.yzPlane(0);
                const auto weight1 = weightSpec.yzPlane(1);
                for (rownr_t k = 0; k < vb.nRows(); ++k) {
                  for (int chan = 0; chan < vb.nChannels(); ++chan) {
                    weight(chan, k) = toStokesWeight(weight0(chan, k), weight1(chan, k));
                  }
                }
              } else {
                // It seems current code doesn't support 4 pol case. So, give up
                // processing such data to avoid producing unintended result
                throw AipsError("Imaging full-Stokes data (npol=4) is not supported.");
              }
            }
          }
        }
        """

#### CASA6 SDGrid::pickWeights: Debug Session

In [None]:
debug_session = """
   2207	      } else if (npol == 2) {
   2208	        const auto weight0 = weightSpec.yzPlane(0);
-> 2209	        const auto weight1 = weightSpec.yzPlane(1);
   2210	        for (rownr_t k = 0; k < vb.nRows(); ++k) {
   2211	          for (int chan = 0; chan < vb.nChannels(); ++chan) {
   2212	            weight(chan, k) = toStokesWeight(weight0(chan, k), weight1(chan, k));
Target 0: (Python) stopped.
(lldb) p npol
(const ssize_t) 2
"""

## Create single_dish_gridder.py file
* Gathered functions from NAOJ imaging/gridding Jupiter notebook into:
  * single_dish_gridder.py file
* This is a Python implementation of casacore Fortran single-dish gridder

## Test astroviper Prolate Spheroidal Function
  * which is defined here:

In [None]:
"""
astroviper/_domain/_imaging/_imaging_utils/gcf_prolate_spheroidal.py
"""

In [None]:
import astroviper
import sys
import os
# from pprint import pprint
import numpy as np

imaging_utils_rel_path = "_domain/_imaging/_imaging_utils"
imaging_utils_abs_path = os.path.join(
    str(astroviper.__path__[0]),
    imaging_utils_rel_path
)
# print(imaging_utils_abs_path)
if imaging_utils_abs_path not in sys.path:
    sys.path.insert(0, imaging_utils_abs_path)

# pprint(sys.path)
from gcf_prolate_spheroidal import ( # pyright: ignore[reportMissingImports]
    _prolate_spheroidal_function as
    prolate_spheroidal_function
)

test_support = 6  # pixels
test_samples_per_pixel = 100
test_n_samples = test_support * test_samples_per_pixel + 1
sampling_points = np.linspace(0, 1, num=test_n_samples, endpoint=True)
_, convolution_function_values = (
    prolate_spheroidal_function(sampling_points)
)

# plt.plot(convolution_function_values, color="DodgerBlue")

gridder_conv_func_data = np.zeros(
    2 * sampling_points.shape[0],
    dtype = np.float32
)
# print(gridder_conv_func_data.shape)
gridder_conv_func_data[0:test_n_samples] = (
    convolution_function_values[0:test_n_samples]
)

fig, axes = plt.subplots()
plt.title("astroviper Prolate Spheroidal Function")
axes.plot(
    np.arange(convolution_function_values.shape[0]),
    convolution_function_values,
    color="DodgerBlue",
    linewidth=10,
    label="Function Values"

)
axes.plot(
    np.arange(gridder_conv_func_data.shape[0]),
    gridder_conv_func_data,
    color="Lime",
    label="Storage Array Values"
)
legend = axes.legend(borderpad=0.8)





# Call the Python Single Dish Gridder

In [None]:
import numpy as np
import sys
import os
from dotmap import DotMap
from astropy.units import Quantity
from single_dish_gridder import ggridsd
from xarray import DataArray
import astroviper

DEBUGGING = False


ggridsd_signature = """
ggridsd(
    grid: Annotated[
            numpy.ndarray[numpy.complex64],
            'shape=(nx, ny, npol, nchan)'
        ] = None, 
    wgrid: Annotated[
        numpy.ndarray[numpy.float32],
        'shape=(nx, ny, npol, nchan)'
    ] = None,
    sumwt: Annotated[
        numpy.ndarray[numpy.float32],
        'shape=(npol, nchan)'
    ] = None,
    xy: Annotated[
        numpy.ndarray[numpy.float64],
        'shape=(2, nrow)'
    ] = None,
    values: Annotated[
        numpy.ndarray[complex],
        'shape=(nvispol, nvischan, nrow)'
    ] = None,
    flag: Annotated[
        numpy.ndarray[bool],
        'shape=(nvispol, nvischan, nrow)'
    ] = None,
    irow: int = None,
    nrow: int = None,
    rflag: Annotated[
        numpy.ndarray[bool], 'shape=(nrow, )'
    ] = None,
    nvispol: int = None,
    nvischan: int = None,
    weight: Annotated[
        numpy.ndarray[float],
        'shape=(nvischan, nrow)'
    ] = None,
    nx: int = None,
    ny: int = None,
    npol: int = None,
    nchan: int = None,
    grid_weight: bool = None,
    chanmap: Annotated[
        numpy.ndarray[int],
        'shape=(nvischan, )'
    ] = None,
    polmap: Annotated[
        numpy.ndarray[int],
        'shape=(nvispol, )'
    ] = None,
    convFunc: Annotated[
        numpy.ndarray[numpy.float32],
        'shape=(unknown/runtime-defined, )'
    ] = None,
    support: int = None,
    sampling: int = None)

"""

# Gridder parameters
gridder_params = DotMap()

# In-Out Images
# ---- In-Out Images: shape
nx = x_axis_size
ny = y_axis_size
npol = polarization_axis_size
nchan = frequency_axis_size

# ---- In-Out Images: data arrays
image_shape = (nx, ny, npol, nchan)
image_array = np.zeros(image_shape, dtype=np.complex64)

weight_image_shape = image_shape
weight_image_array = np.zeros(weight_image_shape, dtype=np.float32)

sum_weight_shape = (npol, nchan)
sum_weight_array = np.zeros(sum_weight_shape, dtype=np.float32)

gridder_params.grid = image_array
gridder_params.wgrid = weight_image_array
gridder_params.sumwt = sum_weight_array

# Input Data and Metadata
# ---- Data Chunk Sizes: time dimension

nrow = (
    current_msv4_xdt.sizes['time'] if not DEBUGGING else 10_000
)

# ---- Metadata: data spatial coordinates
xy = pixel_coordinates.transpose()[:, 0:nrow]

assert xy.shape[0] == 2
assert xy.shape[1] == nrow, f"xy.shape[1]: {xy.shape[1]} != nrow: {nrow}"

gridder_params.xy = xy

# ---- Data: calibrated temperature
temperature: DataArray = current_msv4_xdt.VISIBILITY_CORRECTED.isel(
    {
        'baseline_id': 0  # Dirty, quick fix for API change in xradio 1.1
    }
).transpose('polarization', 'frequency', 'time').isel(
    time=slice(0, nrow)
)

gridder_params.values = temperature.values

# ---- Metadata: data flags
# flags: DataArray = current_msv4_xdt.FLAG.sel(
#     {
#         'antenna_name': f"{antenna_name}"
#     }
# ).transpose('polarization', 'frequency', 'time').isel(
#     time=slice(0, nrow)
# )
flags: DataArray = current_msv4_xdt.FLAG.isel(
    {
         'baseline_id': 0
    }
).transpose('polarization', 'frequency', 'time').isel(
    time=slice(0, nrow)
)
gridder_params.flag = flags.values

# Data Chunk: Row Selection
grid_all_rows = -1
grid_what_rows = grid_all_rows

gridder_params.irow = grid_what_rows

# Data Chunk: Time Dimension
# ---- For now: process all data in 1 chunk:
#  * The full solar regional maps (uid___A002_X11a51f7_X1a7) Processing Set's
#    zarr file is 3.3 GiB
#  * The single MsV4 of the full Sun scan we image
#    (1 antenna, 1 spectral window) is only 7.2 MiB

gridder_params.nrow = nrow

# Row flags: dropped in msv4
row_flags = np.zeros(nrow, dtype=bool)

gridder_params.rflag = row_flags

# Selected dataset dimensions
nvispol = temperature.sizes['polarization']
nvischan = temperature.sizes['frequency']

gridder_params.nvispol = nvispol
gridder_params.nvischan = nvischan


# "Channel weight"
weight = current_msv4_xdt.WEIGHT
weight.load()

# ---- Aggregate polarization weights
def stokesI_weight(w0, w1):
    """ SDGrid::pickweights

    // CAS-9957 correct weight propagation from linear/circular correlations to Stokes I
    const auto toStokesWeight = [](float weight0, float weight1) {
          const auto denominator = weight0 + weight1;
          const auto numerator = weight0 * weight1;
          constexpr float fmin = std::numeric_limits<float>::min();
          return abs(denominator) < fmin ? 0.0f : 4.0f * numerator / denominator;
    };
    """
    numerator = w0 * w1
    denominator = w0 + w1
    tiny = np.finfo(np.float32).tiny
    return np.divide(
        4 * numerator, denominator,
        out=np.zeros_like(numerator, dtype=np.float32),
        where=abs(denominator) > tiny
    )


channel_weight = xarray.apply_ufunc(
    stokesI_weight,
    weight.sel(dict(polarization="XX")),
    weight.sel(dict(polarization="YY")),
    input_core_dims=[['frequency', 'time'], ['frequency', 'time']],
    output_core_dims=[['frequency', 'time']],
    vectorize=True
)
channel_weight.name = "Channel Weight"
gridder_params.weight = channel_weight.squeeze(dim='baseline_id').isel(
    time=slice(0, nrow)
).values

# Image dimensions
gridder_params.nx = nx
gridder_params.ny = ny
gridder_params.npol = npol
gridder_params.nchan = nchan

# Data to grid: temperature value or temperature weight value
gridder_params.grid_weight = False

# Data binning
# ---- Data channel 0 goes to image channel 0
gridder_params.chanmap = np.zeros(nvischan, dtype=int)
# ---- Data polarizations XX and YY (0 and 1) go to
#      image polarization I (0)
gridder_params.polmap = np.zeros(nvispol, dtype=int)

# Discrete convolution function
gridder_params.convFunc = None  # For now
gridder_params.support = 6  # pixels, hardcode for now
                            # taken from Full Sun Imaging script
gridder_params.sampling = 100  # Hard-coded in SDGrid.cc

# ---- Discrete convolution function values
# Use astroviper Prolate Spheroidal function, defined in
# astroviper/<imaging_utils_rel_path>/gcf_prolate_spheroidal.py
imaging_utils_rel_path = "_domain/_imaging/_imaging_utils"
imaging_utils_abs_path = os.path.join(
    str(astroviper.__path__[0]),
    imaging_utils_rel_path
)

if imaging_utils_abs_path not in sys.path:
    sys.path.insert(0, imaging_utils_abs_path)

from gcf_prolate_spheroidal import ( # pyright: ignore[reportMissingImports]
    _prolate_spheroidal_function as
    prolate_spheroidal_function
)

n_samples = (
    gridder_params.support * gridder_params.sampling + 1
)
sampling_points = np.linspace(0, 1, num=n_samples, endpoint=True)
_, prolate_spheroidal_values = (
    prolate_spheroidal_function(sampling_points)
)

# CASA6 (casacore) gridder needs more values,
# so that it can make multiplications by zero ...
gridder_conv_func_values = np.zeros(
    2 * prolate_spheroidal_values.shape[0],
    dtype = np.float32
)

gridder_conv_func_values[0:n_samples] = (
    prolate_spheroidal_values[0:n_samples]
)

gridder_params.convFunc = gridder_conv_func_values
gridder_params_dict = gridder_params.toDict()

# Call the gridder
%time ggridsd(**gridder_params_dict)


# Create Single Dish Image

Divide Gridded Values by Gridded Weights

In [None]:
import numpy as np

tiny = np.finfo(np.float32).tiny

single_dish_image_array = np.divide(
    image_array.real, weight_image_array,
    out=np.zeros_like(weight_image_array, dtype=np.float32),
    where=abs(weight_image_array) > tiny
)


# Create Single Dish Image Mask

* For now, we'll pick the mask from the reference image

In [None]:
reference_image_path = (
    'reference-image/'
    'antenna-PM04.fullSun-num0_sci.spw3_'
    '.sd.TP.cont.I.manual.sd.not-rescaled.fits'
)

with fits.open(f'{reference_image_path}') as hdul:
    ref_image_array = np.transpose(hdul[0].data)
    print(ref_image_array.shape)
    ref_image_mask = np.isnan(ref_image_array)
    single_dish_image_array[ref_image_mask] = np.nan

# Display Single Dish Image

In [None]:
from pathlib import Path
single_dish_2d_image_array = single_dish_image_array.squeeze()

uid_name = Path(ms_file).resolve().name.removesuffix('.ms')
print(
    f"Dataset:            {uid_name}\n"
    f"MsV4:               {current_msv4_name}\n"
    f"Antenna Name:       {antenna_name}\n"
    f"Spectral Window ID: {spw_id_selection}\n"
)

axes_image = plt.imshow(
    single_dish_2d_image_array.transpose(),
    origin='lower'
)

# Export Imaging Results to FITS format

## Single Dish Image

In [None]:
from astropy.io import fits
import numpy as np

my_hdul = fits.HDUList(
    hdus=[
        fits.PrimaryHDU(
            data=np.transpose(single_dish_image_array),
            header=fits_header
        )
    ]
)

prototype_fits_image_name = (
    f"{uid_name}.antenna-{antenna_name_only}.spw-{spw_id_selection}"
    ".fits"
)
prototype_fits_image_path = f"results/{prototype_fits_image_name}"

# Save to disk
my_hdul.writeto(
    prototype_fits_image_path,
    overwrite=True
)


## Single Dish Weight Image

In [None]:
from astropy.io import fits
import numpy as np

weight_image_fits_header = fits_header.copy()
weight_image_fits_header.set(
    "BUNIT",
    value=" " * 8,
    comment="Brightness (pixel) unit"
)

weight_image_hdul = fits.HDUList(
    hdus=[
        fits.PrimaryHDU(
            data=np.transpose(weight_image_array),
            header=weight_image_fits_header
        )
    ]
)

prototype_fits_weight_image_name = (
    prototype_fits_image_name.removesuffix(".fits")
    + ".weight.fits"
)

prototype_fits_weight_image_path = (
    f"results/{prototype_fits_weight_image_name}"
)

# Save to disk
weight_image_hdul.writeto(
    prototype_fits_weight_image_path,
    overwrite=True
)



# Visualize Results and Compare With Reference CASA6 Images

## Visualization Code

In [None]:
import matplotlib.pyplot as plt

from astropy.wcs import WCS
from astropy.io import fits
from astropy import units as u

from enum import Enum

class ImageType(Enum):
    PlainImage = 1
    WeightImage = 2

u.inch = u.imperial.inch

def compare_images(
        ref_fits_image_path,
        prototype_fits_image_path,
        figure_path,
        image_type: ImageType=ImageType.PlainImage,
        ):

    if True:  # Figure
        figure_width = 30 * u.cm
        figure_height =  figure_width
        figure_size_inch = (
            figure_width.to(u.inch).value,
            figure_height.to(u.inch).value,
        )
        figure = plt.figure(
            figsize=figure_size_inch
        )
        figure.suptitle(
            f"ALMA Dataset: {uid_name} "
            f"Scan: {scan_name_selection} "
            f"Antenna: {antenna_name_only} "
            f"Spectral Window: {spw_id_selection}"
        )
        # Plots Grid
        nrows = 2
        ncols = 2

    if True:  # Reference Image Plot
        ref_hdu = fits.open(ref_fits_image_path)[0]
        ref_wcs = WCS(ref_hdu.header).sub(['longitude', 'latitude'])
        # ---- Axes
        ref_image_index = 1
        ref_image_axes = figure.add_subplot(
            nrows, ncols, ref_image_index,
            projection=ref_wcs
        )
        ref_image_data = ref_hdu.data[0, 0, :, :]
        ref_image_data_range = (
            np.nanmin(ref_image_data),
            np.nanmax(ref_image_data)
        )
        if image_type == ImageType.WeightImage:
            ref_percentile_99 = np.percentile(ref_image_data, 99)
            ref_image = ref_image_axes.imshow(
                ref_image_data,
                vmax=ref_percentile_99,
                origin='lower',
                cmap='inferno'
            )
        else:
            ref_image = ref_image_axes.imshow(
                ref_image_data,
                origin='lower',
                cmap='inferno'
            )
        # Axes Title
        if image_type == ImageType.PlainImage:
            ref_axes_title = "Reference CASA6 Image"
        else:
            ref_axes_title = "Reference CASA6 Weight Image"
        ref_image_axes.set_title(ref_axes_title)
        # Axes Labels
        reference_frame = ref_hdu.header.get('RADESYS')
        ref_image_axes.set_xlabel(f"Right Ascension ({reference_frame})")
        ref_image_axes.set_ylabel(f"Declination ({reference_frame})")
        # Axes Grid
        ref_image_axes.grid(color="white", linewidth=0.5)
        # Axes Colorbar
        colorbar = figure.colorbar(ref_image, ax=ref_image_axes)
        if image_type == ImageType.PlainImage:
            colorbar.set_label('Brightness Temperature [K]', fontsize=10)

    if True:  # Prototype RADPS Image Plot
        prototype_hdu = fits.open(prototype_fits_image_path)[0]
        prototype_wcs = WCS(prototype_hdu.header).sub(
            ['longitude', 'latitude']
        )
        # ---- Axes
        prototype_image_index = 2
        prototype_image_axes = figure.add_subplot(
            nrows, ncols, prototype_image_index,
            projection=prototype_wcs
        )
        prototype_image_data = prototype_hdu.data[0, 0, :, :]
        ref_data_min, ref_data_max = ref_image_data_range
        if image_type == ImageType.WeightImage:
            prototype_image = prototype_image_axes.imshow(
                prototype_image_data,
                vmax=ref_percentile_99,
                origin='lower',
                cmap='inferno'
            )
        else:
            prototype_image = prototype_image_axes.imshow(
                prototype_image_data,
                vmin=ref_data_min,
                vmax=ref_data_max,
                origin='lower',
                cmap='inferno'
            )
        # Axes Title
        if image_type == ImageType.PlainImage:
            prototype_axes_title = "Prototype RADPS Image"
        else:
            prototype_axes_title = "Prototype RADPS Weight Image"
        prototype_image_axes.set_title(prototype_axes_title)
        # Axes Labels
        reference_frame = prototype_hdu.header.get('RADESYS')
        prototype_image_axes.set_xlabel(f"Right Ascension ({reference_frame})")
        prototype_image_axes.set_ylabel(f"Declination ({reference_frame})")
        # Axes Grid
        prototype_image_axes.grid(color="white", linewidth=0.5)
        # Axes Colorbar
        colorbar = figure.colorbar(prototype_image, ax=prototype_image_axes)
        if image_type == ImageType.PlainImage:
            colorbar.set_label('Brightness Temperature [K]', fontsize=10)

    if True:  # Difference Image Plot
        # ---- Axes
        diff_axes_index = 3
        diff_image_axes = figure.add_subplot(
            nrows, ncols, diff_axes_index,
            projection=ref_wcs
        )
        diff_image_data = prototype_image_data - ref_image_data
        # "Hide" image mask
        diff_image_data[np.isnan(diff_image_data)] = 0
        diff_image = diff_image_axes.imshow(
            diff_image_data,
            origin='lower',
            cmap='viridis_r'
        )
        # Axes Title
        diff_image_axes.set_title("Prototype RADPS - Reference CASA6")
        # Axes Labels
        reference_frame = ref_hdu.header.get('RADESYS')
        diff_image_axes.set_xlabel(f"Right Ascension ({reference_frame})")
        diff_image_axes.set_ylabel(f"Declination ({reference_frame})")
        # Axes Grid
        diff_image_axes.grid(color="white", linewidth=0.5)
        # for label in image_axes.get_xticklabels():
        #     label.set_fontsize(4)

        # Axes Colorbar
        colorbar = figure.colorbar(diff_image, ax=diff_image_axes)
        if image_type == ImageType.PlainImage:
            colorbar.set_label('Temperature Difference [K]', fontsize=10)

    if True:  # Histograms Plot
        # Axes
        hist_axes_index = 4
        hist_axes = figure.add_subplot(
            nrows, ncols, hist_axes_index
        )
        # Valid Data
        def valid_data(t):
            return t[~np.isnan(t)]

        valid_ref_data = valid_data(ref_image_data)
        valid_prototype_data = valid_data(prototype_image_data)

        # Line plots
        data_arrays = [valid_ref_data, valid_prototype_data]
        attributes = [
            {
                "label": "Reference CASA6",
                "color": "DodgerBlue"
            },
            {
                "label": "Prototype RADPS",
                "color": "Orange"
            }
        ]
        for data_array, attribute in zip(data_arrays, attributes):
            counts, bin_edges = np.histogram(data_array, bins=100)
            bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
            hist_axes.plot(
                bin_centers, counts,
                label=attribute["label"],
                color=attribute["color"]
            )
        # Axes Formatting
        hist_axes.set_yscale('log')
        hist_axes.legend()
        hist_axes.grid(visible=True, which='both', axis='x')
        # Axes Title
        hist_axes.set_title("Images Histograms")
        # Axes Labels
        reference_frame = ref_hdu.header.get('RADESYS')
        if image_type == ImageType.PlainImage:
            x_label = "Brightness Temperature Bin Center [K]"
        else:
            x_label = "Weight Bin Center"
        hist_axes.set_xlabel(f"{x_label}")
        hist_axes.set_ylabel(f"Count")

        # Axes Colorbar
        # colorbar = figure.colorbar(image, ax=image_axes)
        # colorbar.set_label('Temperature Difference [K]', fontsize=10)

    figure.tight_layout()
    figure.savefig(figure_path)
    print(f"Saved figure: {figure_path}")

## Single Dish Images Comparison

In [None]:
# Reference CASA6 Image
ref_fits_image_name = (
    "antenna-PM04.fullSun-"
    "num0_sci.spw3_.sd.TP.cont.I.manual.sd.not-rescaled.fits"
)
ref_fits_image_path = (
    f"reference-image/{ref_fits_image_name}"
)

figure_name = prototype_fits_image_name.removesuffix('.fits')
figure_name += '.compare_with_ref.png'
figure_path = f"results/{figure_name}"

compare_images(
    ref_fits_image_path,
    prototype_fits_image_path,
    figure_path
)


## Single Dish Weight Images Comparison

In [None]:
# Reference CASA6 Weight Image
ref_fits_weight_image_name = (
    "fullSun-num0_sci.spw3_.sd.TP.cont.I.manual.sd.weight.fits"
)
ref_fits_weight_image_path = (
    f"reference-image/{ref_fits_weight_image_name}"
)

weight_figure_name = (
    prototype_fits_weight_image_name.removesuffix('.fits')
    + ".compare_with_ref.png"
)
weight_figure_path = f"results/{weight_figure_name}"

compare_images(
    ref_fits_weight_image_path,
    prototype_fits_weight_image_path,
    weight_figure_path,
    image_type=ImageType.WeightImage
)

# Conclusion
  * We have implemented in pure Python a prototype single-dish imager
    * taking as input an xradio Processing Set
      * https://xradio.readthedocs.io/en/latest/measurement_set/overview.html
    * creating 2 images as done by CASA6 tsdimaging task
      * a "plain" single-dish image and
      * its associated "weight" image
    * directly in FITS format (CASA6 tsdimaging task creates images in CASA format)
  * We have then compared the prototype imaging results with the CASA6 reference images
  * Results show that although images are not (as expected) strictly equal, they are reasonably close.


# Notes
  * Many features of current CASA6 tsdimaging tasks are not implemented, e.g.:
    * Frequency frame conversion
    * Channels binning
    * Ephemeris handling
    * Heuristics computing the "phasecenter" when not user-specified
    * etc, etc ...
  * Performances considerations are out of the scope of this work

# Questions
  * Next steps ?