# Babamul streaming example

In this notebook we'll read from Babamul streams to fetch alerts of interest
and save them locally,
e.g., to use in a machine learning pipeline.

Before running, be sure to copy `.env.example` to `.env` and fill in
your Babamul Kafka credentials from https://babamul.caltech.edu/profile.

In [None]:
# All imports
import os
from datetime import UTC, datetime
from glob import glob

import dotenv
import duckdb
import ipywidgets as widgets
import polars as pl
from IPython.display import display

import babamul
from babamul import LsstAlert, ZtfAlert

In [None]:
# First, load secrets from a local .env file if present
dotenv.load_dotenv()

In [None]:
# Define some parameters
limit = 1000
group_id = "example-group"
topics = ["babamul.lsst.ztf-match.hosted", "babamul.ztf.lsst-match.hosted"]
topics = ["^babamul.*"]

## Fetch new alerts from the streams

Below we read from the stream and merge new alerts into our local cache
in `data/alerts` so we can further process them later.

In [None]:
alerts = []

with babamul.AlertConsumer(
    topics=topics,
    offset="latest",
    group_id=group_id,
    timeout=10,
) as consumer:
    for n, alert in enumerate(consumer):
        if n >= limit:
            break
        alerts.append(alert)

# Create a Polars DataFrame of all alerts and merge them with the rest of our
# local cache
alerts_dir = "data/alerts"
os.makedirs(alerts_dir, exist_ok=True)

# Only process if we got new alerts
parquet_glob = f"{alerts_dir}/**/*.parquet"
if alerts:
    new_df = pl.DataFrame(
        [alert.model_dump() for alert in alerts]
    ).with_columns(
        pl.col("candidate")
        .struct.field("jd")
        .map_elements(
            lambda jd: datetime.fromtimestamp(
                (jd - 2440587.5) * 86400, tz=UTC
            ).date(),
            return_dtype=pl.Date,
        )
        .alias("obs_date")
    )
    # Merge with existing alerts cache
    parquet_files = glob(parquet_glob, recursive=True)
    if parquet_files:
        existing = pl.scan_parquet(parquet_files).collect()
        n_existing = len(existing)
        df_combined = pl.concat([existing, new_df], how="vertical_relaxed")
    else:
        n_existing = 0
        df_combined = new_df
    df_combined.write_parquet(alerts_dir, partition_by="obs_date")
    print(
        f"Saved {len(df_combined)} alerts to {alerts_dir} "
        f"({len(new_df)} new, {n_existing} existing)"
    )
else:
    print("No new alerts received from stream")

## View alerts interactively

The cell below allows us to page through the alerts of interest based on a
SQL query.

In [None]:
# Define a SQL query to find candidates that are likely real astrophysical
# transients
sql = f"""
    SELECT *
    FROM read_parquet('{parquet_glob}', union_by_name=true)
    WHERE
        candidate.drb >= 0.8
        AND (
            candidate.jd - COALESCE(candidate.jdstarthist, candidate.jd)
        ) <= 30
        AND (
            candidate.ssdistnr IS NULL
            OR candidate.ssdistnr >= 12
            OR candidate.ssdistnr < 0
        )
        AND (properties.star IS NULL OR properties.star = FALSE)
        AND (
            properties.near_brightstar IS NULL
            OR properties.near_brightstar = FALSE
        )
        AND (candidate.isdiffpos IS NULL OR candidate.isdiffpos = TRUE)
        AND properties.stationary = TRUE
    ORDER BY obs_date DESC
"""

relevant_alerts = duckdb.sql(sql).pl()


def row_to_alert(row: dict) -> ZtfAlert | LsstAlert:
    """Convert a Polars row (as a dict) back into a ZtfAlert or LsstAlert
    instance.
    """
    survey = row.get("survey", "").lower()
    cls = (
        ZtfAlert
        if survey == "ztf"
        else LsstAlert
        if survey == "lsst"
        else None
    )
    if cls is None:
        raise ValueError(f"Unknown survey: {survey}")
    return cls.model_validate(row)


# Create buttons and output area
prev_button = widgets.Button(description="← Previous")
next_button = widgets.Button(description="Next →")
info_label = widgets.HTML()
output = widgets.Output()

# State
current_idx = [0]


def update_display():
    if len(relevant_alerts) == 0:
        info_label.value = "No alerts found"
        return
    idx = current_idx[0]
    row = relevant_alerts.row(idx, named=True)
    obs_date = row["obs_date"]
    alert = row_to_alert(row)
    info_label.value = (
        f"<b>Alert {idx + 1} of {len(relevant_alerts)}</b>"
        f" (observed {obs_date} by {alert.survey})"
    )
    with output:
        output.clear_output(wait=True)
        try:
            alert.plot_cutouts()
        except Exception as e:
            print(f"Could not plot cutouts: {e}")
        print("\n" + "=" * 50)
        print(alert.properties.model_dump_json(indent=2))


def on_prev(b):
    if current_idx[0] > 0:
        current_idx[0] -= 1
        update_display()


def on_next(b):
    if current_idx[0] < len(relevant_alerts) - 1:
        current_idx[0] += 1
        update_display()


prev_button.on_click(on_prev)
next_button.on_click(on_next)
# Layout
buttons = widgets.HBox([prev_button, next_button])
container = widgets.VBox([info_label, buttons, output])
display(container)
update_display()