In [1]:
#!/usr/bin/env python
# coding: utf-8

# Import required libraries and modules
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from astropy.coordinates import SkyCoord
import astropy.units as u
from tqdm import tqdm

# Define a threshold for matching errors
space_match_threshold = 1 * u.arcsec
MJD_tolerance = 0.00034 #30 sec in units of day 

# Define file paths.
sum_path = {}
sum_path["star"] = '../truth_star/truth_star_summary_v1-0-0.parquet'
#'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_star/truth_star_summary_v1-0-0.parquet'
sum_path["sn"] = "../truth_sn/truth_sn_summary_v1-0-0.parquet" 
#'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_sn/truth_sn_summary_v1-0-0.parquet'

var_path = {}
var_path["star"] = '../truth_star/truth_star_variability_v1-0-0.parquet'
#'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_star/truth_star_variability_v1-0-0.parquet'
var_path["sn"] = '../truth_sn/truth_sn_variability_v1-0-0.parquet'
#'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_sn/truth_sn_variability_v1-0-0.parquet'

detection_csv_pth = "../sources_with_labels.csv" #'exported_sources.csv'

#get DIA detections
dia_detections = pd.read_csv(detection_csv_pth, index_col="diaSourceId") #formerly known as exported_csv
#only for now based on the csv we are reading
dia_detections.drop(["real"], axis=1, inplace=True)    

dia_detections

Unnamed: 0_level_0,ra,dec,midpointMjdTai,type
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1257927201521665,55.760339,-32.260622,59583.125051,
1257927201521666,55.674078,-32.283857,59583.125051,
1257927201521667,55.552914,-32.306395,59583.125051,
1257927201521668,55.547689,-32.309278,59583.125051,
1257927201521669,55.570127,-32.306400,59583.125051,
...,...,...,...,...
660667525163384915,55.889519,-32.485637,61392.194195,star
661047079476396040,55.863218,-32.167598,61393.204087,star
662500331589992590,55.971559,-32.358853,61404.195949,star
662500331589992596,55.881022,-32.482719,61404.195949,star


In [2]:
# Get mind and max ra and dec values to filter out unnecessary records.
max_exp_ra, min_exp_ra = dia_detections.ra.max(), dia_detections.ra.min()
max_exp_dec, min_exp_dec = dia_detections.dec.max(), dia_detections.dec.min()


catalog = {}
result_sum = {}

# Stage 1: Match sources in Space.
for s in ["star", "sn"]:
    # Read Parquet and CSV files to begin ground truth derivation.
    result_sum[s] = pd.read_parquet(sum_path[s])

    # Keep only those records from summary tables which are within the max ra and dec values in the exported sources.
    result_sum[s] = result_sum[s][(result_sum[s]['ra'] >= min_exp_ra) & (result_sum[s]['ra'] <= max_exp_ra) &\
                                    (result_sum[s]['dec'] >= min_exp_dec) & (result_sum[s]['dec'] <= max_exp_dec)]


    # Initialize astropy.coordinates.SkyCoord class for matching in space.
    catalog[s] = SkyCoord(ra=result_sum[s].ra, dec=result_sum[s].dec, unit=u.deg)

# Match exported sources with stars and supernovae.
detections_cat = SkyCoord(ra=dia_detections.ra, dec=dia_detections.dec, unit=u.deg)

star_idx, star_d2d, star_d3d = catalog['star'].match_to_catalog_sky(detections_cat)
print(f"{len(star_idx)} of {len(detections_cat)} stars matched before applying threshold")
sn_idx, sn_d2d, sn_d3d = catalog['sn'].match_to_catalog_sky(detections_cat)
print(f"{len(sn_idx)} of {len(detections_cat)} sne matched before applying threshold")

star_mask = star_d2d < space_match_threshold #remove matches that are too far
sn_mask = sn_d2d < space_match_threshold #remove matches that are too far

# Assign the variability sources catalog id to the detections
dia_detections["id"] = None
dia_detections.loc[dia_detections.index[star_idx], "id"] = result_sum["star"]["id"].to_numpy()
dia_detections.loc[dia_detections.index[sn_idx], "id"] = result_sum['sn']["id"].to_numpy()

# Get all matched stars
matched_star_idx = star_idx[star_mask] #index in detections_cat of matched stars
print(f"Number of matched stars in Stage #1: {len(matched_star_idx)}")

# Get all matched supernovae
matched_sn_idx = sn_idx[sn_mask] #index in detections_cat of matched sn
print(f"Number of matched sne in Stage #1: {len(matched_sn_idx)}")


# By default, set on_source = 0 and real=0 (bogus) for all values in the exported sources.
dia_detections['on_source'] = 0
dia_detections['real'] = 0
dia_detections['type'] = None

dia_detections

2018 of 25446 stars matched before applying threshold
217 of 25446 sne matched before applying threshold
Number of matched stars in Stage #1: 399
Number of matched sne in Stage #1: 19


Unnamed: 0_level_0,ra,dec,midpointMjdTai,type,id,on_source,real
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1257927201521665,55.760339,-32.260622,59583.125051,,,0,0
1257927201521666,55.674078,-32.283857,59583.125051,,,0,0
1257927201521667,55.552914,-32.306395,59583.125051,,,0,0
1257927201521668,55.547689,-32.309278,59583.125051,,,0,0
1257927201521669,55.570127,-32.306400,59583.125051,,31102012090,0,0
...,...,...,...,...,...,...,...
660667525163384915,55.889519,-32.485637,61392.194195,,31411443281,0,0
661047079476396040,55.863218,-32.167598,61393.204087,,31102009372,0,0
662500331589992590,55.971559,-32.358853,61404.195949,,31405685742,0,0
662500331589992596,55.881022,-32.482719,61404.195949,,31411442918,0,0


In [4]:
# The spatially matched detections get on_source = 1
listid_sn = dia_detections.iloc[matched_sn_idx]
listid_stars = dia_detections.iloc[matched_star_idx]
dia_detections.loc[listid_sn.index, "on_source"] = 1
dia_detections.loc[listid_stars.index , "on_source"] = 1
dia_detections.loc[listid_sn.index, "type"] = "sn"
dia_detections.loc[listid_stars.index, "type"] = "star"
dia_detections

Unnamed: 0_level_0,ra,dec,midpointMjdTai,type,id,on_source,real
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1257927201521665,55.760339,-32.260622,59583.125051,,,0,0
1257927201521666,55.674078,-32.283857,59583.125051,,,0,0
1257927201521667,55.552914,-32.306395,59583.125051,,,0,0
1257927201521668,55.547689,-32.309278,59583.125051,,,0,0
1257927201521669,55.570127,-32.306400,59583.125051,,31102012090,0,0
...,...,...,...,...,...,...,...
660667525163384915,55.889519,-32.485637,61392.194195,star,31411443281,1,0
661047079476396040,55.863218,-32.167598,61393.204087,star,31102009372,1,0
662500331589992590,55.971559,-32.358853,61404.195949,star,31405685742,1,0
662500331589992596,55.881022,-32.482719,61404.195949,star,31411442918,1,0


In [47]:
# Print a summary at the end of first round of matching.
print("Summary at the end of First Stage:")
print(f"detections on a source", dia_detections["on_source"].sum(), "\n")
print(f"class detection: {dia_detections.groupby('type').count().iloc[:,0]}")

Summary at the end of First Stage:
detections on a source 418 

class detection: type
sn       19
star    399
Name: ra, dtype: int64


In [12]:
# Stage 2: Match sources in time.

matched_index = dia_detections["on_source"] > 0
matched = {}
matched["sn"] = dia_detections.loc[listid_sn.index]
matched["star"] = dia_detections.loc[listid_stars.index]


for s in ["sn", "star"]:
    print(f"working on class: {s}")
    
    # Get a list of all the unique MJDs of sources that matched in the previous stage for the object type.
    mjd_matched_in_space = matched[s].midpointMjdTai.unique()

    # Get min and max MJD values required for matching.
    max_mjd, min_mjd = mjd_matched_in_space.max(), mjd_matched_in_space.min()

    # Read star/sn lightcurve variability parquet for the object type
    df_var = pd.read_parquet(var_path[s])
    
    # Filter out records with unwanted MJDs.
    df_var = df_var[(df_var.MJD >= min_mjd) & (df_var.MJD <= max_mjd)]
    print(f"need to examine {len(df_var)} variability entries")
    
    for detected in tqdm(matched[s].index): #loop over indices of on_source detection 
        mask_matching_ids = df_var.id == matched[s].loc[detected].id #mask for sources with matching id in variability file 
        if np.any(np.abs(df_var[mask_matching_ids].MJD - matched[s].loc[detected].midpointMjdTai) <= MJD_tolerance):
            dia_detections.loc[detected, "real"] = 1
            print(detected)
    del df_var

dia_detections[dia_detections.real == 1]

working on class: sn
need to examin 11752895 entries


100%|██████████████████████████████████████████████████████████████████████| 19/19 [00:00<00:00, 174.29it/s]

working on class: star





need to examin 379673003 entries


 30%|████████████████████▉                                                | 121/399 [00:30<01:01,  4.54it/s]

121861151112822817


 41%|████████████████████████████▌                                        | 165/399 [00:41<00:54,  4.26it/s]

645816020564443153


 54%|█████████████████████████████████████▌                               | 217/399 [00:53<00:40,  4.50it/s]

225905106065817658


 60%|█████████████████████████████████████████▏                           | 238/399 [00:58<00:36,  4.44it/s]

351362700657295476


 62%|██████████████████████████████████████████▉                          | 248/399 [01:01<00:41,  3.65it/s]

488448449977516453


 63%|███████████████████████████████████████████▊                         | 253/399 [01:04<01:37,  1.50it/s]

647699403989058266


 65%|████████████████████████████████████████████▌                        | 258/399 [01:07<01:55,  1.22it/s]

372716278936240222


 76%|████████████████████████████████████████████████████▌                | 304/399 [01:20<00:24,  3.95it/s]

379922626857926673


 79%|██████████████████████████████████████████████████████▊              | 317/399 [01:23<00:18,  4.45it/s]

510734470488260694


 86%|███████████████████████████████████████████████████████████▍         | 344/399 [01:29<00:13,  4.15it/s]

242418712005574717


 86%|███████████████████████████████████████████████████████████▋         | 345/399 [01:30<00:13,  3.99it/s]

554641942630105130


 97%|██████████████████████████████████████████████████████████████████▊  | 386/399 [01:40<00:03,  4.27it/s]

627179613616865357


100%|█████████████████████████████████████████████████████████████████████| 399/399 [01:43<00:00,  3.86it/s]


Unnamed: 0_level_0,ra,dec,midpointMjdTai,type,id,on_source,real
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
121861151112822817,55.815878,-32.37276,59886.113408,star,31107745011,1,1
225905106065817658,55.86023,-32.489914,60179.353241,star,31411443609,1,1
242418712005574717,55.816639,-32.261288,60232.172006,star,31411419399,1,1
351362700657295476,55.75577,-32.298335,60549.272587,star,31411422848,1,1
372716278936240222,55.582887,-32.369382,60605.310963,star,30317265804,1,1
379922626857926673,55.917179,-32.321105,60625.068296,star,31411425954,1,1
488448449977516453,55.642648,-32.175003,60905.319335,star,31411409555,1,1
510734470488260694,55.966513,-32.222557,60973.167688,star,30321355633,1,1
554641942630105130,55.792227,-32.292353,61101.060926,star,30830343259,1,1
627179613616865357,55.831304,-32.192921,61297.243962,star,31405665055,1,1
