In [1]:
import matplotlib
from ampel.ztf.archive.ArchiveDB import ArchiveDB
from astropy.time import Time
import itertools
import numpy as np
from scipy.stats import norm
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import scipy.optimize
import scipy as scp
import datetime
import ztfquery
import datetime
import re
from ztfquery import alert
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.patches import Circle
from matplotlib.collections import PatchCollection
import csv
import os,io
import pickle
# import pymongo
from astropy.coordinates import SkyCoord
from astropy import units as u
import getpass
import psycopg2 
import sqlalchemy
import healpy as hp
from astropy.io import fits
from tqdm import tqdm
%matplotlib inline

In [2]:
try:
    with open(".AMPEL_user.txt", "r") as f:
        username = f.read()
except FileNotFoundError:
    username = getpass.getpass(prompt='Username: ', stream=None)
    with open(".AMPEL_user.txt", "wb") as f:
        f.write(username.encode())
        
try:
    with open(".AMPEL_pass.txt", "r") as f:
        password = f.read()
except FileNotFoundError:
    password = getpass.getpass(prompt='Password: ', stream=None)
    with open(".AMPEL_pass.txt", "wb") as f:
        f.write(password.encode())

In [3]:
try:
    client = ArchiveDB('postgresql://{0}:{1}@localhost:5432/ztfarchive'.format(username, password))
except sqlalchemy.exc.OperationalError as e:
    print("You can't access the archive database without first opening the port.")
    print("Open a new terminal, and into that terminal, run the following command:")
    print("ssh -L5432:localhost:5433 ztf-wgs.zeuthen.desy.de")
    print("If that command doesn't work, you are either not a desy user or you have a problem in your ssh config.")
    raise e

You can't access the archive database without first opening the port.
Open a new terminal, and into that terminal, run the following command:
ssh -L5432:localhost:5433 ztf-wgs.zeuthen.desy.de
If that command doesn't work, you are either not a desy user or you have a problem in your ssh config.


OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly
	This probably means the server terminated abnormally
	before or while processing the request.

(Background on this error at: http://sqlalche.me/e/e3q8)

In [None]:
def reassemble_alert(candid):
    mock_alert = client.get_alert(candid)
    cutouts = client.get_cutout(candid)
    for k in cutouts:
        mock_alert['cutout{}'.format(k.title())] = {'stampData': cutouts[k], 'fileName': 'dunno'}
    mock_alert['schemavsn'] = 'dunno'
    mock_alert['publisher'] = 'dunno'
    for pp in [mock_alert['candidate']] + mock_alert['prv_candidates']:
        #if pp['isdiffpos'] is not None:
            #pp['isdiffpos'] = ['f', 't'][pp['isdiffpos']]
        pp['pdiffimfilename'] = 'dunno'
        pp['programpi'] = 'dunno'
        pp['ssnamenr'] = 'dunno'
        
    return mock_alert

In [None]:
class GravWaveScanner:
    
    def __init__(self, gw_file, rad, prob_threshold=0.9, cone_nside=64):
        self.gw_path = gw_file
        self.cone_nside = cone_nside
        self.parsed_file = self.read_map()
        self.merger_time = Time(self.parsed_file[1].header["DATE-OBS"], format="isot", scale="utc")
        
        print("MERGER TIME: {0}".format(self.merger_time))
        
        self.data = self.parsed_file[1].data
        self.prob_map = hp.read_map(gw_file)
        self.radius = rad
        self.prob_threshold = prob_threshold
        self.pixel_threshold = self.find_pixel_threshold(self.data["PROB"])
        self.scanned_pixels = []
        self.map_coords = self.unpack_skymap()
        self.cone_ids, self.cone_coords = self.find_cone_coords()
        self.cache = []

    def add_scan(self, ra, dec):
        self.scanned_points += [(ra, dec)]
        
    def read_map(self, ):
        print("Reading file: {0}".format(self.gw_path))
        f = fits.open(gw_file)
        return f
        
        
    def find_pixel_threshold(self, data):
        print("")
        ranked_pixels = np.sort(data)[::-1]
        int_sum = 0.0
        pixel_threshold = 0.0

        for i, prob in enumerate(ranked_pixels):
            int_sum += prob
            if int_sum > self.prob_threshold:
                print("Threshold found! \n To reach {0}% of probability, pixels with "
                      "probability greater than {1} are included".format(
                          int_sum*100., prob))
                pixel_threshold = prob
                break
                
        return pixel_threshold
    
    def find_cone_ids(self):
        pass
    
    @staticmethod
    def extract_ra_dec(nside, index):
        (colat, ra) = hp.pix2ang(nside, index, nest=True)
        dec = np.pi/2. - colat
        return (ra, dec)
    
    @staticmethod
    def extract_npix(nside, ra, dec):
        colat = np.pi/2. - dec
        return hp.ang2pix(nside, colat, ra, nest=True)
    
    def unpack_skymap(self):
#         all_coords = []
#         prob_center = []

#         # new_probs = hp.pixelfunc.ud_grade(gw.data["PROB"], nside_out=cone_nside, power=-2)
#         new_probs =d["PROB"]
#         print(new_probs, np.sum(new_probs))


        ligo_nside = hp.npix2nside(len(self.data["PROB"]))

        threshold = self.find_pixel_threshold(self.data["PROB"])

        mask = self.data["PROB"] > threshold
        
        map_coords = []


        for i in tqdm(range(hp.nside2npix(ligo_nside))):
            if mask[i]:
                map_coords.append(self.extract_ra_dec(ligo_nside, i))

        map_coords = np.array(map_coords, dtype=np.dtype([("ra", np.float), 
                                                          ("dec", np.float)]))
        
        return map_coords

        # all_coords = all_coords[mask]

#         print(np.degrees(all_coords["ra"]), np.degrees(all_coords["dec"]))

    def find_cone_coords(self):
        cone_ids = []

        for ra, dec in self.map_coords:
        #     print(np.degrees(ra), np.degrees(dec))
        #     print(hp.ang2pix(64, np.pi/2. - dec, ra, nest=True))
            cone_ids.append(self.extract_npix(self.cone_nside, ra, dec))

        cone_ids = list(set(cone_ids))        
        
        cone_coords = []

        for i in tqdm(cone_ids):
            cone_coords.append(self.extract_ra_dec(self.cone_nside, i))

        cone_coords = np.array(
            cone_coords, dtype=np.dtype([("ra", np.float), ("dec", np.float)])
        )
        
        return cone_ids, cone_coords
    
    @staticmethod
    def wrap_around_180(ra):
        ra[ra > np.pi] -= 2*np.pi
        return ra
        
        
        
    def plot_skymap(self):
        
        plt.subplot(projection="aitoff")

        sc = plt.scatter(wrap_around_180(self.map_coords["ra"]), self.map_coords["dec"],
                        c=new_probs[mask], vmin=0.,  vmax=max(new_probs), s=1e-4)
        plt.title("LIGO SKYMAP")
        plt.show()
        
        plt.subplot(projection="aitoff")

        sc = plt.scatter(wrap_around_180(self.cone_coords["ra"]), self.cone_coords["dec"])
        plt.title("CONE REGION")
        plt.show()
        
    def scan_cones(self, t_max=Time.now()):
        
        scan_radius = np.degrees(hp.max_pixrad(self.cone_nside))
        print("Commencing Ampel queries!")
        print("Scan radius is", scan_radius)
        print("So far, {0} pixels out of {1} have already been scanned.".format(
            len(self.scanned_pixels), len(self.cone_ids)
        ))
        
        for i, cone_id in enumerate(tqdm(list(self.cone_ids))):
#             print(i)
            ra, dec = self.cone_coords[i]
            
            if cone_id not in self.scanned_pixels:
                try:
                    self.cache += self.query_ampel(ra, dec, scan_radius, t_max)
                    self.scanned_pixels.append(cone_id)
                except:
                    pass
        
        print("Scanned {0} pixels".format(len(self.scanned_pixels)))
        
    
    def filter_f(self, res):
        # Positive detection
        if res['candidate']['isdiffpos'] in ["t", "1"]:
            print(res)
            # Positional 
            if np.logical_and(res["candidate"]["ra"] < ra_max,
                              res["candidate"]["ra"] > ra_min):
                    if np.logical_and(res["candidate"]["dec"] < dec_max,
                                      res["candidate"]["dec"] > dec_min):
                        
                        # Veto past detections, but not past upper limits
                        
                        for prv_detection in res["prv_candidates"]:
                            if np.logical_and(prv_detection["isdiffpos"] is not None, prv_detection["jd"] < t_min.jd):
                                return False
                            
                        # Require 2 detections
                        
                        n_detections = len([x for x in res["prv_candidates"] if np.logical_and(
                            x["isdiffpos"] is not None, x["jd"] > t_min.jd)])
                        
                        if n_detections < 1:
                            return False
                        
                        # Remove stars et al.
                        if res["sgscore1"] > 0.8:
                            return False
                        
                        if res["rb"] < 0.2:
                            return False
                            
                        print((res['objectId']), res["candidate"]["ra"], res["candidate"]["dec"], n_detections)
                        
                        return True
                    
        return False
            
    def query_ampel(self, ra, dec, rad, t_max):
        ztf_object = client.ztf_object = client.get_alerts_in_cone(
            ra, dec, rad, self.merger_time.jd, self.t_max.jd, with_history=True)
        query_res = [i for i in ztf_object]
        print("Found a total of {0} objects with cone search.".format(len(query_res)))
        print("This number will be reduced when box is applied")

        diff = ["t", "1"]
        final_res = []
        for res in query_res:
            if filter_f(res):                        
                final_res.append(res)
                
        return final_res
    

In [None]:
gw = GravWaveScanner("/Users/avocado/Downloads/LALInference.offline", None)

In [None]:
gw.scan_cones()

In [None]:
# print(gw.parsed_file[1].header)