Skip to content
This repository has been archived by the owner on Feb 7, 2024. It is now read-only.

Commit

Permalink
Fix nan fieldid detection.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan-Willem committed Apr 12, 2021
1 parent 9a23330 commit 3395f0f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 14 deletions.
15 changes: 8 additions & 7 deletions ngcasa/imaging/_imaging_utils/_aperture_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,12 @@ def _aperture_weight_grid_numpy_wrap(uvw,imaging_weight,field,cf_baseline_map,cf
sum_weight = np.zeros((n_imag_chan, n_imag_pol), dtype=np.double)

#print('Pos 2')

_aperture_weight_grid_jit(grid, sum_weight, uvw, freq_chan, chan_map, pol_map, cf_baseline_map, cf_chan_map, cf_pol_map, imaging_weight, weight_conv_kernel, n_uv, delta_lm, weight_support, oversampling, field, field_id, phase_gradient)




return grid, sum_weight
#XXXXXXXX

@jit(nopython=True, cache=True, nogil=True)
def _aperture_weight_grid_jit(grid, sum_weight, uvw, freq_chan, chan_map, pol_map, cf_baseline_map, cf_chan_map, cf_pol_map, imaging_weight, weight_conv_kernel, n_uv, delta_lm, weight_support, oversampling, field, field_id, phase_gradient):
c = 299792458.0
Expand Down Expand Up @@ -215,7 +213,7 @@ def _aperture_weight_grid_jit(grid, sum_weight, uvw, freq_chan, chan_map, pol_ma

for i_baseline in range(n_baseline):

if field[i_time,i_baseline] != const.INT_NAN:
if field[i_time,i_baseline] > -1:
field_indx = np.where(field_id == field[i_time,i_baseline])[0][0]

if prev_field != field_indx:
Expand Down Expand Up @@ -421,7 +419,7 @@ def _aperture_grid_jit(grid, sum_weight, do_psf, vis_data, uvw, freq_chan, chan_
# if field[i_time,i_baseline] == const.INT_NAN:
# print('Nan detected')

if field[i_time,i_baseline] != const.INT_NAN:
if field[i_time,i_baseline] > -1:
field_indx = np.where(field_id == field[i_time,i_baseline])[0][0]

# if field_indx != field[i_time,i_baseline]:
Expand Down Expand Up @@ -452,6 +450,7 @@ def _aperture_grid_jit(grid, sum_weight, do_psf, vis_data, uvw, freq_chan, chan_
v_offset = v_center_indx - v_pos
v_center_offset_indx = math.floor(v_offset * oversampling[1] + 0.5) + conv_v_center


for i_pol in range(n_pol):
if do_psf:
weighted_data = imaging_weight[i_time, i_baseline, i_chan, i_pol]
Expand Down Expand Up @@ -506,8 +505,10 @@ def _aperture_grid_jit(grid, sum_weight, do_psf, vis_data, uvw, freq_chan, chan_

grid[a_chan, a_pol, u_indx, v_indx] = grid[a_chan, a_pol, u_indx, v_indx] + conv * weighted_data
norm = norm + conv

sum_weight[a_chan, a_pol] = sum_weight[a_chan, a_pol] + imaging_weight[i_time, i_baseline, i_chan, i_pol]*np.real(norm**2)#*np.real(norm**2)#* np.real(norm) #np.abs(norm**2) #**2 term is needed since the pb is in the image twice (one naturally and another from the gcf)
if do_psf:
sum_weight[a_chan, a_pol] = sum_weight[a_chan, a_pol] + imaging_weight[i_time, i_baseline, i_chan, i_pol]*np.real(norm)
else:
sum_weight[a_chan, a_pol] = sum_weight[a_chan, a_pol] + imaging_weight[i_time, i_baseline, i_chan, i_pol]*np.real(norm**2)#*np.real(norm**2)#* np.real(norm) #np.abs(norm**2) #**2 term is needed since the pb is in the image twice (one naturally and another from the gcf)

return

Expand Down
15 changes: 12 additions & 3 deletions ngcasa/imaging/direction_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def calc_rotation_mats(vis_dataset,field_dataset,rotation_parms):
#print(field_dataset.field_id.values)

field_id = np.unique(vis_dataset.FIELD_ID) # Or should the roation matrix be calculated for all fields #field_id = field_dataset.field_id
field_id = field_id[field_id != const.INT_NAN] #remove nan
#print(field_id)
field_id = field_id[field_id > -1] #remove nan
#print('field_id',field_id)

n_fields = len(field_id)
uvw_rotmat = np.zeros((n_fields,3,3),np.double)
Expand All @@ -161,10 +161,13 @@ def calc_rotation_mats(vis_dataset,field_dataset,rotation_parms):
rotmat_field_phase_center = R.from_euler('ZX',[[-np.pi/2 + field_phase_center[0],field_phase_center[1] - np.pi/2]]).as_matrix()[0]
uvw_rotmat[i_field,:,:] = np.matmul(rotmat_new_phase_center,rotmat_field_phase_center).T

#print(uvw_rotmat[i_field,:,:])

if rotation_parms['common_tangent_reprojection'] == True:
uvw_rotmat[i_field,2,0:2] = 0.0 # (Common tangent rotation needed for joint mosaics, see last part of FTMachine::girarUVW in CASA)

field_phase_center_cosine = _directional_cosine(field_phase_center)
#print("i_field, field, new",i_field,field_phase_center_cosine,new_phase_center_cosine)
phase_rotation[i_field,:] = np.matmul(rotmat_new_phase_center,(new_phase_center_cosine - field_phase_center_cosine))

return uvw_rotmat, phase_rotation, field_id
Expand All @@ -191,12 +194,18 @@ def apply_rotation_matrix(uvw, field_id, uvw_rotmat, rot_field_id):
#uvw[i_time,:,0:2] = -uvw[i_time,:,0:2] this gives the same result as casa (in the ftmachines uvw(negateUV(vb)) is used). In ngcasa we don't do this since the uvw definition in the gridder and vis.zarr are the same.
field_id_t = field_id[i_time,:,0]

unique_field_id = np.unique(field_id_t[field_id_t != const.INT_NAN])
unique_field_id = np.unique(field_id_t[field_id_t > -1])
#print(unique_field_id)
assert len(unique_field_id)==1, "direction_rotate only supports xds where field_id remains constant over baseline."
rot_field_indx = np.where(rot_field_id == unique_field_id[0])[0][0] #should be len 1
#print('rot_field_indx',rot_field_indx)

uvw_rot[i_time,:,:] = uvw[i_time,:,:] @ uvw_rotmat[rot_field_indx,:,:] #uvw time x baseline x uvw_indx, uvw_rotmat n_field x 3 x 3. 1 x 3 @ 3 x 3

#print('uvw_rotmat[rot_field_indx,:,:] ',uvw_rotmat[rot_field_indx,:,:] )
#print('uvw[i_time,:,:]', uvw[i_time,:,:])
#print('uvw_rot[i_time,:,:] ',uvw_rot[i_time,:,:])

#field_id_t = field_id[i_time,0,:]
#uvw[i_time,:,:] = uvw[i_time,:,:] @ uvw_rotmat[field_id_t,:,:] #uvw time x baseline x uvw_indx, uvw_rotmat n_field x 3 x 3. 1 x 3 @ 3 x 3
return uvw_rot
Expand Down
122 changes: 118 additions & 4 deletions ngcasa/imaging/make_grid.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,127 @@
this module will be included in the api
"""

def make_grid(vis_dataset, user_grid_parms, storage_parms):
#Removed for now.
#grid_parms['oversampling'] : int, default = 100
# The oversampling used for the convolutional gridding kernel. This will be removed in a later release and incorporated in the function that creates gridding convolutional kernels.
#grid_parms['support'] : int, default = 7
# The full support used for convolutional gridding kernel. This will be removed in a later release and incorporated in the function that creates gridding convolutional kernels.
#

def make_grid(vis_mxds, img_xds, grid_parms, vis_sel_parms, img_sel_parms):
"""
.. todo::
This function is not yet implemented
Parameters
----------
vis_mxds : xarray.core.dataset.Dataset
Input multi-xarray Dataset with global data.
img_xds : xarray.core.dataset.Dataset
Input image dataset.
grid_parms : dictionary
grid_parms['image_size'] : list of int, length = 2
The image size (no padding).
grid_parms['cell_size'] : list of number, length = 2, units = arcseconds
The image cell size.
grid_parms['chan_mode'] : {'continuum'/'cube'}, default = 'continuum'
Create a continuum or cube image.
grid_parms['fft_padding'] : number, acceptable range [1,100], default = 1.2
The factor that determines how much the gridded visibilities are padded before the fft is done.
vis_sel_parms : dictionary
vis_sel_parms['xds'] : str
The xds within the mxds to use to calculate the imaging weights for.
vis_sel_parms['data_group_in_id'] : int, default = first id in xds.data_groups
The data group in the xds to use.
img_sel_parms : dictionary
img_sel_parms['data_group_in_id'] : int, default = first id in xds.data_groups
The data group in the image xds to use.
img_sel_parms['image'] : str, default ='IMAGE'
The created image name.
img_sel_parms['sum_weight'] : str, default ='SUM_WEIGHT'
The created sum of weights name.
Returns
-------
img_xds : xarray.core.dataset.Dataset
The image_dataset will contain the image created and the sum of weights.
"""
print('######################### Start make_image #########################')
import numpy as np
from numba import jit
import time
import math
import dask.array.fft as dafft
import xarray as xr
import dask.array as da
import matplotlib.pylab as plt
import dask
import copy, os
from numcodecs import Blosc
from itertools import cycle

from cngi._utils._check_parms import _check_sel_parms, _check_existence_sel_parms
from ._imaging_utils._check_imaging_parms import _check_grid_parms
from ._imaging_utils._gridding_convolutional_kernels import _create_prolate_spheroidal_kernel, _create_prolate_spheroidal_kernel_1D
from ._imaging_utils._standard_grid import _graph_standard_grid
from ._imaging_utils._remove_padding import _remove_padding
from ._imaging_utils._aperture_grid import _graph_aperture_grid
from cngi.image import make_empty_sky_image

#print('****',sel_parms,'****')
_mxds = vis_mxds.copy(deep=True)
_img_xds = img_xds.copy(deep=True)
_vis_sel_parms = copy.deepcopy(vis_sel_parms)
_img_sel_parms = copy.deepcopy(img_sel_parms)
_grid_parms = copy.deepcopy(grid_parms)

##############Parameter Checking and Set Defaults##############
assert(_check_grid_parms(_grid_parms)), "######### ERROR: grid_parms checking failed"
assert('xds' in _vis_sel_parms), "######### ERROR: xds must be specified in sel_parms" #Can't have a default since xds names are not fixed.
_vis_xds = _mxds.attrs[_vis_sel_parms['xds']]

#Check vis data_group
_check_sel_parms(_vis_xds,_vis_sel_parms)

#Check img data_group
_check_sel_parms(_img_xds,_img_sel_parms,new_or_modified_data_variables={'sum_weight':'SUM_WEIGHT','grid':'GRID'},append_to_in_id=True)

##################################################################################

# Creating gridding kernel
_grid_parms['oversampling'] = 100
_grid_parms['support'] = 7

cgk, correcting_cgk_image = _create_prolate_spheroidal_kernel(_grid_parms['oversampling'], _grid_parms['support'], _grid_parms['image_size_padded'])
cgk_1D = _create_prolate_spheroidal_kernel_1D(_grid_parms['oversampling'], _grid_parms['support'])

_grid_parms['complex_grid'] = True
_grid_parms['do_psf'] = False
grids_and_sum_weights = _graph_standard_grid(_vis_xds, cgk_1D, _grid_parms, _vis_sel_parms)


if _grid_parms['chan_mode'] == 'continuum':
freq_coords = [da.mean(_vis_xds.coords['chan'].values)]
chan_width = da.from_array([da.mean(_vis_xds['chan_width'].data)],chunks=(1,))
imag_chan_chunk_size = 1
elif _grid_parms['chan_mode'] == 'cube':
freq_coords = _vis_xds.coords['chan'].values
chan_width = _vis_xds['chan_width'].data
imag_chan_chunk_size = _vis_xds.DATA.chunks[2][0]

phase_center = _grid_parms['phase_center']
image_size = _grid_parms['image_size']
cell_size = _grid_parms['cell_size']
phase_center = _grid_parms['phase_center']

pol_coords = _vis_xds.pol.data
time_coords = [_vis_xds.time.mean().data]

_img_xds = make_empty_sky_image(_img_xds,grid_parms['phase_center'],image_size,cell_size,freq_coords,chan_width,pol_coords,time_coords)

_img_xds[_img_sel_parms['data_group_out']['sum_weight']] = xr.DataArray(grids_and_sum_weights[1][None,:,:], dims=['time','chan','pol'])
_img_xds[_img_sel_parms['data_group_out']['grid']] = xr.DataArray(grids_and_sum_weights[0][:,:,None,:,:], dims=['u', 'v', 'time', 'chan', 'pol'])
_img_xds.attrs['data_groups'][0] = {**_img_xds.attrs['data_groups'][0],**{_img_sel_parms['data_group_out']['id']:_img_sel_parms['data_group_out']}}


print('######################### Created graph for make_image #########################')
return _img_xds

6 changes: 6 additions & 0 deletions ngcasa/imaging/make_gridding_convolution_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def make_gridding_convolution_function(mxds, gcf_parms, grid_parms, sel_parms):
cf_baseline_map,pb_ant_pairs = create_cf_baseline_map(_gcf_parms['unique_ant_indx'],_gcf_parms['basline_ant'],n_unique_ant)

cf_chan_map, pb_freq = create_cf_chan_map(_gcf_parms['freq_chan'],_gcf_parms['chan_tolerance_factor'])
#print('****',pb_freq)
pb_freq = da.from_array(pb_freq,chunks=np.ceil(len(pb_freq)/_gcf_parms['a_chan_num_chunk'] ))

cf_pol_map = np.zeros((len(_gcf_parms['pol']),),dtype=int) #create_cf_pol_map(), currently treating all pols the same
Expand Down Expand Up @@ -448,6 +449,11 @@ def calc_conv_size(sub_a_term,imsize,support_cut_level,oversampling,max_support)
assert(support_y < max_support[1]), "######### ERROR: support_cut_level too small or imsize too small." + str(support_y) + ",*," + str(max_support[1])

#print('approx_conv_size_x,approx_conv_size_y',approx_conv_size_x,approx_conv_size_y,support_x,support_y,max_support)
#print('support_x, support_y',support_x, support_y)
if support_x > support_y:
support_y = support_x
else:
support_x = support_y
return [support_x, support_y]

##########################
Expand Down

0 comments on commit 3395f0f

Please sign in to comment.