<h1>Create initial images to be used in feather tutorial</h1>

In [None]:
import pkg_resources, os
casa_data_dir = pkg_resources.resource_filename("casadata", "__data__")
rc_file = open(os.path.expanduser("~/.casarc"), "a+")  # append mode
rc_file.write("\nmeasures.directory: " + casa_data_dir)
rc_file.close()

In [None]:
# download base images
from graphviper.utils.data import download

download("feather_sim_sd_c1_pI.im")
download("feather_sim_vla_c1_pI.im")

In [None]:
# from graphviper.dask.client import local_client
# viper_client = local_client(cores=4, memory_limit="4GB")

# import dask
# dask.config.set(scheduler="synchronous")
# dask.config.set(scheduler="threads")

<h2>Inputs to be specified by user</h2>

In [None]:
# ra, dec size, should not exceed 4096 x 4096
imsize = [2048, 2048]

# number of channels
nchan = 512

# currently, there is only one polarization and it is I

In [None]:
from xradio.image.image import (
    make_empty_sky_image
)
import numpy as np
rad_per_arcsec = np.pi/180/3600
skel_xds = make_empty_sky_image(
    phase_center=[0.6, -0.2],
    image_size=imsize,
    cell_size=[15*rad_per_arcsec, 15*rad_per_arcsec],
    chan_coords=np.linspace(1.4e9, 1.5e9, nchan),
    pol_coords=["I"],
    time_coords=[0],
)
skel_xds

In [None]:
from xradio.image import read_image

sel_dict = {}
if imsize[0] < 4096:
    blc = 2048 - imsize[0]//2
    l_slice = slice(blc, blc + imsize[0])
    sel_dict["l"] = l_slice
if imsize[1] < 4096:
    blc = 2048 - imsize[1]//2
    m_slice = slice(blc, blc + imsize[1])
    sel_dict["m"] = m_slice
xds_sd_temp = read_image("feather_sim_sd_c1_pI.im").isel(sel_dict)
xds_sd_temp

In [None]:
xds_int_temp = read_image("feather_sim_vla_c1_pI.im").isel(sel_dict)
#xds_int_temp
xds_int_temp

In [None]:
import dask.array as da
import xarray as xr
dm = skel_xds.sizes
data = da.zeros(
    [
        dm["time"], dm["polarization"],
        dm["frequency"], dm["l"], dm["m"]
    ], dtype=np.float32
)
zeros = xr.DataArray(data=data, coords=skel_xds.coords, dims=skel_xds.dims)
zeros

In [None]:
import copy

sky = xr.DataArray(
    zeros.copy(), coords=skel_xds.coords, dims=skel_xds.dims
)
for i in (0, 1):
    print(i)
    xds = copy.deepcopy(skel_xds)
    xds["SKY"] = sky.copy()
    for j in range(0, nchan, 16):
        min_chan = j
        max_chan = min(j+16, nchan)
        fx = xds_sd_temp if i == 0 else xds_int_temp
        print(id(fx))
        xds["SKY"][{"frequency": slice(min_chan, max_chan)}] = (
            fx["SKY"].values
        )
        print(f"xds {id(xds)}")
    if i == 0:
        xds_sd = xds
        xds_sd.attrs["beam"] = copy.deepcopy(xds_sd_temp.attrs["beam"])
    else:
        xds_int = xds
        xds_int.attrs["beam"] = copy.deepcopy(xds_int_temp.attrs["beam"])
    
xds_sd

In [None]:
bytes_in_dtype = {'float32':4,'double':8,'complex':16}

#chunking_dims_sizes = {'frequency':int_xds["sky"].sizes['frequency']}
#memory_singleton_chunk = 3*np.product(np.array(list(chunking_dims_sizes.values())))
xds_sd['SKY'].sizes['frequency']

singleton_chunk_sizes = dict(xds_sd['SKY'].sizes) 
print(singleton_chunk_sizes)
del singleton_chunk_sizes['frequency'] #Remove dimensions that will be chuncked on.
print(singleton_chunk_sizes)
fudge_factor = 1.1
n_images_in_memory = 3.0
memory_singleton_chunk = n_images_in_memory*np.product(np.array(list(singleton_chunk_sizes.values())))*fudge_factor*bytes_in_dtype[str(xds_sd['SKY'].dtype)]/(1024**3)


memory_singleton_chunk

In [None]:
xds_sd["SKY"].sel(polarization="I").isel(frequency=0, time=0).plot()

In [None]:
# This is a point source, so may not be obvious
# in this plot
xds_int["SKY"].sel(polarization="I").isel(frequency=0, time=0).plot()

In [None]:
# These are the input images for the next step

import os, shutil

from xradio.image import write_image
for xds, outfile in zip([xds_sd, xds_int], ["sd.zarr", "int.zarr"]):
    if os.path.exists(outfile):
        shutil.rmtree(outfile)
    write_image(xds, outfile, "zarr")
    print(f"Wrote {outfile}")