In [12]:
#!/usr/bin/env python
# coding: utf-8

# Import required libraries and modules
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from astropy.coordinates import SkyCoord
import astropy.units as u

# Define a threshold for matching errors
threshold = 1 * u.arcsec

# Define file paths.
sum_path = {}
sum_path["star"] = '../truth_star/truth_star_summary_v1-0-0.parquet'
#'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_star/truth_star_summary_v1-0-0.parquet'
#'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_star/truth_star_variability_v1-0-0.parquet'
sum_path["sn"] = "../truth_sn/truth_sn_summary_v1-0-0.parquet" #'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_sn/truth_sn_summary_v1-0-0.parquet'

var_path = {}
var_path["star"] = '../truth_star/truth_star_variability_v1-0-0.parquet'
var_path["sn"] = '../truth_sn/truth_sn_variability_v1-0-0.parquet'
#'/sdf/data/rubin/shared/dc2_run2.2i_truth/truth_sn/truth_sn_variability_v1-0-0.parquet'
exported_csv_pth = "../sources_with_labels.csv" #'exported_sources.csv'

#get DIA detections
dia_detections = pd.read_csv(exported_csv_pth, index_col="diaSourceId") #formerly known as exported_csv
#only for now based on the csv we are reading
dia_detections.drop(["real"], axis=1, inplace=True)

# Get mind and max ra and dec values to filter out unnecessary records.
max_exp_ra, min_exp_ra = dia_detections.ra.max(), dia_detections.ra.min()
max_exp_dec, min_exp_dec = dia_detections.dec.max(), dia_detections.dec.min()

result_sum = {}
catalog = {}
# Stage 1: Match sources in Space.
for s in ["sn", "star"]:
    # Read Parquet and CSV files to begin ground truth derivation.
    result_sum[s] = pd.read_parquet(sum_path[s])

    # Keep only those records from summary tables which are within the max ra and dec values in the exported sources.
    result_sum[s] = result_sum[s][(result_sum[s]['ra'] >= min_exp_ra) & (result_sum[s]['ra'] <= max_exp_ra) &\
                                    (result_sum[s]['dec'] >= min_exp_dec) & (result_sum[s]['dec'] <= max_exp_dec)]


    # Initialize astropy.coordinates.SkyCoord class for matching in space.
    exported_cat = SkyCoord(ra=dia_detections.ra, dec=dia_detections.dec, unit=u.deg)
    catalog[s] = SkyCoord(ra=result_sum[s].ra, dec=result_sum[s].dec, unit=u.deg)

# Match exported sources with stars and supernovae.
star_idx, star_d2d, star_d3d = catalog['star'].match_to_catalog_sky(exported_cat)
sn_idx, sn_d2d, sn_d3d = catalog['sn'].match_to_catalog_sky(exported_cat)

star_mask = star_d2d < threshold
sn_mask = sn_d2d < threshold

# Get all matched stars
matched_star_idx = star_idx[star_mask]
#matched_stars = exported_csv.iloc[matched_star_idx]
print(f"Number of matched stars in Stage #1: {len(matched_star_idx)}")

# Get all matched supernovae
matched_sn_idx = sn_idx[sn_mask]
#matched_sn = exported_csv.iloc[matched_sn_idx]
print(f"Number of matched sne in Stage #1: {len(matched_sn_idx)}")

# By default, set real=0 (bogus) for all values in the exported sources.
dia_detections['detected'] = 0
dia_detections['real'] = 0
dia_detections['type'] = None

dia_detections

Number of matched stars in Stage #1: 399
Number of matched sne in Stage #1: 19


Unnamed: 0_level_0,ra,dec,midpointMjdTai,type,detected,real
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1257927201521665,55.760339,-32.260622,59583.125051,,0,0
1257927201521666,55.674078,-32.283857,59583.125051,,0,0
1257927201521667,55.552914,-32.306395,59583.125051,,0,0
1257927201521668,55.547689,-32.309278,59583.125051,,0,0
1257927201521669,55.570127,-32.306400,59583.125051,,0,0
...,...,...,...,...,...,...
660667525163384915,55.889519,-32.485637,61392.194195,,0,0
661047079476396040,55.863218,-32.167598,61393.204087,,0,0
662500331589992590,55.971559,-32.358853,61404.195949,,0,0
662500331589992596,55.881022,-32.482719,61404.195949,,0,0


In [18]:
dia_detections

Unnamed: 0_level_0,ra,dec,midpointMjdTai,type,detected,real
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1257927201521665,55.760339,-32.260622,59583.125051,,0,0
1257927201521666,55.674078,-32.283857,59583.125051,,0,0
1257927201521667,55.552914,-32.306395,59583.125051,,0,0
1257927201521668,55.547689,-32.309278,59583.125051,,0,0
1257927201521669,55.570127,-32.306400,59583.125051,,0,0
...,...,...,...,...,...,...
660667525163384915,55.889519,-32.485637,61392.194195,star,1,0
661047079476396040,55.863218,-32.167598,61393.204087,star,1,0
662500331589992590,55.971559,-32.358853,61404.195949,star,1,0
662500331589992596,55.881022,-32.482719,61404.195949,star,1,0


In [13]:
#FBB
listid_sn = dia_detections.iloc[matched_sn_idx]
listid_stars = dia_detections.iloc[matched_star_idx]

listid_stars.index

Index([484684926365466639, 484684926365466635, 141467103440928810,
       494605264658366505, 121861151112822800, 225930931704168548,
       374633902913880562, 242966879218434098, 636296203004280949,
       620367796559151117,
       ...
       373046440051605539, 343039123830866034,  93738742493216825,
       374632866216149155, 494605264658366482, 647130291831308763,
       657916603442135098, 242418712005574696, 374632866216149063,
       354457242847674454],
      dtype='int64', name='diaSourceId', length=399)

In [20]:

dia_detections.loc[listid_sn.index, "detected"] = 1
dia_detections.loc[listid_stars.index , "detected"] = 1
dia_detections.loc[listid_sn.index, "type"] = "sn"
dia_detections.loc[listid_stars.index, "type"] = "star"
dia_detections.loc[dia_detections.type == "sn"]

Unnamed: 0_level_0,ra,dec,midpointMjdTai,type,detected,real
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
93738742493216770,55.734081,-32.316066,59810.381776,sn,1,0
128110290947538969,55.685306,-32.30808,59909.044045,sn,1,0
225905618777538634,55.619536,-32.227419,60179.35369,sn,1,0
278114224394207312,55.619952,-32.23514,60338.13911,sn,1,0
289069573032902730,55.746311,-32.272632,60369.030714,sn,1,0
343017058973253667,55.929655,-32.427467,60531.374974,sn,1,0
343018160632365243,55.756722,-32.300056,60531.375875,sn,1,0
343039123830866004,55.876253,-32.161075,60531.394166,sn,1,0
373030845025353839,55.732614,-32.251144,60606.210164,sn,1,0
374633902913880566,55.743163,-32.253603,60610.105066,sn,1,0


In [26]:

# Print a summary at the end of first round of matching.
print("Summary at the end of First Stage:")
print(dia_detections[["detected"]].sum())
#print(df_combined.groupby("real").count())

# Stage 2: Match sources in time.

"""
# Separate the matched and unmatched sources.
df_matched = df_combined[df_combined.real==1]
df_matched.reset_index(drop=True, inplace=True)

df_unmatched = df_combined[df_combined.real==0]
df_unmatched.reset_index(drop=True, inplace=True)

"""
matched_index = dia_detections["detected"] > 0


# Get a list of all the unique MJDs of sources that matched in previous stage.
mjd_matched_in_space = dia_detections.loc[matched_index].midpointMjdTai.unique()

# Get min and max MJD values required for matching.
max_mjd, min_mjd = mjd_matched_in_space.max(), mjd_matched_in_space.min()

# Define tolerance for errors while matching.
tolerance=0.00034
matched_mjd = []

for s in ["sn"]:#"star", ]:
    # Read star lightcurve variability parquet
    df_var = pd.read_parquet(var_path[s])
    # Filter out records with unwanted MJDs.
    df_var = df_var[(df_var.MJD>=min_mjd) & (df_var.MJD<=max_mjd)]

    # Get a list of unique star MJDs for matching.
    unique_mjd_var = df_var.MJD.unique()
    
    # Perform MJD matching for detection.
    for v in mjd_matched_in_space:
        #FBBB here there should be an index that matches IDs and runs the diff < tolerance only on the spatially matched ids
        if np.any(np.abs(unique_mjd_var-v) <= tolerance):
            matched_mjd.append(v)

    # Delete data not required anymore to save memory
    del df_var, unique_mjd_var
# Create a set of matched MJD to get rid of duplicate MJD matches across stars and sn.
matched_mjd = list(set(matched_mjd))
print(matched_mjd)

Summary at the end of First Stage:
detected    418
dtype: int64
[60934.24178422337, 60937.25068922338, 60937.26380622454, 60942.201734223374, 59920.070471224535, 59926.15045022106, 59932.11353622454, 60957.20876122454, 60961.18990821991, 60961.23309222222, 59946.21070022107, 60973.15184022106, 60973.16768821991, 60976.119601224535, 60991.07399922338, 59968.10273422338, 60991.08932022107, 60993.114815221066, 60996.16909422338, 61006.042328219904, 61017.18042122453, 61017.25204321991, 61028.11725422338, 60530.355550221066, 60531.32797922338, 60531.37497422338, 60531.375875221056, 60531.39416622453, 60540.34908522107, 60542.31262821991, 60542.2946382199, 60549.27258722222, 60549.28710522107, 60557.25793321991, 60557.25838022107, 60557.25927422338, 60557.25972122453, 60559.30283122454, 61070.09395422338, 60568.319254223374, 61080.111384223375, 60570.26115722222, 60578.23621122453, 60578.35116622454, 61093.06416222222, 60584.20355722222, 60584.20400321991, 60586.191569223374, 60587.23074821

In [16]:
# Match sources that were real in the first stage.
mask = dia_detections.loc[matched_index]['midpointMjdTai'].isin(matched_mjd)
dia_detections.loc[matched_index]['real'] = mask.astype(int)
# print(df_matched.groupby('real').count())

print(dia_detections.groupby("real").count())
# Write final CSV
dia_detections

         ra    dec  midpointMjdTai  type  detected
real                                              
0     25446  25446           25446   418     25446


Unnamed: 0_level_0,ra,dec,midpointMjdTai,type,detected,real
diaSourceId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1257927201521665,55.760339,-32.260622,59583.125051,,0,0
1257927201521666,55.674078,-32.283857,59583.125051,,0,0
1257927201521667,55.552914,-32.306395,59583.125051,,0,0
1257927201521668,55.547689,-32.309278,59583.125051,,0,0
1257927201521669,55.570127,-32.306400,59583.125051,,0,0
...,...,...,...,...,...,...
660667525163384915,55.889519,-32.485637,61392.194195,star,1,0
661047079476396040,55.863218,-32.167598,61393.204087,star,1,0
662500331589992590,55.971559,-32.358853,61404.195949,star,1,0
662500331589992596,55.881022,-32.482719,61404.195949,star,1,0
