# Source clustering

This notebook clusters sources from nightly validation to generate object light curves.

In [None]:
import lsdb
import matplotlib.pyplot as plt
import os
import pandas as pd
import tempfile

import lsst.daf.butler as dafButler

from dask.distributed import Client
from pathlib import Path
from tqdm import tqdm

pd.set_option("display.max_rows", 100)

In [None]:
%pip install git+https://github.com/astronomy-commons/lsdb.git@sean/nested-crossmatch

base_output_dir = Path(f"/sdf/data/rubin/shared/lsdb_commissioning/tmp")

### Querying for recent day_obs

First let's get the visit ids as well as the exposures midpoint mjd.

In [None]:
os.environ["no_proxy"] += ",.consdb"
from lsst.summit.utils import ConsDbClient

client = ConsDbClient("http://consdb-pq.consdb:8080/consdb")
visits = client.query(
    "SELECT * FROM cdb_lsstcam.visit1 WHERE day_obs >= 20250418 AND day_obs <= 20250419 and science_program = 'BLOCK-365'"
).to_pandas()

In [None]:
# Save number of visits for crossmatch
num_visits = len(visits)

### Initializing the Butler

In [None]:
repo = "embargo"
instrument = "LSSTCam"
collection_all = "LSSTCam/runs/nightlyValidation"
butler = dafButler.Butler(repo, collections=collection_all, instrument=instrument)

### Create object table

In [None]:
# Find visit of best dimm_seeing
visits = visits.sort_values("dimm_seeing")
visits = visits[~visits["dimm_seeing"].isna()]
visits.iloc[0]

In [None]:
# Find the name of the collection for the day_obs: 20250418
day_collection = butler.registry.queryCollections(
    "LSSTCam/runs/nightlyValidation/20250418*7"
)[0]
day_collection

In [None]:
butler = dafButler.Butler(repo, collections=day_collection, instrument=instrument)
object_df = butler.get(
    "single_visit_star", visit=2025041800655, instrument=instrument
).to_pandas()

In [None]:
# import with lsdb
object_cat = lsdb.from_dataframe(object_df)
# object_cat.to_hats(base_output_dir / "object")
# object_cat = lsdb.read_hats(base_output_dir / "object")
# object_cat

In [None]:
# choose a single day: 2025_04_20
# visits = visits[visits["day_obs"] == 20250420]
# visits = visits[["visit_id","day_obs","exp_midpt_mjd","dimm_seeing"]]
# visits

### Getting sources for the available nightly runs

In [None]:
def get_sources_for_day(day_visits):
    day_dfs = []

    day_obs = day_visits.name
    day_collection = butler.registry.queryCollections(
        f"LSSTCam/runs/nightlyValidation/{day_obs}*7"
    )[0]
    print(f"Day collection: {day_collection}")
    day_butler = dafButler.Butler(
        repo, collections=day_collection, instrument=instrument
    )

    ids = day_visits["visit_id"]
    mjds = day_visits["exp_midpt_mjd"]

    for visit_id, visit_mjd in tqdm(zip(ids, mjds)):
        try:
            # Get all sources for visit
            df = day_butler.get(
                "single_visit_star", visit=visit_id, instrument=instrument
            ).to_pandas()

            # Do some filtering
            df = df[df["detect_isPrimary"] == True]
            df = df.dropna(subset=["coord_ra"])
            df = df[df["sky_source"] == False]

            # Skip if DataFrame is now empty
            if df.empty:
                continue

            # Add visit_id and mjd columns
            df["visit_id"] = visit_id
            df["mjd"] = visit_mjd

            # Reduce number of columns (for efficiency)
            df = df[
                [
                    "ra",
                    "dec",
                    "sourceId",
                    "band",
                    "mjd",
                    "psfFlux",
                    "psfFluxErr",
                    "visit_id",
                ]
            ]

            day_dfs.append(df)
        except Exception as e:
            print(f"Skipping visit {visit_id} due to error: {e}")

    print(f"Loaded {len(day_dfs)} dataframes.")
    return day_dfs


all_dfs = visits.groupby("day_obs").apply(get_sources_for_day)

In [None]:
# Aggregate all the visit dfs
result = [df for sublist in all_dfs for df in sublist]
final_df = pd.concat(result, ignore_index=True)
final_df

In [None]:
# Cut to get only sources on "i" band
final_df = final_df[final_df["band"] == "i"]
final_df

In [None]:
# See distribution of ra/dec (make sure there are no distant/outlier visits)
plt.figure(figsize=(8, 6))
plt.hist2d(final_df["ra"], final_df["dec"], bins=200, cmap="viridis")
plt.colorbar(label="Counts")
plt.xlabel("RA (degrees)")
plt.ylabel("Dec (degrees)")
plt.title("2D Histogram of RA/Dec")
plt.show()

In [None]:
# Import with lsdb
source_cat = lsdb.from_dataframe(final_df)
# source_cat.to_hats(base_output_dir / "source")
# source_cat = lsdb.read_hats(base_output_dir / "source")
# source_cat

### Construct light curves

In [None]:
tmp_path = tempfile.TemporaryDirectory()
tmp_dir = tmp_path.name

with Client(n_workers=16, threads_per_worker=1, local_directory=tmp_dir):
    lc_cat = object_cat.crossmatch_nested(
        source_cat, radius_arcsec=0.2, n_neighbors=num_visits, nested_column_name="lc"
    )

lc_cat

In [None]:
# Get objects with at least 10 observations
object_lc = lc_cat.reduce(
    lambda mjd: {"nobs": mjd.size}, "lc.mjd", meta={"nobs": int}, append_columns=True
)
object_lc = object_lc.query("nobs > 10")
object_lc = object_lc.compute()

In [None]:
# Grab a light curve
lc = object_lc.iloc[10]["lc"].sort_values("mjd")
lc

In [None]:
# Plot it

COLORS = {
    "u": "#56b4e9",
    "g": "#009e73",
    "r": "#f0e442",
    "i": "#cc79a7",
    "z": "#d55e00",
    "y": "#0072b2",
}


def plot_rubin_lc(lc, flux_col, fluxerr_col):
    _, ax = plt.subplots()
    for band, color in COLORS.items():
        band_lc = lc.query(f"band == '{band}'")
        flux, fluxerr = band_lc[flux_col], band_lc[fluxerr_col]
        ax.errorbar(
            band_lc["mjd"],
            flux,
            fluxerr,
            fmt="o",
            label=band,
            color=color,
            alpha=1,
            markersize=5,
            capsize=3,
            elinewidth=1,
        )
    ax.set_xlabel("MJD")
    ax.set_ylabel("Flux")
    ax.invert_yaxis()
    ax.legend(loc="lower right", fontsize=12)


plot_rubin_lc(lc, "psfFlux", "psfFluxErr")