In [None]:
import numpy as np, matplotlib.pyplot as plt, pandas as pd
import warnings, json, sys, requests, gzip, io

from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from matplotlib.colors import LogNorm

from penquins import Kowalski
from sklearn.model_selection import train_test_split

from bson.json_util import loads, dumps
from astropy.stats import sigma_clipped_stats
from astropy.io import fits
from astropy.io.fits.verify import VerifyWarning 
warnings.filterwarnings("ignore", category=VerifyWarning)

BOLD = "\033[1m"; END  = "\033[0m"

with open('misc/credentials.json', 'r') as f:
    creds = json.load(f)

In [None]:
plt.rcParams.update({
    "font.family": "Times New Roman",
    "font.size": 12,
})
plt.rcParams['axes.linewidth'] = 1

### Query the BTS Sample Explorer

In [None]:
# Old queries
# bts_trues_url      = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&format=csv"
# bts_dim_falses_url = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&quality=y&purity=y&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&reverse=y&format=csv"

# bts_trues_url      = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&passok=y&refok=y&dateok=y&purity=y&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&format=csv"
# bts_var_falses_url = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=var&classstring=&classexclude=&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&format=csv"
# bts_dim_falses_url = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&passok=y&refok=y&dateok=y&purity=y&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&reverse=y&format=csv"


In [None]:
# # v2
# query_urls = {
#     "rcf_trues":           "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&subsample=trans&classstring=&classexclude=&refok=y&purity=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
#     "rcf_dim_falses":      "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&subsample=trans&classstring=&classexclude=&covok=y&refok=y&purity=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
#     "rcf_var_falses":      "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&subsample=var&classstring=&classexclude=&refok=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",

#     "rcf_deep_trues"     : "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&sampdeep=y&subsample=trans&classstring=&classexclude=&refok=y&purity=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
#     "rcf_deep_dim_falses": "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&sampdeep=y&subsample=trans&classstring=&classexclude=&covok=y&refok=y&purity=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
#     "rcf_deep_var_falses": "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&sampdeep=y&subsample=var&classstring=&classexclude=&refok=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv"
# }

In [None]:
# # v3.1
# query_urls = {
#     "rcf_trues":    "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&sampdeep=y&subsample=trans&classstring=&classexclude=&refok=y&purity=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
#     "rcf_dim":      "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&subsample=all&classstring=&classexclude=&covok=y&refok=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
#     "rcf_var":      "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&subsample=var&classstring=&classexclude=&refok=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
    
#     "rcf_deep_dim": "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&sampdeep=y&subsample=all&classstring=&classexclude=&covok=y&refok=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
#     "rcf_deep_var": "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&sampdeep=y&subsample=var&classstring=&classexclude=&refok=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv"
# }

In [None]:
# v3.2
query_urls = {
    "trues": "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&sampdeep=y&subsample=trans&classstring=&classexclude=&refok=y&purity=y&ps1img=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
    "dims":  "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&sampdeep=y&subsample=all&classstring=&classexclude=&covok=y&refok=y&purity=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
    "vars":  "http://sites.astro.caltech.edu/ztf/rcf/explorer.php?f=s&coverage=any&samprcf=y&sampdeep=y&subsample=var&classstring=&classexclude=&refok=y&lcfig=y&ztflink=fritz&startsavedate=&startpeakdate=&startlastdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startlastmag=&startabsmag=&starthostabs=&starthostcol=&startsavevis=&startlatevis=&startcurrvis=&startb=&startav=&endsavedate=&endpeakdate=&endlastdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endlastmag=&endabsmag=&endhostabs=&endhostcol=&endsavevis=&endlatevis=&endcurrvis=&endb=&endav=&sort=peakmag&format=csv",
}


In [None]:
for set_name in query_urls.keys():
    with open(f"data/base_data/{set_name}.csv", "w") as f:
        f.write(requests.get(query_urls[set_name], auth=(creds["btsse_username"], creds["btsse_password"])).text)
        print("Queried and wrote", set_name)


### Read queried data

In [None]:
for name1 in query_urls.keys():
    set1 = pd.read_csv(f"data/base_data/{name1}.csv")
    print(name1)
    for name2 in query_urls.keys():
        set2 = pd.read_csv(f"data/base_data/{name2}.csv")
        print(f"  in {name2}")
        print("  ", np.sum(set1['ZTFID'].isin(set2["ZTFID"])), "/", len(set1['ZTFID']))
    print()

In [None]:
all_queries = pd.DataFrame(columns=["ZTFID"])
queries = [pd.read_csv(f"data/base_data/{set_name}.csv") for set_name in query_urls.keys()]

for i in range(len(queries)):
    queries[i] = queries[i][~queries[i]['ZTFID'].isin(all_queries['ZTFID'])]
    all_queries = pd.concat([all_queries, queries[i]])
    
trues = queries[0]
dims  = queries[1]
vars  = queries[2]

In [None]:
print(len(trues), "rcf true sources")
print(len(dims)+len(vars), "rcf (deep) false sources")

print(len(all_queries), "total")

# v2
# 7009 total

# v3.1
# 4150 rcf true sources
# 7913 rcf (deep) false sources
# 12063 total


### Process list from Mat Smith

In [None]:
MS_Ias = pd.read_csv('mat-smith/ztfdr2_masterlist.csv')
MS_Ias.rename(columns={"ztfname": "ZTFID"}, inplace=True)
print("Total in MS list", len(MS_Ias))

In [None]:
nonZTF = ~MS_Ias['ZTFID'].str.contains('ZTF')
nonZTF_idxs = MS_Ias['ZTFID'].index[nonZTF]

MS_Ias = MS_Ias.drop(index=nonZTF_idxs)
print("Total in MS list excluding non ZTF objects", len(MS_Ias))

In [None]:
inBTSSE = MS_Ias['ZTFID'].isin(all_queries['ZTFID'])
inBTSSE_idxs = MS_Ias['ZTFID'].index[inBTSSE]

MS_Ias = MS_Ias.drop(index=inBTSSE_idxs)
print("New in MS list", len(MS_Ias))

In [None]:
print("Total objects:", len(all_queries)+len(MS_Ias))

### Objects to remove

In [None]:
objs_to_remove = ["ZTF18abdiasx", "ZTF21abyazip", "ZTF18aaadqua", "ZTF18aarrwmi"]
print(len(all_queries)+len(MS_Ias))

for obj in objs_to_remove:
    trues = trues[trues["ZTFID"] != obj]
    dims = dims[dims["ZTFID"] != obj]
    vars = vars[vars["ZTFID"] != obj]
    MS_Ias = MS_Ias[MS_Ias["ZTFID"] != obj]
        
dims = dims[~dims["type"].isin(["bogus", "duplicate", "bogus?", "duplicate?"])]

queries = [trues, dims, vars] 
all_queries = pd.concat(queries)
print(len(all_queries)+len(MS_Ias))

### Helper functions for querying kowalski and processing alerts

In [None]:
k = Kowalski(username=creds['kowalski_username'], password=creds['kowalski_password'])
assert(k.ping())

In [None]:
def query_kowalski(ZTFID, kowalski):
    """
        Query kowalski to get the candidate stamps
        ADAPTED FROM https://github.com/growth-astro/ztfrest/
        https://zwickytransientfacility.github.io/ztf-avro-alert/schema.html
    """
    
    if type(ZTFID) == str:
        list_ZTFID = [ZTFID]
    elif type(ZTFID) == list:
        list_ZTFID = ZTFID
    else:
        print(f"{ZTFID} must be a list or a string")
        return None

    alerts = []
    
    for ZTFID in list_ZTFID:
        query = {
            "query_type": "find",
            "query": {
                "catalog": "ZTF_alerts",
                "filter": {
                    'objectId': ZTFID
                },
                "projection": {
                    "_id": 0,
                    "candid": 1,
                    "objectId": 1,
                    
                    "candidate.jd": 1,
                    "candidate.fid": 1,
                    "candidate.programid": 1,
                    "candidate.isdiffpos": 1,
                    "candidate.rcid": 1,
                    "candidate.field": 1,
                    "candidate.ra": 1,
                    "candidate.dec": 1,
                    "candidate.magpsf": 1,
                    "candidate.sigmapsf": 1,
                    "candidate.distnr": 1,
                    "candidate.magnr": 1,
                    "candidate.sigmagnr": 1,
                    "candidate.sky": 1,
                    "candidate.fwhm": 1,
                    "candidate.classtar": 1,
                    "candidate.mindtoedge": 1,
                    "candidate.rb": 1,
                    "candidate.drb": 1,
                    "candidate.ndethist": 1,
                    "candidate.jdstarthist": 1,
                    "candidate.jdendhist": 1,
                    "candidate.scorr": 1,
                    "candidate.sgscore1": 1,
                    "candidate.distpsnr1": 1,
                    "candidate.sgscore2": 1,
                    "candidate.distpsnr2": 1,
                    "candidate.sgscore3": 1,
                    "candidate.distpsnr3": 1,
                    "candidate.magzpsci": 1, 
                    "candidate.magzpsciunc": 1,
                    "candidate.neargaia": 1, 
                    "candidate.maggaia": 1,
                    
                    "classifications.acai_h": 1,
                    "classifications.acai_v": 1,
                    "classifications.acai_o": 1,
                    "classifications.acai_n": 1,
                    "classifications.acai_b": 1,
                    
                    "cutoutScience": 1,
                    "cutoutTemplate": 1,
                    "cutoutDifference": 1,
                }
            }
        }

        r = kowalski.query(query)

        if r['data'] == []:
            print("  No data for", ZTFID)
        else:
            alerts += r['data']
            print("  Finished querying", ZTFID)
    print(BOLD+f"Finished all queries, got {len(alerts)} alerts"+END+"\n")
    return alerts
    

In [None]:
def make_triplet(alert, normalize: bool = True):
    """
        Feed in alert packet
        ADAPTED FROM https://github.com/dmitryduev/braai
    """
    
    cutout_dict = dict()
    drop = False
    
    for cutout in ('science', 'template', 'difference'):
        cutout_data = loads(dumps([alert[f'cutout{cutout.capitalize()}']['stampData']]))[0]
        # unzip
        with gzip.open(io.BytesIO(cutout_data), 'rb') as f:
            with fits.open(io.BytesIO(f.read())) as hdu:
                data = hdu[0].data
#                 print(data)
                # replace nans with zeros
                medfill = np.nanmedian(data.flatten())
#                 print(medfill)
                
                if medfill == np.nan or medfill == -np.inf or medfill == np.inf:
                    print(BOLD, alert['objectId'], END, "bad medfill (nan or inf)", alert['candid'])
                    drop=True
    
                cutout_dict[cutout] = np.nan_to_num(data, nan=medfill)
                
                # normalize
                if normalize and not drop:
                    cutout_dict[cutout] /= np.linalg.norm(cutout_dict[cutout])
                    
                if np.all(cutout_dict[cutout].flatten() == 0):
                    print(BOLD, alert['objectId'], END, "zero image", alert['candid'])
                    drop=True
                    
                if np.any(np.isnan(cutout_dict[cutout].flatten())):
                    print(BOLD, alert['objectId'], END, "nan here", alert['candid'])
#                     for it in cutout_dict[cutout]:
#                         print(it)
                    plt.imshow(cutout_dict[cutout], origin='upper', cmap=plt.cm.bone, norm=LogNorm())
                    plt.show()
        
        # pad to 63x63 if smaller
        shape = cutout_dict[cutout].shape
        if shape != (63, 63):
            print("bad shape", shape, alert['candid'], alert['objectId'])
            cutout_dict[cutout] = np.pad(cutout_dict[cutout],
                                         [(0, 63 - shape[0]),
                                          (0, 63 - shape[1])],
                                         mode='constant', constant_values=medfill)
#     print()
    triplet = np.zeros((63, 63, 3))
    triplet[:, :, 0] = cutout_dict['science']
    triplet[:, :, 1] = cutout_dict['template']
    triplet[:, :, 2] = cutout_dict['difference']
    
    return triplet, drop


In [None]:
def plot_triplet(tr, show_fig: bool = True):
    """ADAPTED FROM https://github.com/dmitryduev/braai"""
    
    fig = plt.figure(figsize=(8, 2), dpi=120)
    ax1 = fig.add_subplot(131)
    ax1.axis('off')
#     mean, median, std = sigma_clipped_stats(tr[:, :, 0])
    ax1.imshow(tr[:, :, 0], origin='upper', cmap=plt.cm.bone, norm=LogNorm())
    ax1.title.set_text('Science')
    ax2 = fig.add_subplot(132)
    ax2.axis('off')
#     mean, median, std = sigma_clipped_stats(tr[:, :, 1])
    ax2.imshow(tr[:, :, 1], origin='upper', cmap=plt.cm.bone, norm=LogNorm())
    ax2.title.set_text('Reference')
    ax3 = fig.add_subplot(133)
    ax3.axis('off')
#     mean, median, std = sigma_clipped_stats(tr[:, :, 2])
    ax3.imshow(tr[:, :, 2], origin='upper', cmap=plt.cm.bone)
    ax3.title.set_text('Difference')

    if show_fig:
        plt.show()
    else:
        return fig
    

In [None]:
def extract_triplets(alerts, normalize: bool = True, pop_triplet: bool = True):
    triplets = np.empty((len(alerts), 63, 63, 3))
    to_drop = np.array((), dtype=int)
    for i, alert in enumerate(alerts):
        triplets[i], drop = make_triplet(alert, normalize=normalize)
        if pop_triplet:
            alert.pop('cutoutScience'); alert.pop('cutoutTemplate'); alert.pop('cutoutDifference')
        if drop:
            to_drop = np.append(to_drop, int(i))
            
    if len(to_drop) > 0:
        triplets = np.delete(triplets, list(to_drop), axis=0)
        alerts = np.delete(alerts, list(to_drop), axis=0)

    return alerts, triplets


In [None]:
def process_cand_data(alerts, label):
    cand_class_data = [alert['candidate'] | alert['classifications'] for alert in alerts]

    df = pd.DataFrame(cand_class_data)
    df.insert(0, "objectId", [alert['objectId'] for alert in alerts])
    df.insert(1, "candid", [alert['candid'] for alert in alerts])
    
    # label must be int equalling 0, 1 or a list of 1s and 0s
    if type(label) == list or type(label) == np.ndarray:
        assert(len(label) == len(alerts))
        df.insert(2, "label", label)
    elif type(label) == int:    
        df.insert(2, "label", np.full((len(alerts),), label, dtype=int))
    print("Arranged candidate data and inserted labels")
    return df


### Query data from kowalski, separate and save triplets and candidate data  

In [None]:
print(f"Querying kowalski for {len(trues['ZTFID'])} objects")
true_alerts, true_triplets = extract_triplets(query_kowalski(trues['ZTFID'].to_list(), k), True)

np.save("data/base_data/true_triplets.npy", true_triplets)
del true_triplets
print("Saved and purged triplets\n")

num_true_alerts = len(true_alerts)
print(f"All {num_true_alerts} alerts are trues")

true_cand_data = process_cand_data(true_alerts, np.ones(num_true_alerts, dtype=int))
true_cand_data.to_csv('data/base_data/true_candidates.csv', index=False)
del true_cand_data
print("Saved and purged candidate data")

In [None]:
print(f"Querying kowalski for {len(dims['ZTFID'])} objects")
dims_alerts, dims_triplets = extract_triplets(query_kowalski(dims['ZTFID'].to_list(), k), True)

np.save("data/base_data/dims_triplets.npy", dims_triplets)
del dims_triplets
print("Saved and purged triplets\n")

num_dims_alerts = len(dims_alerts)
print(f"All {num_dims_alerts} alerts are falses")

dims_cand_data = process_cand_data(dims_alerts, np.zeros(num_dims_alerts, dtype=int))
dims_cand_data.to_csv('data/base_data/dims_candidates.csv', index=False)
del dims_cand_data
print("Saved and purged candidate data")

In [None]:
print(f"Querying kowalski for {len(vars['ZTFID'])} objects")
vars_alerts, vars_triplets = extract_triplets(query_kowalski(vars['ZTFID'].to_list(), k), True)

np.save("data/base_data/vars_triplets.npy", vars_triplets)
del vars_triplets
print("Saved and purged triplets\n")

num_vars_alerts = len(vars_alerts)
print(f"All {num_vars_alerts} alerts are falses")

vars_cand_data = process_cand_data(vars_alerts, np.zeros(num_vars_alerts, dtype=int))
vars_cand_data.to_csv('data/base_data/vars_candidates.csv', index=False)
del vars_cand_data
print("Saved and purged candidate data")


In [None]:
print(f"Querying kowalski for {len(MS_Ias)} objects")
MS_alerts, MS_triplets = extract_triplets(query_kowalski(MS_Ias['ZTFID'].to_list(), k), True)

np.save("data/base_data/MS_triplets.npy", MS_triplets)
del MS_triplets
print("Saved and purged triplets\n")

num_MS_alerts = len(MS_alerts)

MS_true_objs = set()
for al in MS_alerts: 
    if al['candidate']['magpsf'] < 18.5:
        MS_true_objs.add(al['objectId'])
MS_labels = [1 if al['objectId'] in MS_true_objs else 0 for al in MS_alerts]
print(f"Generated labels: {np.sum(MS_labels)} trues, {len(MS_labels)-np.sum(MS_labels)} falses")

MS_cand_data = process_cand_data(MS_alerts, MS_labels)
MS_cand_data.to_csv('data/base_data/MS_candidates.csv', index=False)
del MS_cand_data
print("Saved and purged candidate data")

### Visualization helper functions

In [None]:
def plot_lightcurve(alerts):
    fid_to_color = {
        1: ('green',  'g'),
        2: ('red',    'r'),
        3: ('orange', 'i')
    }
    
    alerts = alerts.sort_values(by='jd')
    jds = alerts['jd']-alerts['jd'].to_numpy()[0]
        
    fig, ax = plt.subplots(figsize=(8,5))
    
    for fid in [1, 2, 3]:
        obs_in_filt = alerts['fid'] == fid
        
        alerts_in_filt = alerts.loc[obs_in_filt]
        plt.errorbar(jds[obs_in_filt], alerts_in_filt['magpsf'], fmt='o', color=fid_to_color[fid][0], yerr=alerts_in_filt['sigmapsf'], label='ztf'+fid_to_color[fid][1])
    
    ax.invert_yaxis()
    ax.set_xlabel("days since first detection", size=16, labelpad=10)
    ax.set_ylabel("PSF magnitude", size=16, labelpad=10)
    ax.legend(loc='upper right', bbox_to_anchor=(1.25, 1))
    
    return fig, ax

In [None]:
def plot_lightcurve_itr(alerts, triplets):
    fid_to_color = {
        1: ('green',  'g'),
        2: ('red',    'r'),
        3: ('orange', 'i')
    }
    
#     alerts = alerts.sort_values(by='jd')
    first_detect = np.min(alerts['jd'].to_numpy())
    print(first_detect)
    for trip in trips[::10]:
        
        plot_triplet(trip, show_fig=True)
        
    fig, ax = plt.subplots(figsize=(8,5))
    
    for idx in alerts.index:
        alert = alerts.loc[idx]
        jd = alert['jd'] - first_detect
        if alert['isdiffpos']:
            plt.errorbar(jd, alert['magpsf'], fmt='^', fillstyle='none', alpha=0.5, color=fid_to_color[alert['fid']][0], yerr=alert['sigmapsf'])
        else:
            plt.errorbar(jd, alert['magpsf'], fmt='v', fillstyle='none', alpha=0.5, color=fid_to_color[alert['fid']][0], yerr=alert['sigmapsf'])
    
    
    plt.plot([],[], marker='^', color='red', fillstyle='none', alpha=0.5, label='ztfr pos diff')
    plt.plot([],[], marker='^', color='green', fillstyle='none', alpha=0.5, label='ztfg pos diff')
    plt.plot([],[], marker='^', color='orange', fillstyle='none', alpha=0.5, label='ztfi pos diff')
    
    plt.plot([],[], marker='v', color='r', fillstyle='none', alpha=0.5, label='ztfr neg diff')
    plt.plot([],[], marker='v', color='green', fillstyle='none', alpha=0.5, label='ztfg neg diff')
    plt.plot([],[], marker='v', color='orange', fillstyle='none', alpha=0.5, label='ztfi neg diff')
    ax.invert_yaxis()
    ax.set_xlabel("days since first detection", size=16, labelpad=10)
    ax.set_ylabel("PSF magnitude", size=16, labelpad=10)
    ax.legend(loc='upper right')#, bbox_to_anchor=(1.2,1))
    
    return fig, ax

In [None]:
# ztfid = "ZTF18ablmduj"

# alerts = cand[cand['objectId']==ztfid]
# trips = triplets[cand['objectId']==ztfid]
# fig, ax = plot_lightcurve_itr(alerts, trips)
# ax.set_title(f"{ztfid} lightcurve", size=14)
# fig.tight_layout()
# plt.show()