In [1]:
import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf, PandasUDFType, udf
from pyspark.sql.types import IntegerType, BinaryType, ArrayType
from pyspark.sql.types import StringType, BooleanType
from pyspark.sql import SparkSession

from astropy.coordinates import SkyCoord
from astropy import units as u

import pandas as pd
import numpy as np

from fink_filters.classification import extract_fink_classification
from fink_utils.spark.utils import concat_col
from fink_science.xmatch.utils import cross_match_astropy
from fink_science.xmatch.processor import crossmatch_mangrove, crossmatch_other_catalog



In [2]:
# taken from https://github.com/astrolabsoftware/fink-science-portal/blob/b7326ed4febe0e106c1e93565ea622c1c95218e8/assets/spark_ztf_transfer.py#L112C1-L202C14
# on 19 FEB 2024
def add_classification(spark, df, path_to_tns):
    """ Add classification from Fink & TNS

    Parameters
    ----------
    spark:
    df: DataFrame
        Spark DataFrame containing ZTF alert data
    path_to_tns: str
        Path to TNS data (parquet)

    Returns
    ----------
    df: DataFrame
        Input DataFrame with 2 new columns `finkclass` and
        `tnsclass` containing classification tags.
    """
    # extract Fink classification
    df = df.withColumn(
        'finkclass',
        extract_fink_classification(
            df['cdsxmatch'],
            df['roid'],
            df['mulens'],
            df['snn_snia_vs_nonia'],
            df['snn_sn_vs_all'],
            df['rf_snia_vs_nonia'],
            df['candidate.ndethist'],
            df['candidate.drb'],
            df['candidate.classtar'],
            df['candidate.jd'],
            df['candidate.jdstarthist'],
            df['rf_kn_vs_nonkn'],
            df['tracklet']
        )
    )

    pdf_tns_filt = pd.read_parquet(path_to_tns)
    pdf_tns_filt_b = spark.sparkContext.broadcast(pdf_tns_filt)

    @pandas_udf(StringType(), PandasUDFType.SCALAR)
    def crossmatch_with_tns(objectid, ra, dec):
        # TNS
        pdf = pdf_tns_filt_b.value
        ra2, dec2, type2 = pdf['ra'], pdf['declination'], pdf['type']

        # create catalogs
        catalog_ztf = SkyCoord(
            ra=np.array(ra, dtype=np.float) * u.degree,
            dec=np.array(dec, dtype=np.float) * u.degree
        )
        catalog_tns = SkyCoord(
            ra=np.array(ra2, dtype=np.float) * u.degree,
            dec=np.array(dec2, dtype=np.float) * u.degree
        )

        # cross-match
        idx, d2d, d3d = catalog_tns.match_to_catalog_sky(catalog_ztf)

        sub_pdf = pd.DataFrame({
            'objectId': objectid.values,
            'ra': ra.values,
            'dec': dec.values,
        })

        # cross-match
        idx2, d2d2, d3d2 = catalog_ztf.match_to_catalog_sky(catalog_tns)

        # set separation length
        sep_constraint2 = d2d2.degree < 1.5 / 3600

        sub_pdf['TNS'] = ['Unknown'] * len(sub_pdf)
        sub_pdf['TNS'][sep_constraint2] = type2.values[idx2[sep_constraint2]]

        to_return = objectid.apply(
            lambda x: 'Unknown' if x not in sub_pdf['objectId'].values
            else sub_pdf['TNS'][sub_pdf['objectId'] == x].values[0]
        )

        return to_return

    df = df.withColumn(
        'v:tns_classification',
        crossmatch_with_tns(
            df['objectId'],
            df['candidate.ra'],
            df['candidate.dec']
        )
    )

    return df

In [3]:
cols0 = ['objectId']
cols = [
    F.col('cutoutScience.stampData').alias('b:cutoutScience_stampData'),
    F.col('cutoutTemplate.stampData').alias('b:cutoutTemplate_stampData'),
    F.col('cutoutDifference.stampData').alias('b:cutoutDifference_stampData'),
    F.col('candidate.aimage').alias('i:aimage'),
    F.col('candidate.aimagerat').alias('i:aimagerat'),
    F.col('candidate.bimage').alias('i:bimage'),
    F.col('candidate.bimagerat').alias('i:bimagerat'),
    F.col('candidate.candid').alias('i:candid'),
    F.col('candidate.chinr').alias('i:chinr'),
    F.col('candidate.chipsf').alias('i:chipsf'),
    F.col('candidate.classtar').alias('i:classtar'),
    F.col('candidate.dec').alias('i:dec'),
    F.col('candidate.fid').alias('i:fid'),
    F.col('candidate.fwhm').alias('i:fwhm'),
    F.col('candidate.isdiffpos').alias('i:isdiffpos'),
    F.col('candidate.jd').alias('i:jd'),
    F.col('candidate.maggaia').alias('i:maggaia'),
    F.col('candidate.maggaiabright').alias('i:maggaiabright'),
    F.col('candidate.magpsf').alias('i:magpsf'),
    F.col('candidate.neargaia').alias('i:neargaia'),
    F.col('candidate.neargaiabright').alias('i:neargaiabright'),
    F.col('candidate.ra').alias('i:ra'),
    F.col('candidate.sigmapsf').alias('i:sigmapsf'),
    F.col('cdsxmatch'),
    F.col('roid'),
    F.col('mulens'),
    F.col('snn_snia_vs_nonia'),
    F.col('snn_sn_vs_all'),
    F.col('rf_snia_vs_nonia'),
    F.col('candidate.ndethist'),
    F.col('candidate.drb'),
    F.col('candidate.classtar'),
    F.col('candidate.jd'),
    F.col('candidate.jdstarthist'),
    F.col('rf_kn_vs_nonkn'),
    F.col('tracklet'),
    #F.col('mangrove'),
    F.col('finkclass').alias('v:classification'),
    F.col('v:tns_classification')
]

epochs = {
    'epoch1': [
        '../julien.peloton/archive/science/year=2019',
        '../julien.peloton/archive/science/year=2020',
        '../julien.peloton/archive/science/year=2021',
    ],
    'epoch2': [
        '../julien.peloton/archive/science/year=2022',
        '../julien.peloton/archive/science/year=2023',
    ]
}

# path to TNS data
path_to_tns = '/spark_mongo_tmp/julien.peloton/tns.parquet'

In [4]:
# define classes not wanted in TNS
not_wanted_tns = ['(TNS) CV', '(TNS) Varstar', '(TNS) M dwarf']

In [5]:
extragalactic_uniqclass_best = ['AGN',
                                'AGN_Candidate',
                                'SN',
                                'BH_Candidate', 
                                'Blazar', 
                                'Blazar_Candidate',
                                'BLLac', 
                                'BLLac_Candidate', 
                                'Candidate_Nova', 
                                'Candidate_NS',
                                'Candidate_SN*', 
                                'LINER',
                                'QSO',
                                'QSO_Candidate'
                                'Seyfert', 
                                'Seyfert1', 
                                'Seyfert2', 
                                'Seyfert_1',
                                'Seyfert_2',
                                'SN', 
                                'SN*_Candidate',
                                'Supernova', 
                                'ULX', 
                                'ULX?', 
                                'ULX_Candidate',
                                ###### until here the list came from Priscila
                                'SN candidate',             # from SuperNNova
                                'Early SN Ia candidate',    # from AL + random forest
                                'Kilonova candidate'       # from PCA + random forest
                                ] 

In [6]:
# spark functions created by Emille

# function that identifies classes different from Unknown
filter_known = udf(lambda x: x != 'Unknown', BooleanType())

# function that check if a given entry is within  the required Fink (ML + SIMBAD) classes
filter_classes = udf(lambda x: x in extragalactic_uniqclass_best, BooleanType())

# function that checks if an object exists in hyperLEDA
filter_hyperLEDA = udf(lambda x: x['HyperLEDA_name'] == 'None', BooleanType())

# function that checks if an object exists in 2MASS
filter_2MASS = udf(lambda x: x['2MASS_name'] == 'None', BooleanType())

In [7]:
# flag to pring intermediate steps. This take a long time! use with caution
debug = False

# read data

# read one month for quick testing
#df1 = spark.read.format('parquet').load('../julien.peloton/archive/science/year=2023/month=04/')

# read entire first epoch
df1 = spark.read.format('parquet').option('basePath', '../julien.peloton/archive/science').load(epochs['epoch2'])


24/02/20 05:23:49 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


In [8]:
if debug:
    tot = df1.count()
    print('Just read: ', tot)

In [9]:
# cross match with TNS
spark = SparkSession.builder.getOrCreate()
df1_with_classes = add_classification(spark, df=df1, path_to_tns=path_to_tns)

# get the columns we want
df1_renamed = df1_with_classes.select(cols0 + cols)

# cross match with Mangrove
df1_with_classes = df1_renamed.withColumn('mangrove',
                      crossmatch_mangrove(df1_renamed['i:candid'], df1_renamed['i:ra'], 
                                          df1_renamed['i:dec'], F.lit(60.0)))

# cross match with GCVS
df1_with_classes = df1_with_classes.withColumn('gcvs', 
                                             crossmatch_other_catalog(F.col('i:candid'), 
                                                                      F.col('i:ra'), 
                                                                      F.col('i:dec'),
                                                                      F.lit('gcvs')))

# cross match with VSX
df1_with_classes = df1_with_classes.withColumn('vsx', 
                                             crossmatch_other_catalog(F.col('i:candid'), 
                                                                      F.col('i:ra'), 
                                                                      F.col('i:dec'),
                                                                      F.lit('vsx')))

# cross match with 3hsp
df1_with_classes = df1_with_classes.withColumn('3hsp', 
                                             crossmatch_other_catalog(F.col('i:candid'), 
                                                                      F.col('i:ra'), 
                                                                      F.col('i:dec'),
                                                                      F.lit('3hsp')))

# cross match with VSX
df1_with_classes = df1_with_classes.withColumn('4lac', 
                                             crossmatch_other_catalog(F.col('i:candid'), 
                                                                      F.col('i:ra'), 
                                                                      F.col('i:dec'),
                                                                      F.lit('4lac')))

In [10]:
if debug:
    df1_with_classes.select('v:classification', 'v:tns_classification', 'gcvs', 'vsx', '3hsp', '4lac','spicy').show()

In [11]:
# remove things that do not have a counter part in the fink ML classes nor SIMBAD nor TNS
df1_known = df1_with_classes.filter(filter_known(F.col('v:tns_classification')) | \
                               filter_known(F.col('v:classification')))

if debug:
    # count number of alerts
    tot_df1_known = df1_known.count()

    print('After unknown in both TNS and Fink classes: ', tot_df1_known, tot_df1_known/tot)

In [12]:
if debug:
    df1_known.select('v:tns_classification', 'v:classification').show()

In [13]:
# filter unwanted classes in TNS
df1_filtered_tns = df1_known.filter(~F.col('v:tns_classification').isin(not_wanted_tns))

if debug:
    
    # get number of surviving alerts
    tot_filtered_tns = df1_filtered_tns.count()

    print('After unwanted TNS classes: ', tot_filtered_tns, tot_filtered_tns/tot)

In [14]:
if debug:
    df1_filtered_tns.select('v:tns_classification').show()

In [15]:
# filter only objects in the required Fink (ML), SIMBAD or TNS required classes
df1_filtered_classes = df1_filtered_tns.filter(filter_classes(F.col('v:classification'))| \
                                               filter_known(F.col('v:tns_classification')))

if debug:
    # get number of surviving alerts
    tot_df1_filtered_classes = df1_filtered_classes.count()

    print('After filtering classes: ', tot_df1_filtered_classes, tot_df1_filtered_classes/tot)

In [16]:
if debug:
    df1_filtered_classes.select('v:classification', 'v:tns_classification').show()

In [17]:
# filter objects that do not have a counterpart in Mangrove
df1_filtered_final = df1_filtered_classes.filter(filter_hyperLEDA(F.col("mangrove")) &  \
                                                filter_2MASS(F.col("mangrove")))

if debug:
    # get number of alerts without a mangrove host 
    tot_df1_filtered_final = df1_filtered_final.count()

    print('After filtering mangrove: ', tot_df1_filtered_final, tot_df1_filtered_final/tot) 

In [18]:
if debug:
    df1_filtered_final.select('mangrove').show()

In [19]:
# select which columns to keep
cols_ = [i for i in df1_filtered_final.columns if i != 'objectId']
df1_output = df1_filtered_final.select(cols0 + cols_)

# aggregate output by objectId
df1_agg = df1_output.groupBy('objectId')\
    .agg(*[F.collect_list(col).alias(col) for col in cols_])

In [20]:
if debug:
    df1_agg.select('v:classification').show()

In [21]:
if debug:
    # count number of unique objects
    tot_obj = df1_agg.count()

    print(tot_obj)

In [None]:
# write result to file
df1_agg.write.mode('overwrite').parquet('SIMBAD_not_in_MANGROVE_with_candidates_2022_2023')



In [None]:
# read the size of the output
! hdfs dfs -du -h | grep SIMBAD_not_in_MANGROVE_with_candidates

In [None]:
pwd