In [337]:
import numpy as np
from lsst.daf.persistence import Butler
butler = Butler('/datasets/hsc/repo/rerun/DM-13666/DEEP')

In [338]:
sky_map = butler.get('deepCoadd_skyMap')

In [339]:
import numpy as np

class ArbitrarySkyQuery(object):

    def __init__(self, sky_map):
    
        self.sky_map = sky_map
    
        # an array of the centers of all tracts (Cartesian)
        self.tract_center_array = np.zeros((len(sky_map), 3), dtype=float)

        # an array of the angular radii of all tracts (in degrees)
        self.tract_radius_array = np.zeros(len(sky_map), dtype=float)

        # a dict keyed on tract id that gives the cartesian centers
        # of all patches
        self.tract_patch_centers = {}
        self.tract_patch_radii = {}

        # a dict that will associate the integer id of a patch
        # in self.tract_patch_center_dict[tract_id] with the
        # (ix, iy) identity of the patch
        self.tract_patch_lookup = {}
        self.tract_patch_initialized = {}

        for ii in range(len(self.sky_map)):
            tract = self.sky_map[ii]
            wcs = tract.getWcs()
            bbox = tract.getBBox()

            tract_center = np.array(tract.getCtrCoord().getVector())                          
               
            dist_max = None
            for xx in [bbox.getMinX(), bbox.getMaxX()]:
                for yy in [bbox.getMinY(), bbox.getMaxY()]:
                    sky_pt = np.array(wcs.pixelToSky(xx, yy).getVector())
                    dp = np.dot(tract_center, sky_pt)
                    rr = np.abs(np.arccos(dp))
                    if dist_max is None or rr>dist_max:
                        dist_max = rr
            self.tract_radius_array[ii] = np.degrees(dist_max)
 
            self.tract_center_array[ii] = tract_center
            patch_grid_dim = tract.getNumPatches()
            self.tract_patch_lookup[ii] = []
            self.tract_patch_initialized[ii] = False
            ct = 0
            for i1 in range(patch_grid_dim[0]):
                for i2 in range(patch_grid_dim[1]):
                    self.tract_patch_lookup[ii].append((i1,i2))
                    ct += 1
            self.tract_patch_lookup[ii] = np.array(self.tract_patch_lookup[ii])
            self.tract_patch_centers[ii] = np.zeros((ct,3), dtype=float)
            self.tract_patch_radii[ii] = np.zeros(ct, dtype=float)

    def find_tracts(self, ra, dec, angular_radius):
        """
        ra, dec, angular radius are all in degrees
        """
        ra_rad = np.radians(ra)
        dec_rad = np.radians(dec)
        xyz = np.array([np.cos(dec_rad)*np.cos(ra_rad),
                        np.cos(dec_rad)*np.sin(ra_rad),
                        np.sin(dec_rad)])
        dot_prod = np.dot(self.tract_center_array, xyz)
        valid = np.where(dot_prod>np.cos(np.radians(angular_radius +
                                                    self.tract_radius_array)))

        print('valid dot')
        print(dot_prod[valid])
        for i_valid in valid[0]:
            print(np.dot(xyz,self.tract_center_array[i_valid]))
        return valid
    
    def _initialize_patch_coords(self, tract_id):
        """
        compute the centers and radii for the patches in a tract

        tract_id is an integer denoting the tract we are dealing with
        """
        if self.tract_patch_initialized[tract_id]:
            return
        print('initializing ',tract_id)
        tract = self.sky_map[tract_id]
        wcs = tract.getWcs()

        xyz_array = np.zeros((4,3), dtype=float)
        for ii in range(len(self.tract_patch_lookup[tract_id])):
            patch_id = self.tract_patch_lookup[tract_id][ii]
            patch = tract[patch_id[0], patch_id[1]]
            bbox = patch.getOuterBBox()
            ct = 0
            for xx in [bbox.getMinX(), bbox.getMaxX()]:
                for yy in [bbox.getMinY(), bbox.getMaxY()]:
                    sky = np.array(wcs.pixelToSky(xx, yy).getVector())
                    xyz_array[ct] = sky
                    ct += 1
            center_xyz = np.array([np.mean(xyz_array[:,0]),
                                   np.mean(xyz_array[:,1]),
                                   np.mean(xyz_array[:,2])])
            self.tract_patch_centers[tract_id][ii] = center_xyz
            
            dist_max = None
            for xyz in xyz_array:
                dp = np.dot(xyz, center_xyz)
                rr = np.degrees(np.abs(np.arccos(dp)))
                if dist_max is None or rr> dist_max:
                    dist_max = rr
            self.tract_patch_radii[tract_id][ii] = dist_max

        self.tract_patch_initialized[tract_id] = True
    
    def find_patches(self, ra, dec, angular_distance, tract_id):
        """
        ra -- ra of boresite in degrees
        dec -- dec of boresite in degrees
        angular_distance -- radius of patch of sky in degrees
        tract_id -- integer ID of the tract being considered
        
        returns:
        --------
        A 2-D list.  Each row contains the ix, iy indices of
        patches that may overlap the region
        """
        if not self.tract_patch_initialized[tract_id]:
            self._initialize_patch_coords(tract_id)
        
        ra_rad = np.radians(ra)
        dec_rad = np.radians(dec)
        xyz = np.array([np.cos(dec_rad)*np.cos(ra_rad),
                        np.cos(dec_rad)*np.sin(ra_rad),
                        np.sin(dec_rad)])
        
        patch_centers = self.tract_patch_centers[tract_id]
        patch_radii = self.tract_patch_radii[tract_id]
        valid = np.where(np.dot(patch_centers, xyz) >
                         np.cos(np.radians(patch_radii+angular_distance)))
        return self.tract_patch_lookup[tract_id][valid]

In [340]:
asq = ArbitrarySkyQuery(sky_map)

In [341]:
ra=23.0
dec=-12.0
dist = 0.1
valid_tracts = asq.find_tracts(ra, dec, dist)

valid dot
[0.99992998 0.99986769]
0.9999299780253365
0.9998676940551104


In [342]:
for tract_id in valid_tracts[0]:
    print('tract %d' % tract_id)
    patches = asq.find_patches(ra, dec, dist, tract_id)
    print(patches)

tract 7318
initializing  7318
[[2 7]
 [2 8]
 [3 7]
 [3 8]
 [4 7]
 [4 8]]
tract 7555
initializing  7555
[[1 0]
 [2 0]
 [3 0]]


In [343]:
from lsst.afw.geom import SpherePoint, Angle, degrees

In [344]:
ra_afw = Angle(ra, degrees)
dec_afw = Angle(dec, degrees)
sph_pt = SpherePoint(ra_afw, dec_afw)
tract_info = sky_map.findTract(sph_pt)
tract_id = tract_info.getId()
tract_truth = sky_map[tract_id]
tract_truth.findPatch(sph_pt)

PatchInfo(index=(3, 7), innerBBox=(minimum=(12000, 28000), maximum=(15999, 31999)), outerBBox=(minimum=(11900, 27900), maximum=(16099, 32099)))

In [345]:
print(len(valid_tracts[0]))
print(len(tract_center_array))
print(valid_tracts)
ra_rad = np.radians(ra)
dec_rad = np.radians(dec)
xyz = np.array([np.cos(dec_rad)*np.cos(ra_rad), np.cos(dec_rad)*np.sin(ra_rad),
                np.sin(dec_rad)])
center = asq.tract_center_array[7318]
print(xyz)
print(center)
print(np.dot(xyz,center))

2
18938
(array([7318, 7555]),)
[ 0.90038961  0.38219272 -0.20791169]
[ 0.899605    0.37787851 -0.21890332]
0.9999299780253365
