# Babamul streaming example

In this notebook we'll read from Babamul streams to fetch alerts of interest
and display them interactively.

ðŸš¨ **Before running, be sure to copy `.env.example` to `.env` and fill in
your Babamul Kafka credentials from https://babamul.caltech.edu/profile.** ðŸš¨

In [None]:
import dotenv
from astropy.coordinates import SkyCoord
from astropy.table import Table
from tqdm.notebook import tqdm

import babamul
from babamul import LsstAlert, ZtfAlert

In [None]:
dotenv.load_dotenv()

In this notebook, we'll listen to LSST alerts that are likely to be hosted (based on matches to the LSPSC catalog), and that match with a ZTF object:

In [None]:
topics = ["babamul.lsst.ztf-match.hosted"]

In [None]:
# first, let's just grab one alert and inspect it
# Note: initializing a consumer can take a few seconds, based on your distance
#       from the Kafka cluster and the number of topics you subscribe to. Once
# .      alerts start flowing, things get a lot faster!
with babamul.AlertConsumer(
    topics=topics,
    offset="earliest",
    auto_commit=False,
    timeout=15,
) as consumer:
    for alert in consumer:
        alert.show()
        break

Next, we define a pretty basic filtering function that will be applied to incoming alerts, to select
a subset of them for display we might be interested in:

In [None]:
def is_relevant(alert: ZtfAlert | LsstAlert):
    if isinstance(alert, ZtfAlert):
        age = alert.candidate.jd - alert.candidate.jdstarthist
        # Only consider alerts for objects that were observed for the first time within the last 60 days, and more than 3 days old
        if age < 3 or age > 60:
            return False
    if isinstance(alert, LsstAlert):
        if alert.candidate.isDipole:
            return False
        if alert.candidate.psfFlux_flag:
            return False
        # Only consider alerts with a reasonable PSF fit (using a threshold on the reduced chi2 of the PSF fit)
        if alert.candidate.psfChi2 / alert.candidate.psfNdata > 10.0:
            return False
        if alert.candidate.extendedness is None:
            return False
        if alert.candidate.shape_flag:
            return False
        if alert.candidate.centroid_flag:
            return False

    # Only consider alerts with a real-bogus score above 0.2
    if alert.drb < 0.4:
        return False

    # Exclude alerts that are likely to be known SSOs
    if alert.properties.rock:
        return False

    # Exclude alerts that are likely to be stars or near bright stars
    if alert.properties.star:  # using PS1 PSC for ZTF, using LSPSC for LSST
        return False

    if alert.properties.near_brightstar:  # same here
        return False

    # Only consider positive subtractions (i.e. candidate is brighter than the reference)
    if not alert.candidate.isdiffpos:
        return False

    # Exclude alerts that are not "stationary", i.e. not detected at least twice with sufficient time separation
    # (helps rule out uncataloged asteroids, some boguses, ...)
    return alert.properties.stationary

Let's run the filter on the the stream and selec up to `limit` alerts that match our criteria, and display them in a table:

In [None]:
alerts = []
limit = 2000  # we'll stop after we find 2000 relevant alerts, more than enough for our purposes here

with babamul.AlertConsumer(
    topics=topics,
    offset="earliest",
    auto_commit=False,
    timeout=15,
) as consumer:
    for alert in tqdm(consumer, desc="Filtering relevant alerts"):
        if is_relevant(alert):
            alerts.append(alert)
        if len(alerts) >= limit:
            break

print(f"Fetched {len(alerts)} alerts.")

Before we display the alerts, since we may have more than one alert per object,
let's deduplicate the alerts by objectId, to make our life a little easier and
avoid fetching data for the same object multiple times when we add cross-matches in the next step:

In [None]:
# We deduplicate by objectId, keeping the most recent alert for each
# scan for these alerts
alerts_to_scan = {}
for a in alerts:
    if (
        a.objectId not in alerts_to_scan
        or a.candidate.jd > alerts_to_scan[a.objectId].candidate.jd
    ):
        alerts_to_scan[a.objectId] = a
alerts_to_scan = list(alerts_to_scan.values())
print(f"After deduplication, {len(alerts_to_scan)} unique alerts remain.")

Next, let's use the `add_cross_matches` helper function to fetch cross-match information 
(from the API) for our selected alerts, so we can use them to keep filtering:

In [None]:
babamul.add_cross_matches(alerts_to_scan)

In [None]:
def is_revelant_using_crossmatches(alert: ZtfAlert | LsstAlert):
    cross_matches: babamul.models.CrossMatches = alert.get_cross_matches()
    if not cross_matches.ned or len(cross_matches.ned) == 0:
        return False
    if any(gal.distance_arcsec < 1.0 for gal in cross_matches.ned):
        return False
    if cross_matches.gaia and len(cross_matches.gaia) > 0:
        return False
    if cross_matches.lspsc and any(
        m.score > 0.5
        and (
            m.distance_arcsec < 2.0
            or (m.distance_arcsec < 30.0 and m.mag_white < 15.0)
            or (m.distance_arcsec < 10.0 and m.mag_white < 16.0)
            or (m.distance_arcsec < 5.0 and m.mag_white < 17.0)
        )
        for m in cross_matches.lspsc
    ):
        return False
    return not (
        cross_matches.milliquasar and len(cross_matches.milliquasar) > 0
    )

In [None]:
alerts_to_scan_filtered = []

for alert in tqdm(alerts_to_scan):
    if is_revelant_using_crossmatches(alert):
        alerts_to_scan_filtered.append(alert)

print(
    f"After excluding alerts using cross-matches with NED, Gaia, LSPSC, and MilliQuasar, {len(alerts_to_scan_filtered)} alerts remain that are likely not nuclear transients."
)

## View alerts interactively

The cell below allows us to page through the alerts we have selected and view their contents interactively.

In [None]:
babamul.jupyter.scan_alerts(
    sorted(
        alerts_to_scan_filtered, key=lambda a: a.candidate.magpsf, reverse=True
    )
)

Now maybe we can get a little more specific, and only look at alerts that have a nearby host in NED:

In [None]:
alerts = []
limit = 100  # we'll stop after we find 2000 relevant alerts, more than enough for our purposes here

with babamul.AlertConsumer(
    topics=topics,
    offset="earliest",
    auto_commit=False,
    timeout=15,
) as consumer:
    for alert in tqdm(consumer, desc="Filtering relevant alerts"):
        if is_relevant(alert):
            alerts.append(alert)
        if len(alerts) == limit:
            babamul.add_cross_matches(alerts)
            alerts = [a for a in alerts if is_revelant_using_crossmatches(a)]
        if len(alerts) >= limit:
            break

In [None]:
def has_nearby_host(alert: ZtfAlert | LsstAlert):
    # we use the API to retrieve the cross-matches for this alert's object
    cross_matches: babamul.models.CrossMatches = alert.get_cross_matches()
    # check if we have at least one match w/ NED
    if not cross_matches.ned:
        return False
    # at least one of them should be at z <= 0.033 (~150 Mpc) and distance_kpc < 30
    return any(
        m.z is not None and m.z <= 0.033 and m.distance_kpc < 30
        for m in cross_matches.ned
    )

In [None]:
limit = 10
nearby_alerts = []

babamul.add_cross_matches(alerts_to_scan, n_threads=10)

for a in tqdm(alerts, desc="Filtering nearby alerts"):
    if len(nearby_alerts) >= limit:
        break
    if has_nearby_host(a):
        nearby_alerts.append(a)

print(f"Selected {len(nearby_alerts)} nearby alerts.")

In [None]:
babamul.jupyter.scan_alerts(nearby_alerts)

In [None]:
# Now let's demonstrate using the API to look for alerts
survey, object_id = nearby_alerts[0].survey, nearby_alerts[0].objectId
obj = babamul.api.get_object(survey, object_id)
obj.show()

In [None]:
survey = "ZTF"
alerts_from_api = babamul.api.get_alerts(
    survey,
    start_jd=2460745.5,
    end_jd=2460746.5,
    is_rock=False,
    is_star=False,
    is_near_brightstar=False,
    is_stationary=True,
    min_drb=0.5,
)

print(f"Found {len(alerts_from_api)} alerts matching initial criteria.")

alerts_filtered_from_api = []
for a in tqdm(alerts_from_api, desc="Filtering relevant alerts"):
    if not is_relevant(a):
        continue

babamul.add_cross_matches(alerts_from_api, n_threads=10)
nearby_alerts_from_api = []
for a in tqdm(alerts_from_api, desc="Filtering nearby alerts"):
    if has_nearby_host(a):
        nearby_alerts_from_api.append(a)

print(
    f"Found {len(nearby_alerts_from_api)}/{len(alerts_from_api)} nearby alerts."
)

In [None]:
babamul.jupyter.scan_alerts(nearby_alerts_from_api)

In [None]:
survey = "ZTF"
ra, dec = (
    nearby_alerts_from_api[0].candidate.ra,
    nearby_alerts_from_api[0].candidate.dec,
)
coordinates = SkyCoord(ra=ra, dec=dec, unit="deg", frame="icrs")
radius_arcsec = 2.0
conesearch_results = babamul.api.cone_search_alerts(
    survey,
    coordinates,
    radius_arcsec,
    is_rock=False,
    is_star=False,
    is_near_brightstar=False,
    is_stationary=True,
    min_drb=0.5,
)

for pos_name, alerts in conesearch_results.items():
    print(f"Found {len(alerts)} alerts for position {pos_name}.")
    for a in alerts:
        a.show()

In [None]:
# let's load the NED LVS catalog at .data/NEDLVS_20250602.fits
# you can download it from https://ned.ipac.caltech.edu/NED::LVS/
ned_lvs: Table = Table.read("data/NEDLVS_20250602.fits")
print(f"Loaded NED LVS catalog with {len(ned_lvs):,} entries.")
# let's keep the nearby galaxies with z <= 0.03 (~150 Mpc)
nearby_galaxies = ned_lvs[ned_lvs["z"] <= 0.03]
print(f"Found {len(nearby_galaxies):,} nearby galaxies at less than 150 Mpc.")

In [None]:
survey, radius_arcsec = "LSST", 60.0
conesearch_results: dict[str, ZtfAlert | LsstAlert] = (
    babamul.api.cone_search_alerts(
        survey,
        nearby_galaxies,
        radius_arcsec=radius_arcsec,
        is_rock=False,
        is_star=False,
        is_near_brightstar=False,
        is_stationary=True,
        min_drb=0.8,
        n_threads=8,
        batch_size=500,
    )
)

galaxies_with_alerts = [
    name for name, alerts in conesearch_results.items() if alerts
]
print(
    f"Found {len(galaxies_with_alerts)} nearby galaxies with alerts within {radius_arcsec} arcseconds."
)

From here on out, the notebook deserved some refactoring to use the latest from the API, but I haven't gotten around to it yet!

In [None]:
hosted_alerts: dict[str, list[ZtfAlert | LsstAlert]] = {}
for alerts in conesearch_results.values():
    for a in alerts:
        a: LsstAlert
        if not a.candidate.isdiffpos:
            continue
        if a.candidate.isDipole:
            continue
        if a.candidate.magpsf > 21.5:
            continue
        if a.candidate.psfFlux_flag:
            continue
        a.get_cross_matches()
        if a.cross_matches.milliquasar:
            continue
        photometry = a.get_photometry()
        # keep those that are < 60 days old
        # and with an amplitude > 1 mag
        first_jd = a.candidate.jd
        faintest_mag, brightest_mag = a.candidate.magpsf, a.candidate.magpsf
        for p in photometry:
            if p.isdiffpos and p.magpsf:
                if p.jd < first_jd:
                    first_jd = p.jd
                if p.magpsf > faintest_mag:
                    faintest_mag = p.magpsf
                if p.magpsf < brightest_mag:
                    brightest_mag = p.magpsf
        age = a.candidate.jd - first_jd if first_jd else None
        amplitude = (
            faintest_mag - brightest_mag
            if faintest_mag and brightest_mag
            else None
        )
        if age is None or age > 60 or age < 10:
            continue
        if amplitude is None or amplitude < 1:
            continue
        # if not is_relevant_complex(a):
        #     continue
        if (
            a.objectId not in hosted_alerts
            or a.candidate.jd > hosted_alerts[a.objectId].candidate.jd
        ):
            hosted_alerts[a.objectId] = a
hosted_alerts = list(hosted_alerts.values())

print(
    f"Found {len(hosted_alerts)} unique hosted alerts around nearby galaxies."
)

In [None]:
babamul.jupyter.scan_alerts(hosted_alerts)

In [None]:
# let's load data/tns_public_objects.csv, which is a dump of all public objects in the Transient Name Server as of 2026-02-12
tns_objects: Table = Table.read("data/tns_public_objects.csv")
print(f"Loaded TNS public objects catalog with {len(tns_objects):,} entries.")

# let's remove those with type AGN
tns_objects = tns_objects[tns_objects["type"] != "AGN"]
# let's remove the already classified supernovae
tns_objects = tns_objects[tns_objects["name_prefix"] != "SN"]
print(
    f"Filtered TNS objects to {len(tns_objects):,} entries after removing AGN."
)

In [None]:
# do a cone search against the alerts for the tns_sne
survey, radius_arcsec = "LSST", 1.5
conesearch_results = babamul.api.cone_search_alerts(
    survey,
    tns_objects,
    radius_arcsec=radius_arcsec,
    n_threads=8,
    batch_size=500,
)

tns_sne_candidates = [
    name for name, alerts in conesearch_results.items() if alerts
]
print(
    f"Found {len(tns_sne_candidates)} TNS supernovae with alerts within {radius_arcsec} arcseconds:"
)

In [None]:
# scan for these alerts
alerts_to_scan = {}
for _, alerts in conesearch_results.items():
    for a in alerts:
        if (
            a.objectId not in alerts_to_scan
            or a.candidate.jd > alerts_to_scan[a.objectId].candidate.jd
        ):
            alerts_to_scan[a.objectId] = a
print(f"Found {len(alerts_to_scan)} unique alerts around TNS supernovae.")

babamul.jupyter.scan_alerts(list(alerts_to_scan.values()))