In [1]:
import healpy as hp
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
%matplotlib inline

In [23]:
class GravWaveScanner:
    
    def __init__(self, gw_file, rad, prob_threshold=0.9, cone_nside=64):
        self.gw_path = gw_file
        self.cone_nside = cone_nside
        self.parsed_file = self.read_map()
        self.data = self.parsed_file[1].data
        self.prob_map = hp.read_map(gw_file)
        self.radius = rad
        self.prob_threshold = prob_threshold
        self.pixel_threshold = self.find_pixel_threshold(self.data["PROB"])
        self.scanned_pixels = []
        self.map_coords = self.unpack_skymap()
        self.cone_ids, self.cone_coords = self.find_cone_coords()
        self.cache = []

    def add_scan(self, ra, dec):
        self.scanned_points += [(ra, dec)]
        
    def read_map(self, ):
        print("Reading file: {0}".format(self.gw_path))
        f = fits.open(gw_file)
        return f
        
        
    def find_pixel_threshold(self, data):
        print("")
        ranked_pixels = np.sort(data)[::-1]
        int_sum = 0.0
        pixel_threshold = 0.0

        for i, prob in enumerate(ranked_pixels):
            int_sum += prob
            if int_sum > self.prob_threshold:
                print("Threshold found! \n To reach {0}% of probability, pixels with "
                      "probability greater than {1} are included".format(
                          int_sum*100., prob))
                pixel_threshold = prob
                break
                
        return pixel_threshold
    
    def find_cone_ids(self):
        pass
    
    @staticmethod
    def extract_ra_dec(nside, index):
        (colat, ra) = hp.pix2ang(nside, index, nest=True)
        dec = np.pi/2. - colat
        return (ra, dec)
    
    @staticmethod
    def extract_npix(nside, ra, dec):
        colat = np.pi/2. - dec
        return hp.ang2pix(nside, colat, ra, nest=True)
    
    def unpack_skymap(self):
#         all_coords = []
#         prob_center = []

#         # new_probs = hp.pixelfunc.ud_grade(gw.data["PROB"], nside_out=cone_nside, power=-2)
#         new_probs =d["PROB"]
#         print(new_probs, np.sum(new_probs))


        ligo_nside = hp.npix2nside(len(self.data["PROB"]))

        threshold = self.find_pixel_threshold(self.data["PROB"])

        mask = self.data["PROB"] > threshold
        
        map_coords = []


        for i in tqdm(range(hp.nside2npix(ligo_nside))):
            if mask[i]:
                map_coords.append(self.extract_ra_dec(ligo_nside, i))

        map_coords = np.array(map_coords, dtype=np.dtype([("ra", np.float), 
                                                          ("dec", np.float)]))
        
        return map_coords

        # all_coords = all_coords[mask]

#         print(np.degrees(all_coords["ra"]), np.degrees(all_coords["dec"]))

    def find_cone_coords(self):
        cone_ids = []

        for ra, dec in self.map_coords:
        #     print(np.degrees(ra), np.degrees(dec))
        #     print(hp.ang2pix(64, np.pi/2. - dec, ra, nest=True))
            cone_ids.append(self.extract_npix(self.cone_nside, ra, dec))

        cone_ids = list(set(cone_ids))        
        
        cone_coords = []

        for i in tqdm(cone_ids):
            cone_coords.append(self.extract_ra_dec(self.cone_nside, i))

        cone_coords = np.array(
            cone_coords, dtype=np.dtype([("ra", np.float), ("dec", np.float)])
        )
        
        return cone_ids, cone_coords
    
    @staticmethod
    def wrap_around_180(ra):
        ra[ra > np.pi] -= 2*np.pi
        return ra
        
        
        
    def plot_skymap(self):
        
        plt.subplot(projection="aitoff")

        sc = plt.scatter(wrap_around_180(self.map_coords["ra"]), self.map_coords["dec"],
                        c=new_probs[mask], vmin=0.,  vmax=max(new_probs), s=1e-4)
        plt.title("LIGO SKYMAP")
        plt.show()
        
        plt.subplot(projection="aitoff")

        sc = plt.scatter(wrap_around_180(self.cone_coords["ra"]), self.cone_coords["dec"])
        plt.title("CONE REGION")
        plt.show()
        
    def scan_cones(self):
        
        scan_radius = np.degrees(hp.max_pixrad(self.cone_nside))
        print("Commencing Ampel queries!")
        print("Scan radius is", scan_radius)
        print("So far, {0} pixels out of {1} have already been scanned.".format(
            len(self.scanned_pixels), len(self.cone_ids)
        ))
        
        for i, cone_id in enumerate(tqdm(list(self.cone_ids))):
#             print(i)
            ra, dec = self.cone_coords[i]
            
            if cone_id not in self.scanned_pixels:
                try:
                    self.cache += self.query_ampel(ra, dec, scan_radius)
                    self.scanned_pixels.append(cone_id)
                except:
                    pass
        
        print("Scanned {0} pixels".format(len(self.scanned_pixels)))
            
        
    def query_ampel(self, ra, dec, rad):
        return []
        
        
        
        
        
    

In [24]:
gw = GravWaveScanner("/Users/avocado/Downloads/LALInference.offline", None)

Reading file: /Users/avocado/Downloads/LALInference.offline
NSIDE = 1024
ORDERING = NESTED in fits file
INDXSCHM = IMPLICIT
Ordering converted to RING

Threshold found! 
 To reach 90.00009741504594% of probability, pixels with probability greater than 3.8098001157350675e-06 are included



  0%|          | 0/12582912 [00:00<?, ?it/s]

Threshold found! 
 To reach 90.00009741504594% of probability, pixels with probability greater than 3.8098001157350675e-06 are included


100%|██████████| 12582912/12582912 [00:11<00:00, 1104074.42it/s]
100%|██████████| 153/153 [00:00<00:00, 10702.97it/s]


In [25]:
gw.scan_cones()

100%|██████████| 153/153 [00:00<00:00, 54503.87it/s]

Commencing Ampel queries!
Scan radius is 0.9541480607387777
So far, 0 pixels out of 153 have already been scanned.
Scanned 153 pixels



