# Image Data Group

- sky \<fft\> visibility  
- visibility_normalization  
- primary_beam \<fft\> aperture  
- primary_beam_squared 
- aperture_normalization  
- point_spread_function \<fft\> uv_sampling  
- uv_sampling_normalization  
- sky_model
- mask  


# Image Schema Suggested Changes

Each xda is it owns image and all non-coordinate meta data should be contained in the xda.attrs. 

In [14]:
from astropy.io import fits
import os
from xradio.image._util._fits.xds_from_fits import (
    _fits_header_to_xds_attrs,
    _fits_image_to_xds,
)
import xarray as xr
import json
import numpy as np

import numpy as np

def make_json_serializable(attributes):
    """
    Recursively convert attributes to JSON-serializable formats.
    """
    if isinstance(attributes, dict):
        return {key: make_json_serializable(value) for key, value in attributes.items()}
    elif isinstance(attributes, list):
        return [make_json_serializable(item) for item in attributes]
    elif isinstance(attributes, np.ndarray):
        return attributes.tolist()
    elif isinstance(attributes, np.float64):
        return float(attributes)
    elif isinstance(attributes, (np.integer, np.floating)):
        return attributes.item()
    elif isinstance(attributes, (np.bool_)):
        return bool(attributes)
    else:
        return attributes



class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)


images_dir = "Antennae_images_cube"

is_xds = xr.Dataset() # image_set xds


def convert_to_std_measures(meas_dict):
    new_meas_dict = {}
    new_meas_dict["data"] = meas_dict["value"]
    new_meas_dict["dims"] = []
    del meas_dict["value"]
    new_meas_dict["attrs"] = meas_dict

    return new_meas_dict


# for image_name in ["Antennae_North.cube.image.fits"]:
for image_name in os.listdir(images_dir):
    if "fits" in image_name:
        fits_image = os.path.join(images_dir, image_name)
        xds = _fits_image_to_xds(
            fits_image, chunks={}, verbose=False, do_sky_coords=False
        )

        xda = xds["SKY"]
        xda.attrs = make_json_serializable(xds.attrs)  # An image is a single data variable.
        
        if "active_mask" in xda.attrs:
            del xda.attrs["active_mask"]

        if "direction" in xda.attrs:
            xda.attrs["direction_info"] = xda.attrs["direction"]
            xda.attrs["direction_info"]["reference"] = convert_to_std_measures(
                xda.attrs["direction"]["reference"]
            )
            xda.attrs["direction_info"]["lonpole"] = xda.attrs["direction"][
                "longpole"
            ]  # Fits uses longpole but Astropy uses lonpole.
            del xda.attrs["direction_info"]["longpole"]
            del xda.attrs["direction"]

        if "pointing_center" in xda.attrs:
            initial = xda.attrs["pointing_center"]["initial"]
            xda.attrs["direction_info"]["primary_beam_center"] = {
                "data": xda.attrs["pointing_center"]["value"],
                "dims": [],
                "attrs": {
                    "initial": initial,
                    "type": "sky_coord",
                    "frame": "FK5",
                    "equinox": "J2000.0",
                    "units": ["rad", "rad"],
                },
            }
            del xda.attrs["pointing_center"]

        if "beam" in xda.attrs:
            if xda.attrs["beam"] is not None:
                xda.attrs["beam_info"] = {}
                xda.attrs["beam_info"]["major_axis"] = xda.attrs["beam"]["bmaj"]
                xda.attrs["beam_info"]["minor_axis"] = xda.attrs["beam"]["bmin"]
                xda.attrs["beam_info"]["position_angle"] = xda.attrs["beam"]["pa"]
            del xda.attrs["beam"]

        if "obsdate" in xda.attrs:
            xda.attrs["observation_date"] = convert_to_std_measures(
                xda.attrs["obsdate"]
            )
            xda.attrs["observation_date"]["attrs"]["scale"] = xda.attrs[
                "observation_date"
            ]["attrs"]["scale"].lower()
            xda.attrs["observation_date"]["attrs"]["format"] = xda.attrs[
                "observation_date"
            ]["attrs"]["format"].lower()
            del xda.attrs["obsdate"]

        if "telescope" in xda.attrs:
            xda.attrs["telescope_info"] = xds.attrs["telescope"]
            xda.attrs["telescope_info"]["position"] = convert_to_std_measures(
                xda.attrs["telescope_info"]["position"]
            )
            xda.attrs["telescope_info"]["position"]["data"] = xda.attrs["telescope_info"]["position"]["data"].tolist()
            del xda.attrs["telescope"]

        if "sumwt" in image_name:
            xda = xda.squeeze(dim=["l", "m"], drop=True)

        xda = xda.transpose(
            "time", "frequency", "polarization", "l", "m", missing_dims="ignore"
        )

        if "image" in image_name:
            is_xds["SKY"] = xda
            is_xds["SKY"].attrs["type"] = "sky"
            is_xds["SKY"].attrs["units"] = ["Jy/beam"]

        if "model" in image_name:
            is_xds["SKY_MODEL"] = xda
            is_xds["SKY_MODEL"].attrs["type"] = "sky_model"
            is_xds["SKY_MODEL"].attrs["units"] = ["Jy/beam"]

        if "mask" in image_name:
            is_xds["MASK"] = xda
            is_xds["MASK"].attrs["type"] = "mask"
            is_xds["MASK"].attrs["units"] = ["None"]

        if "pb" in image_name:
            is_xds["PRIMARY_BEAM"] = xda
            is_xds["PRIMARY_BEAM"].attrs["type"] = "primary_beam"
            is_xds["PRIMARY_BEAM"].attrs["units"] = ["None"]

        if "psf" in image_name:
            is_xds["POINT_SPREAD_FUNCTION"] = xda
            is_xds["POINT_SPREAD_FUNCTION"].attrs["type"] = "point_spread_function"
            is_xds["POINT_SPREAD_FUNCTION"].attrs["units"] = ["None"]
            
        if "residual" in image_name:
            is_xds["RESIDUAL"] = xda
            is_xds["RESIDUAL"].attrs["type"] = "residual"
            is_xds["RESIDUAL"].attrs["units"] = ["None"]

        if "sumwt" in image_name:
            is_xds["VISIBILITY_NORMALIZATION"] = xda
            is_xds["VISIBILITY_NORMALIZATION"].attrs[
                "type"
            ] = "visibility_normalization "
            is_xds["VISIBILITY_NORMALIZATION"].attrs["units"] = ["None"]

        if "weight" in image_name:
            is_xds["PRIMARY_BEAM_SQUARED"] = xda
            is_xds["PRIMARY_BEAM_SQUARED"].attrs["type"] = "primary_beam_squared"
            is_xds["PRIMARY_BEAM_SQUARED"].attrs["units"] = ["None"]
            
if "time" in is_xds.dims:
    is_xds.time.attrs["scale"] = is_xds.time.attrs["scale"].lower()
    is_xds.time.attrs["units"] = [is_xds.time.attrs["units"]]
    is_xds.time.attrs["format"] = is_xds.time.attrs["format"].lower()
    
if "frequency" in is_xds.dims:
    rest_freq = is_xds.frequency.attrs["rest_frequency"]
    is_xds.frequency.attrs["rest_frequency"] = convert_to_std_measures(rest_freq)
    is_xds.frequency.attrs["units"] = [is_xds.frequency.attrs["units"]]
    is_xds.frequency.attrs["observer"] = is_xds.frequency.attrs["frame"].lower()
    is_xds.frequency.attrs["wave_unit"] = [is_xds.frequency.attrs["wave_unit"]]
    
    #crval now has a descriptive name:
    is_xds.frequency.attrs["reference_value"] = {'data':is_xds.frequency.attrs["crval"], 'dims':[], 'attrs':{'units':is_xds.frequency.attrs["units"], 'type':'frequency', 'observer':is_xds.frequency.attrs["observer"]}}
    del is_xds.frequency.attrs["cdelt"] #Remove since this information is contained within the cooridinate.
    del is_xds.frequency.attrs["frame"]
    
    
if "l" in is_xds.dims:
    is_xds.l.attrs["units"] = [is_xds.l.attrs["units"]]
    is_xds.l.attrs['reference_pixel'] =  int(is_xds.l.attrs['crval']) #0 based indexing
    del is_xds.l.attrs["cdelt"]
    del is_xds.l.attrs["crval"]
    
if "m" in is_xds.dims:
    is_xds.m.attrs["units"] = [is_xds.m.attrs["units"]]
    is_xds.m.attrs['reference_pixel'] =  int(is_xds.m.attrs['crval']) #0 based indexing
    del is_xds.m.attrs["cdelt"]
    del is_xds.m.attrs["crval"]
           
is_xds.attrs["data_groups"] = {
    "base": {
        "sky": "SKY",
        "sky_model": "SKY_MODEL",
        "mask": "MASK",
        "primary_beam": "PRIMARY_BEAM",
        "point_spread_function": "POINT_SPREAD_FUNCTION",
        "residual": "RESIDUAL",
        "visibility_normalization": "VISIBILITY_NORMALIZATION",
        "primary_beam_squared": "PRIMARY_BEAM_SQUARED",
    }
}

is_xds.attrs['schema_verion'] = '0.0.1'
is_xds.attrs['type'] = 'image'
is_xds.attrs['xradio_version'] = '0.0.43'
is_xds.attrs['creation_date'] = '2024-11-12T14:06:24.605819'

#print(json.dumps(is_xds["SKY"].attrs, cls=NumpyEncoder, indent=4))

is_xds

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 66.76 MiB 66.76 MiB Shape (1, 70, 1, 500, 500) (1, 70, 1, 500, 500) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",70  1  500  500  1,

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,280 B,280 B
Shape,"(1, 70, 1)","(1, 70, 1)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 280 B 280 B Shape (1, 70, 1) (1, 70, 1) Dask graph 1 chunks in 6 graph layers Data type float32 numpy.ndarray",1  70  1,

Unnamed: 0,Array,Chunk
Bytes,280 B,280 B
Shape,"(1, 70, 1)","(1, 70, 1)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 66.76 MiB 66.76 MiB Shape (1, 70, 1, 500, 500) (1, 70, 1, 500, 500) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",70  1  500  500  1,

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 66.76 MiB 66.76 MiB Shape (1, 70, 1, 500, 500) (1, 70, 1, 500, 500) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",70  1  500  500  1,

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 66.76 MiB 66.76 MiB Shape (1, 70, 1, 500, 500) (1, 70, 1, 500, 500) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",70  1  500  500  1,

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 66.76 MiB 66.76 MiB Shape (1, 70, 1, 500, 500) (1, 70, 1, 500, 500) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",70  1  500  500  1,

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 66.76 MiB 66.76 MiB Shape (1, 70, 1, 500, 500) (1, 70, 1, 500, 500) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",70  1  500  500  1,

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 66.76 MiB 66.76 MiB Shape (1, 70, 1, 500, 500) (1, 70, 1, 500, 500) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",70  1  500  500  1,

Unnamed: 0,Array,Chunk
Bytes,66.76 MiB,66.76 MiB
Shape,"(1, 70, 1, 500, 500)","(1, 70, 1, 500, 500)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [2]:
is_xds.data_groups

{'base': {'sky': 'SKY',
  'sky_model': 'SKY_MODEL',
  'mask': 'MASK',
  'primary_beam': 'PRIMARY_BEAM',
  'point_spread_function': 'POINT_SPREAD_FUNCTION',
  'residual': 'RESIDUAL',
  'visibility_normalization': 'VISIBILITY_NORMALIZATION',
  'primary_beam_squared': 'PRIMARY_BEAM_SQUARED'}}

In [None]:
is_xds.to_zarr("Antennae_images_cube.img.zarr",mode='w')

# Suggested API Changes:
   - open_image_set_xds
   - load_image_set_xds


In [None]:
#tclean command used to generate the images:

# tclean(
#     vis="Antennae_North.cal.ms",
#     datacolumn="data",
#     imagename="Antennae_North.cube",
#     spw="0",
#     field="",
#     phasecenter=12,
#     specmode="cube",
#     outframe="LSRK",
#     restfreq="345.79599GHz",
#     nchan=70,
#     start="1200km/s",
#     width="10km/s",
#     gridder="mosaic",
#     mosweight=True,
#     deconvolver="hogbom",
#     imsize=500,
#     cell="0.13arcsec",
#     pblimit=0.2,
#     restoringbeam="common",
#     interactive=True,
#     weighting="briggsbwtaper",
#     robust=0.5,
#     niter=20000,
#     threshold="5.0mJy",
#     savemodel="modelcolumn",
# )

In [None]:
# Convert CASA images to fits files in CASA:

# import os

# images_dir = "Antennae_images_cube"

# for image_name in os.listdir(images_dir):
#     fits_image_name = image_name + ".fits"
#     try:
#         exportfits(
#             os.path.join(images_dir, image_name),
#             os.path.join(images_dir, fits_image_name),
#         )
#     except:
#         print("Error converting image: ", image_name)