# Babamul streaming example

In this notebook we'll read from Babamul streams to fetch alerts of interest
and save them locally,
e.g., to use in a machine learning pipeline.

Before running, be sure to copy `.env.example` to `.env` and fill in
your Babamul Kafka credentials from https://babamul.caltech.edu/profile.

In [None]:
# All imports
import os
from datetime import UTC, datetime
from glob import glob

import polars as pl
import duckdb
import dotenv

import babamul
from babamul import ZtfAlert, LsstAlert

In [None]:
# First, load secrets from a local .env file if present

dotenv.load_dotenv()

In [None]:
# Define some parameters
limit = 1000
group_id = "example-group"
topics = ["babamul.lsst.ztf-match.hosted", "babamul.ztf.lsst-match.hosted"]
topics = ["^babamul.*"]

In [None]:
alerts = []

with babamul.AlertConsumer(
    topics=topics,
    offset="latest",
    group_id=group_id,
    timeout=10,
) as consumer:
    for n, alert in enumerate(consumer):
        if n >= limit:
            break
        alerts.append(alert)

# Create a Polars DataFrame of all alerts and merge them with the rest of our
# local cache
alerts_dir = "data/alerts"
os.makedirs(alerts_dir, exist_ok=True)

new_df = (
    pl.DataFrame([alert.model_dump() for alert in alerts])
    .unnest("candidate", separator="_")
    .unnest("properties", separator="_")
    .with_columns(
        pl.col("candidate_jd")
        .map_elements(
            lambda jd: datetime.fromtimestamp(
                (jd - 2440587.5) * 86400, tz=UTC
            ).date(),
            return_dtype=pl.Date,
        )
        .alias("observation_date")
    )
)

# Merge with existing alerts cache
parquet_glob = f"{alerts_dir}/**/*.parquet"
parquet_files = glob(parquet_glob, recursive=True)
if parquet_files:
    existing = pl.scan_parquet(parquet_files).collect()
    n_existing = len(existing)
    df_combined = pl.concat([existing, new_df], how="vertical_relaxed")
else:
    n_existing = 0
    df_combined = new_df

df_combined.write_parquet(alerts_dir, partition_by="observation_date")
print(
    f"Saved {len(df_combined)} alerts to {alerts_dir} "
    f"({len(new_df)} new, {n_existing} existing)"
)

In [None]:
# Scan through all of our saved alerts with DuckDB and search for interesting
# candidates

parquet_glob = f"{alerts_dir}/**/*.parquet"

relevant_alerts = duckdb.sql(
    f"""
    SELECT *
    FROM read_parquet('{parquet_glob}')
    WHERE
        candidate_drb >= 0.8
        AND (candidate_jd - COALESCE(candidate_jdstarthist, candidate_jd)) <= 30
        AND (
            candidate_ssdistnr IS NULL
            OR candidate_ssdistnr >= 12
            OR candidate_ssdistnr < 0
        )
        AND (properties_star IS NULL OR properties_star = FALSE)
        AND (properties_near_brightstar IS NULL OR properties_near_brightstar = FALSE)
        AND (candidate_isdiffpos IS NULL OR candidate_isdiffpos = TRUE)
        AND properties_stationary = TRUE
    """
).pl()

relevant_alerts

In [None]:
# Let's view one of our relevant alerts' cutouts
def row_to_alert(row: dict) -> ZtfAlert | LsstAlert:
    survey = row.get("survey", "").lower()
    cls = (
        ZtfAlert
        if survey == "ztf"
        else LsstAlert
        if survey == "lsst"
        else None
    )
    if cls is None:
        raise ValueError(f"Unknown survey: {survey}")
    # Extract flattened candidate fields
    candidate = {
        k.removeprefix("candidate_"): v
        for k, v in row.items()
        if k.startswith("candidate_")
    }
    # Extract flattened properties fields
    properties = {
        k.removeprefix("properties_"): v
        for k, v in row.items()
        if k.startswith("properties_")
    }
    # Build alert dict with required and optional fields
    alert_dict = {
        "candid": row["candid"],
        "objectId": row["objectId"],
        "candidate": candidate,
        "properties": properties,
    }
    # Add nested fields if they exist
    nested_fields = [
        "prv_candidates",
        "prv_nondetections",
        "fp_hists",
        "survey_matches",
    ]
    for field in nested_fields:
        if field in row and row[field] is not None:
            alert_dict[field] = row[field]
    # Add any remaining top-level fields (like cutouts, survey, etc.)
    excluded_prefixes = ("candidate_", "properties_")
    excluded_fields = {"candid", "objectId"} | set(nested_fields)
    for key, value in row.items():
        if (
            not any(key.startswith(p) for p in excluded_prefixes)
            and key not in excluded_fields
        ) and value is not None:
            alert_dict[key] = value

    return cls.model_validate(alert_dict)


alert = row_to_alert(relevant_alerts.row(1, named=True))
alert.plot_cutouts()
print(alert.properties.model_dump_json(indent=2))