In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import pandas as pd
import json

from penquins import Kowalski

from sklearn.model_selection import train_test_split

from matplotlib.colors import LogNorm
from astropy.stats import sigma_clipped_stats
from astropy.io.fits.verify import VerifyWarning
import warnings
warnings.filterwarnings("ignore", category=VerifyWarning)

import sys
BOLD = "\033[1m"
END  = "\033[0m"

# %matplotlib notebook

In [None]:
plt.rcParams.update({
    "font.family": "Times New Roman",
    "font.size": 12,
})
plt.rcParams['axes.linewidth'] = 1

### Query the BTS Sample Explorer

In [None]:
old_bts_trues_url      = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&format=csv"
old_bts_dim_falses_url = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&quality=y&purity=y&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&reverse=y&format=csv"


In [None]:
bts_trues_url      = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&passok=y&refok=y&dateok=y&purity=y&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=18.5&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&format=csv"
bts_var_falses_url = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=var&classstring=&classexclude=&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&format=csv"
bts_dim_falses_url = "https://sites.astro.caltech.edu/ztf/bts/explorer.php?f=s&subsample=trans&classstring=&classexclude=&passok=y&refok=y&dateok=y&purity=y&ps1img=y&lcfig=y&ztflink=lasair&lastdet=&startsavedate=&startpeakdate=&startra=&startdec=&startz=&startdur=&startrise=&startfade=&startpeakmag=18.5&startabsmag=&starthostabs=&starthostcol=&startb=&startav=&endsavedate=&endpeakdate=&endra=&enddec=&endz=&enddur=&endrise=&endfade=&endpeakmag=&endabsmag=&endhostabs=&endhostcol=&endb=&endav=&sort=peakmag&reverse=y&format=csv"


In [None]:
!curl -o data/base_data/bts_trues.csv "$bts_trues_url"
!curl -o data/base_data/bts_var_falses.csv "$bts_var_falses_url"
!curl -o data/base_data/bts_dim_falses.csv "$bts_dim_falses_url"

# m = -2.5log(c)+zp
# 12=-2.5log(255)+zp
# m1-m2=-2.5log(f1/f2)
# 12+2.5*np.log10(255)

### Read queried data

In [None]:
bts_trues = pd.read_csv("data/base_data/bts_trues.csv")
bts_var_falses = pd.read_csv("data/base_data/bts_var_falses.csv")
bts_dim_falses = pd.read_csv("data/base_data/bts_dim_falses.csv")

In [None]:
print(len(bts_trues), "bts true sources")
print(len(bts_dim_falses)+len(bts_var_falses), "bts false sources")

### Process list from Mat Smith

In [None]:
MS_Ias = pd.read_csv('mat-smith/ztfdr2_masterlist.csv')
MS_Ias.rename(columns={"ztfname": "ZTFID"}, inplace=True)
print("Total in MS list", len(MS_Ias))

In [None]:
nonZTF = ~MS_Ias['ZTFID'].str.contains('ZTF')
nonZTF_idxs = MS_Ias['ZTFID'].index[nonZTF]

MS_Ias = MS_Ias.drop(index=nonZTF_idxs)
print("Total in MS list excluding non ZTF objects", len(MS_Ias))

In [None]:
inBTSSE = MS_Ias['ZTFID'].isin(pd.concat([bts_trues['ZTFID'], bts_dim_falses['ZTFID'], bts_var_falses['ZTFID']]))
inBTSSE_idxs = MS_Ias['ZTFID'].index[inBTSSE]

MS_Ias = MS_Ias.drop(index=inBTSSE_idxs)
print("New in MS list", len(MS_Ias))

### Objects to remove

In [None]:
objs_to_remove = ["ZTF18abdiasx", "ZTF21abyazip", "ZTF18aaadqua", "ZTF18aarrwmi"]
print(len(bts_trues)+len(bts_var_falses)+len(bts_dim_falses)+len(MS_Ias))

for obj in objs_to_remove:
    bts_trues      = bts_trues     [bts_trues     ["ZTFID"] != obj]
    bts_var_falses = bts_var_falses[bts_var_falses["ZTFID"] != obj]
    bts_dim_falses = bts_dim_falses[bts_dim_falses["ZTFID"] != obj]
    MS_Ias         = MS_Ias        [MS_Ias        ["ZTFID"] != obj]
        
print(len(bts_trues)+len(bts_var_falses)+len(bts_dim_falses)+len(MS_Ias))

### Helper functions for querying kowalski and processing alerts

In [None]:
with open('misc/credentials.json', 'r') as f:
    creds = json.load(f)
    
k = Kowalski(username=creds['username'], password=creds['password'])
assert(k.ping())

In [None]:
def query_kowalski(ZTFID, kowalski):
    """
        Query kowalski to get the candidate stamps
        ADAPTED FROM https://github.com/growth-astro/ztfrest/
        https://zwickytransientfacility.github.io/ztf-avro-alert/schema.html
    """
    
    if type(ZTFID) == str:
        list_ZTFID = [ZTFID]
    elif type(ZTFID) == list:
        list_ZTFID = ZTFID
    else:
        print(f"{ZTFID} must be a list or a string")
        return None

    alerts = []
    
    for ZTFID in list_ZTFID:
        query = {
            "query_type": "find",
            "query": {
                "catalog": "ZTF_alerts",
                "filter": {
                    'objectId': ZTFID
                },
                "projection": {
                    "_id": 0,
                    "candid": 1,
                    "objectId": 1,
                    
                    "candidate.jd": 1,
                    "candidate.fid": 1,
                    "candidate.programid": 1,
                    "candidate.isdiffpos": 1,
                    "candidate.rcid": 1,
                    "candidate.field": 1,
                    "candidate.ra": 1,
                    "candidate.dec": 1,
                    "candidate.magpsf": 1,
                    "candidate.sigmapsf": 1,
                    "candidate.distnr": 1,
                    "candidate.magnr": 1,
                    "candidate.sigmagnr": 1,
                    "candidate.sky": 1,
                    "candidate.fwhm": 1,
                    "candidate.classtar": 1,
                    "candidate.mindtoedge": 1,
                    "candidate.rb": 1,
                    "candidate.drb": 1,
                    "candidate.ndethist": 1,
                    "candidate.jdstarthist": 1,
                    "candidate.jdendhist": 1,
                    "candidate.scorr": 1,
                    "candidate.sgscore1": 1,
                    "candidate.distpsnr1": 1,
                    "candidate.sgscore2": 1,
                    "candidate.distpsnr2": 1,
                    "candidate.sgscore3": 1,
                    "candidate.distpsnr3": 1,
                    "candidate.magzpsci": 1, 
                    "candidate.magzpsciunc": 1,
                    "candidate.neargaia": 1, 
                    "candidate.maggaia": 1,
                    
                    "classifications.acai_h": 1,
                    "classifications.acai_v": 1,
                    "classifications.acai_o": 1,
                    "classifications.acai_n": 1,
                    "classifications.acai_b": 1,
                    
                    "cutoutScience": 1,
                    "cutoutTemplate": 1,
                    "cutoutDifference": 1,
                }
            }
        }

        r = kowalski.query(query)

        if r['data'] == []:
            print("  No data for", ZTFID)
        else:
            alerts += r['data']
            print("  Finished querying", ZTFID)
    print(BOLD+f"Finished all queries, got {len(alerts)} alerts"+END+"\n")
    return alerts
    

In [None]:
def make_triplet(alert, normalize: bool = True):
    """
        Feed in alert packet
        ADAPTED FROM https://github.com/dmitryduev/braai
    """
    from bson.json_util import loads, dumps
    import gzip
    import io
    from astropy.io import fits
    from matplotlib.colors import LogNorm

    cutout_dict = dict()
    drop = False
    
    for cutout in ('science', 'template', 'difference'):
        cutout_data = loads(dumps([alert[f'cutout{cutout.capitalize()}']['stampData']]))[0]
        # unzip
        with gzip.open(io.BytesIO(cutout_data), 'rb') as f:
            with fits.open(io.BytesIO(f.read())) as hdu:
                data = hdu[0].data
                # replace nans with zeros
                cutout_dict[cutout] = np.nan_to_num(data, nan=np.nanmedian(data.flatten()))
                
                # normalize
                if normalize:
                    cutout_dict[cutout] /= np.linalg.norm(cutout_dict[cutout])
                
                if np.sum(cutout_dict[cutout].flatten())==0:
                    drop=True
                
        # pad to 63x63 if smaller
        shape = cutout_dict[cutout].shape
        if shape != (63, 63):
            print(shape, alert['candid'])
            cutout_dict[cutout] = np.pad(cutout_dict[cutout],
                                         [(0, 63 - shape[0]),
                                          (0, 63 - shape[1])],
                                         mode='constant', constant_values=1e-9)
    
    triplet = np.zeros((63, 63, 3))
    triplet[:, :, 0] = cutout_dict['science']
    triplet[:, :, 1] = cutout_dict['template']
    triplet[:, :, 2] = cutout_dict['difference']
    
    return triplet, drop


In [None]:
def plot_triplet(tr, show_fig: bool = True):
    """ADAPTED FROM https://github.com/dmitryduev/braai"""
    
    fig = plt.figure(figsize=(8, 2), dpi=120)
    ax1 = fig.add_subplot(131)
    ax1.axis('off')
#     mean, median, std = sigma_clipped_stats(tr[:, :, 0])
    ax1.imshow(tr[:, :, 0], origin='upper', cmap=plt.cm.bone, norm=LogNorm())
    ax1.title.set_text('Science')
    ax2 = fig.add_subplot(132)
    ax2.axis('off')
#     mean, median, std = sigma_clipped_stats(tr[:, :, 1])
    ax2.imshow(tr[:, :, 1], origin='upper', cmap=plt.cm.bone, norm=LogNorm())
    ax2.title.set_text('Reference')
    ax3 = fig.add_subplot(133)
    ax3.axis('off')
#     mean, median, std = sigma_clipped_stats(tr[:, :, 2])
    ax3.imshow(tr[:, :, 2], origin='upper', cmap=plt.cm.bone)
    ax3.title.set_text('Difference')

    if show_fig:
        plt.show()
    else:
        return fig
    

In [None]:
def extract_triplets(alerts, normalize: bool = True):
    triplets = np.empty((len(alerts), 63, 63, 3))
    to_drop = np.array((), dtype=int)
    for i, alert in enumerate(alerts):
        triplets[i], drop = make_triplet(alert, normalize=normalize)
        alert.pop('cutoutScience'); alert.pop('cutoutTemplate'); alert.pop('cutoutDifference')
        if drop:
            to_drop = np.append(to_drop, int(i))
            
    if len(to_drop) > 0:
        triplets = np.delete(triplets, list(to_drop), axis=0)
        alerts = np.delete(alerts, list(to_drop), axis=0)

    return alerts, triplets


In [None]:
def process_cand_data(alerts, label):
    cand_class_data = [alert['candidate'] | alert['classifications'] for alert in alerts]

    df = pd.DataFrame(cand_class_data)
    df.insert(0, "objectId", [alert['objectId'] for alert in alerts])
    df.insert(1, "candid", [alert['candid'] for alert in alerts])
    
    # label must be int equalling 0, 1 or a list of 1s and 0s
    if type(label) == list or type(label) == np.ndarray:
        assert(len(label) == len(alerts))
        df.insert(2, "label", label)
    elif type(label) == int:    
        df.insert(2, "label", np.full((len(alerts),), label, dtype=int))
    print("Arranged candidate data and inserted labels")
    return df


### Query data from kowalski, separate and save triplets and candidate data  

In [None]:
# print(f"Querying kowalski for {len(bts_trues['ZTFID'])} objects")
# bts_true_alerts, bts_true_triplets = extract_triplets(query_kowalski(bts_trues['ZTFID'].to_list(), k), True)

# np.save("data/base_data/bts_true_triplets.npy", bts_true_triplets)
# del bts_true_triplets
# print("Saved and purged triplets\n")

# num_bts_true_alerts = len(bts_true_alerts)

# print(f"All {num_bts_true_alerts} alerts are trues")

# bts_true_cand_data = process_cand_data(bts_true_alerts, np.ones(num_bts_true_alerts, dtype=int))
# bts_true_cand_data.to_csv('data/base_data/bts_true_candidates.csv', index=False)
# del bts_true_cand_data
# print("Saved and purged candidate data")


In [None]:
# print(f"Querying kowalski for {len(bts_dim_falses['ZTFID'])} objects")
# bts_dim_false_alerts, bts_dim_false_triplets = extract_triplets(query_kowalski(bts_dim_falses['ZTFID'].to_list(), k), True)

# np.save("data/base_data/bts_dim_false_triplets.npy", bts_dim_false_triplets)
# del bts_dim_false_triplets
# print("Saved and purged triplets\n")

# num_bts_dim_false_alerts = len(bts_dim_false_alerts)

# print(f"All {num_bts_dim_false_alerts} alerts are falses")

# bts_dim_false_cand_data = process_cand_data(bts_dim_false_alerts, np.zeros(num_bts_dim_false_alerts, dtype=int))
# bts_dim_false_cand_data.to_csv('data/base_data/bts_dim_false_candidates.csv', index=False)
# del bts_dim_false_cand_data
# print("Saved and purged candidate data")


In [None]:
# print(f"Querying kowalski for {len(bts_var_falses['ZTFID'])} objects")
# bts_var_false_alerts, bts_var_false_triplets = extract_triplets(query_kowalski(bts_var_falses['ZTFID'].to_list(), k), True)

# np.save("data/base_data/bts_var_false_triplets.npy", bts_var_false_triplets)
# del bts_var_false_triplets
# print("Saved and purged triplets\n")

# num_bts_var_false_alerts = len(bts_var_false_alerts)

# print(f"All {num_bts_var_false_alerts} alerts are falses")

# bts_var_false_cand_data = process_cand_data(bts_var_false_alerts, np.zeros(num_bts_var_false_alerts, dtype=int))
# bts_var_false_cand_data.to_csv('data/base_data/bts_var_false_candidates.csv', index=False)
# del bts_var_false_cand_data
# print("Saved and purged candidate data")


In [None]:
# print(f"Querying kowalski for {len(MS_Ias)} objects")
# MS_alerts, MS_triplets = extract_triplets(query_kowalski(MS_Ias['ZTFID'].to_list(), k), True)

# np.save("data/base_data/MS_triplets.npy", MS_triplets)
# del MS_triplets
# print("Saved and purged triplets\n")

# num_MS_alerts = len(MS_alerts)

# MS_true_objs = set()
# for al in MS_alerts: 
#     if al['candidate']['magpsf'] < 18.5:
#         MS_true_objs.add(al['objectId'])
# MS_labels = [1 if al['objectId'] in MS_true_objs else 0 for al in MS_alerts]
# print(f"Generated labels: {np.sum(MS_labels)} trues, {len(MS_labels)-np.sum(MS_labels)} falses")

# # MS_true_objs = set()
# # for idx in cand.index: 
# #     if cand.loc[idx]['magpsf'] < 18.5:
# #         MS_true_objs.add(cand.loc[idx]['objectId'])
# # MS_labels = [1 if objectid in MS_true_objs else 0 for objectid in cand['objectId']]
# # print(f"Generated labels: {np.sum(MS_labels)} trues, {len(MS_labels)-np.sum(MS_labels)} falses")


# MS_cand_data = process_cand_data(MS_alerts, MS_labels)
# MS_cand_data.to_csv('data/base_data/MS_candidates.csv', index=False)
# del MS_cand_data
# print("Saved and purged candidate data")


### Thin datasets down to N alerts per object

In [None]:
def thin_by_alerts(set_name, mods, N_max: int):
    np.random.seed(2)
    print(f"Thinning {set_name}{mods} data to {N_max} alerts per object")
    trip_filename = f"data/base_data/{set_name}_triplets{mods}.npy"
    cand_filename = f"data/base_data/{set_name}_candidates{mods}.csv"
    
    triplets = np.load(trip_filename, mmap_mode='r+')
    cand = pd.read_csv(cand_filename)
    
    plt.figure()
    _ = plt.hist(cand['objectId'].value_counts(), histtype='step', bins=50)
    plt.tight_layout()
    plt.show()
    print(f"Initial median of {np.median(cand['objectId'].value_counts())} detections per object")
    
    drops = np.empty((0,), dtype=int)
    for ID in set(cand['objectId']):
        reps = np.argwhere(np.asarray(cand['objectId']) == ID).flatten()
        if len(reps) >= N_max:
            drops = np.concatenate((drops, np.random.choice(reps, len(reps)-N_max, replace=False)))
    
    
    triplets = np.delete(triplets, drops, axis=0)
    cand = cand.drop(index=drops)
    print(f"{BOLD}Dropped {len(drops)} {set_name} alerts{END}")
    
    np.save(f"data/{set_name}_triplets{mods}_{N_max}max.npy", triplets)
    cand.to_csv(f"data/{set_name}_candidates{mods}_{N_max}max.csv", index=False)
    
    print(f"Final median of {np.median(cand['objectId'].value_counts())} detections per object")
    print(f"Saved thinned {set_name}{mods} data to disk\n")
    

In [None]:
N_max_alerts = 15

thin_by_alerts("bts_true", "", N_max_alerts)
thin_by_alerts("bts_false", "", N_max_alerts)
thin_by_alerts("MS", "", N_max_alerts)

### Concatenate triplets and candidate data from given sources into two primary files

In [None]:
def concat(source_sets, mods, N_max: int = None):
    print(f"Merging triplets and candidate data for {source_sets}{mods} {f'with {N_max} maximum alert(s) per object' if N_max is not None else '' }")
    triplets = np.empty((0,63,63,3))
    cand = pd.DataFrame()
    
    for source_set in source_sets:
        triplets = np.concatenate((triplets, 
                                   np.load(f"data/{source_set}_triplets{mods}{ f'_{N_max}max' if N_max is not None else '' }.npy", mmap_mode='r+')))
        cand = pd.concat((cand,
                          pd.read_csv(f"data/{source_set}_candidates{mods}{ f'_{N_max}max' if N_max is not None else '' }.csv")))
        print(f"  Read and merged {source_set} data")
        
    np.save(f"data/triplets{mods}{ f'_{N_max}max' if N_max is not None else '' }.npy", triplets)
    cand.to_csv(f"data/candidates{mods}{ f'_{N_max}max' if N_max is not None else '' }.csv", index=False)
    print("Wrote merged triplets and candidate data")
    del triplets, cand

In [None]:
concat(["bts_true", "bts_false"], "", N_max_alerts)

### Visualization helper functions

In [None]:
def plot_lightcurve(alerts):
    fid_to_color = {
        1: ('green',  'g'),
        2: ('red',    'r'),
        3: ('orange', 'i')
    }
    
    alerts = alerts.sort_values(by='jd')
    jds = alerts['jd']-alerts['jd'].to_numpy()[0]
        
    fig, ax = plt.subplots(figsize=(8,5))
    
    for fid in [1, 2, 3]:
        obs_in_filt = alerts['fid'] == fid
        
        alerts_in_filt = alerts.loc[obs_in_filt]
        plt.errorbar(jds[obs_in_filt], alerts_in_filt['magpsf'], fmt='o', color=fid_to_color[fid][0], yerr=alerts_in_filt['sigmapsf'], label='ztf'+fid_to_color[fid][1])
    
    ax.invert_yaxis()
    ax.set_xlabel("days since first detection", size=16, labelpad=10)
    ax.set_ylabel("PSF magnitude", size=16, labelpad=10)
    ax.legend(loc='upper right', bbox_to_anchor=(1.25, 1))
    
    return fig, ax

In [None]:
set_name = "MS"

cand = pd.read_csv(f"data/{set_name}_candidates.csv")

In [None]:
ztfids = cand['objectId'].value_counts().index.to_numpy()
np.random.shuffle(ztfids)

for ztfid in ztfids[0:1]:
    fig, ax = plot_lightcurve(cand[cand['objectId']==ztfid])
    ax.set_title(f"{ztfid} lightcurve", size=14)
    fig.tight_layout()
    plt.show()

### Inventory of $\texttt{isdiffpos = False}$

In [None]:
set_name = "bts_true"

triplets = np.load(f"data/base_data/{set_name}_triplets.npy", mmap_mode='r')
cand = pd.read_csv(f"data/base_data/{set_name}_candidates.csv")

In [None]:
cand['isdiffpos'] = [True if isdiffpos == 't' else False for isdiffpos in cand['isdiffpos']]

In [None]:
negdiffs = cand[~cand['isdiffpos']]

negdiff_objids = negdiffs['objectId'].value_counts().index.to_numpy()
negdiffs_counts = negdiffs['objectId'].value_counts().to_numpy()

cand_objids = cand['objectId'].value_counts().index.to_numpy()
cand_counts = cand['objectId'].value_counts().to_numpy()

print(f"Percent of alerts with negative diff {100*len(negdiffs)/len(cand):.2f}%")
print(f"Percent of objects that have at least one negative diff alert {100*len(negdiff_objids)/len(cand_objids):.2f}%")

# MS:        62.32% neg
# BTS_true:   6.48% neg
# BTS_false: 19.70% neg


In [None]:
negdiff_fracs = [negcounts/len(cand[cand['objectId']==objid]) for objid, negcounts in zip(negdiff_objids, negdiffs_counts)]

percs = np.zeros(len(negdiff_objids))
for i, (objid, negcounts) in enumerate(zip(negdiff_objids, negdiffs_counts)):
    percs[i] = 100*negcounts/len(cand[cand['objectId']==objid])
    print(f"{objid} has {percs[i]:.2f}% negative differences")

In [None]:
plt.figure()
plt.hist(percs)
plt.show()

In [None]:
def plot_lightcurve_itr(alerts, triplets):
    fid_to_color = {
        1: ('green',  'g'),
        2: ('red',    'r'),
        3: ('orange', 'i')
    }
    
#     alerts = alerts.sort_values(by='jd')
    first_detect = np.min(alerts['jd'].to_numpy())
    print(first_detect)
    for trip in trips[::10]:
        
        plot_triplet(trip, show_fig=True)
        
    fig, ax = plt.subplots(figsize=(8,5))
    
    for idx in alerts.index:
        alert = alerts.loc[idx]
        jd = alert['jd'] - first_detect
        if alert['isdiffpos']:
            plt.errorbar(jd, alert['magpsf'], fmt='^', fillstyle='none', alpha=0.5, color=fid_to_color[alert['fid']][0], yerr=alert['sigmapsf'])
        else:
            plt.errorbar(jd, alert['magpsf'], fmt='v', fillstyle='none', alpha=0.5, color=fid_to_color[alert['fid']][0], yerr=alert['sigmapsf'])
    
    
    plt.plot([],[], marker='^', color='red', fillstyle='none', alpha=0.5, label='ztfr pos diff')
    plt.plot([],[], marker='^', color='green', fillstyle='none', alpha=0.5, label='ztfg pos diff')
    plt.plot([],[], marker='^', color='orange', fillstyle='none', alpha=0.5, label='ztfi pos diff')
    
    plt.plot([],[], marker='v', color='r', fillstyle='none', alpha=0.5, label='ztfr neg diff')
    plt.plot([],[], marker='v', color='green', fillstyle='none', alpha=0.5, label='ztfg neg diff')
    plt.plot([],[], marker='v', color='orange', fillstyle='none', alpha=0.5, label='ztfi neg diff')
    ax.invert_yaxis()
    ax.set_xlabel("days since first detection", size=16, labelpad=10)
    ax.set_ylabel("PSF magnitude", size=16, labelpad=10)
    ax.legend(loc='upper right')#, bbox_to_anchor=(1.2,1))
    
    return fig, ax

In [None]:
# ztfid = "ZTF18ablmduj"

# alerts = cand[cand['objectId']==ztfid]
# trips = triplets[cand['objectId']==ztfid]
# fig, ax = plot_lightcurve_itr(alerts, trips)
# ax.set_title(f"{ztfid} lightcurve", size=14)
# fig.tight_layout()
# plt.show()

### Inventory of $i$-band data

In [None]:
triplets = np.load(f"data/{set_name}_triplets.npy", mmap_mode='r')
cand = pd.read_csv(f"data/{set_name}_candidates.csv")

cand['isdiffpos'] = [True if isdiffpos == 't' else False for isdiffpos in cand['isdiffpos']]

In [None]:
iband = cand[cand['fid']==3]

iband_objids = iband['objectId'].value_counts().index.to_numpy()
iband_counts = iband['objectId'].value_counts().to_numpy()

cand_objids = cand['objectId'].value_counts().index.to_numpy()
cand_counts = cand['objectId'].value_counts().to_numpy()

print(f"Percent of alerts that are in the i-band {100*len(iband)/len(cand):.2f}%")
print(f"Percent of objects that have at least one i band alert {100*len(iband_objids)/len(cand_objids):.2f}%")


# MS:        3.37% i-band
# BTS_true:  8.10% i-band
# BTS_false: 4.48% i-band


In [None]:
neg_iband = iband[~iband['isdiffpos']]

neg_iband_objids = neg_iband['objectId'].value_counts().index.to_numpy()
neg_iband_counts = neg_iband['objectId'].value_counts().to_numpy()

print(f"Percent of i-band alerts that have negative differences {100*len(neg_iband)/len(iband):.2f}%")

# MS:        67.58% neg & i-band
# BTS_true:  30.24% neg & i-band
# BTS_false: 30.22% neg & i-band


In [None]:
cand_pd_gr = cand[(cand['isdiffpos']) & ((cand['fid'] == 1) | (cand['fid'] == 2))]
triplets_pd_gr = triplets[(cand['isdiffpos']) & ((cand['fid'] == 1) | (cand['fid'] == 2))]

print("Positive difference alerts in g- or r-band", len(cand_pd_gr))

np.save(f"data/{set_name}_triplets_pd_gr.npy", triplets_pd_gr)
cand_pd_gr.to_csv(f"data/{set_name}_candidates_pd_gr.csv", index=False)


In [None]:
# # Purge thinned files
# import glob

# files = glob.glob("data/*max*")
# for file in files:
#     !rm "$file"