# Source clustering

This notebook clusters sources from nightly validation to generate object light curves.

In [None]:
import lsdb
import matplotlib.pyplot as plt
import pandas as pd
import tempfile

import lsst.daf.butler as dafButler
from lsst.summit.utils import ConsDbClient

from dask.distributed import Client
from lsdb.core.search import ConeSearch
from pathlib import Path
from tqdm import tqdm

pd.set_option('display.max_rows', 100)

In [None]:
#%pip install git+https://github.com/astronomy-commons/lsdb.git@sean/nested-crossmatch
base_output_dir = Path("/sdf/data/rubin/shared/lsdb_commissioning/tmp")

### Query for all recent visits

First let's get all the visits from April 18 to 20.

In [None]:
start_day_obs, end_day_obs = 20250418, 20250420

In [None]:
with open("token", "r") as f:
    token = f.read()
client = ConsDbClient(f"https://user:{token}@usdf-rsp.slac.stanford.edu/consdb")
visits = client.query(f"SELECT * FROM cdb_lsstcam.visit1 WHERE day_obs >= {start_day_obs} AND day_obs <= {end_day_obs} and science_program = 'BLOCK-365'").to_pandas()

In [None]:
num_visits = len(visits)
print(f"Found {num_visits} visits from {start_day_obs} to {end_day_obs}")

### Initialize the Butler

In [None]:
repo = "embargo"
instrument = "LSSTCam"
collection_all = "LSSTCam/runs/nightlyValidation"
butler = dafButler.Butler(repo, collections=collection_all, instrument=instrument)

### Create object table

In [None]:
# Find visit of best dimm_seeing
visits = visits.sort_values("dimm_seeing")
visits = visits[~visits["dimm_seeing"].isna()]
visit_best_dimm_seeing = visits.iloc[0]
visit_best_dimm_seeing

In [None]:
# Find the name of the collection for the day_obs: 20250418
day_obs = visit_best_dimm_seeing["day_obs"]
day_collection = butler.registry.queryCollections(f"LSSTCam/runs/nightlyValidation/{day_obs}*7")[0]
day_collection

In [None]:
butler = dafButler.Butler(repo, collections=day_collection, instrument=instrument)
object_df = butler.get('single_visit_star', visit=visit_best_dimm_seeing["visit_id"], instrument=instrument).to_pandas()
object_df

Let's transform this object dataframe into a HATS catalog:

In [None]:
object_cat = lsdb.from_dataframe(object_df)
# There is a bug using the from_dataframe output directly:
# A workaround is to save the catalog to transient storage and load it back
object_cat.to_hats(base_output_dir / "object")
object_cat = lsdb.read_hats(base_output_dir / "object")
object_cat

In [None]:
object_cat.plot_pixels()

### Query for all sources

Let's query the Butler to get the sources for all the visits.

In [None]:
def _get_butler_for_day(day_obs):
    day_collection = butler.registry.queryCollections(f"LSSTCam/runs/nightlyValidation/{day_obs}*7")[0]
    return dafButler.Butler(repo, collections=day_collection, instrument=instrument)

def _filter_source_df(df):
    # Filter non-primary detections
    df = df[df['detect_isPrimary']]
    # Those with invalid coord_ra
    df = df.dropna(subset=["coord_ra"])
    # Or the fake detections 
    df = df[df['sky_source'] == False]
    # Cut only to "i" band
    df = df[df["band"] == "i"]
    # Reduce number of columns (for efficiency)
    return df[["ra","dec","sourceId","band","psfFlux","psfFluxErr"]]

def get_sources_for_day(day_visits):
    # Initialize butler for current day
    day_obs = day_visits["day_obs"].iloc[0]
    day_butler = _get_butler_for_day(day_obs)
    ids, mjds = day_visits["visit_id"], day_visits["exp_midpt_mjd"]

    day_dfs = []
    # Get the sources for each visit
    for visit_id, visit_mjd in tqdm(zip(ids, mjds)):
        try:
            df = day_butler.get(
                'single_visit_star', visit=visit_id, instrument=instrument
            ).to_pandas()
            df = _filter_source_df(df)
            if not df.empty:
                df["visit_id"] = visit_id
                df["mjd"] = visit_mjd
                day_dfs.append(df)
        except Exception as e:
            print(f"Skipping visit {visit_id} due to error: {e}")

    print(f"Loaded {len(day_dfs)} dataframes from {day_collection}")
    return pd.concat(day_dfs, ignore_index=True)

In [None]:
# Took roughly 2min
all_dfs = [get_sources_for_day(day_visits) for _, day_visits in visits.groupby("day_obs")]
sources_df = pd.concat(all_dfs, ignore_index=True)
sources_df

In [None]:
# Import with lsdb
source_cat = lsdb.from_dataframe(sources_df)
# There is a bug using the from_dataframe output directly:
# A workaround is to save the catalog to transient storage and load it back
source_cat.to_hats(base_output_dir / "source")
source_cat = lsdb.read_hats(base_output_dir / "source")
source_cat

Let's remove the few sources that are distant from the main cluster:

In [None]:
source_cat.plot_pixels()
cone = ConeSearch(ra=218, dec=-15, radius_arcsec=12*3600)
cone.plot(fc="#00000000", ec="red")

In [None]:
source_cat = source_cat.cone_search(ra=cone.ra, dec=cone.dec, radius_arcsec=cone.radius_arcsec)
source_cat

### Construct light curves

In [None]:
tmp_path = tempfile.TemporaryDirectory()
tmp_dir = tmp_path.name
client = Client(n_workers=16, threads_per_worker=1, local_directory=tmp_dir)
client

In [None]:
# Get light curves for the catalog
lc_cat = object_cat.crossmatch_nested(source_cat, radius_arcsec=0.2, n_neighbors=num_visits, nested_column_name="lc")
lc_cat

In [None]:
# Took roughly 30sec
object_lc = lc_cat.reduce(lambda mjd: {"nobs": mjd.size}, "lc.mjd", meta={"nobs": int}, append_columns=True)
object_lc = object_lc.query("nobs > 10")
object_lc = object_lc.compute()
object_lc.head()

In [None]:
client.close()
tmp_path.cleanup()

### Plot light curves

In [None]:
# Grab a single light curve
lc = object_lc.iloc[10]["lc"].sort_values("mjd")
lc

In [None]:
COLORS = {
    "u": "#56b4e9",
    "g": "#009e73",
    "r": "#f0e442",
    "i": "#cc79a7",
    "z": "#d55e00",
    "y": "#0072b2",
}

def plot_rubin_lc(lc, flux_col, fluxerr_col):
    _, ax = plt.subplots()
    for band, color in COLORS.items():
        band_lc = lc.query(f"band == '{band}'")
        flux, fluxerr = band_lc[flux_col], band_lc[fluxerr_col]
        ax.errorbar(
            band_lc["mjd"],
            flux,
            fluxerr,
            fmt="o",
            label=band,
            color=color,
            alpha=1,
            markersize=5,
            capsize=3,
            elinewidth=1,
        )
    ax.set_xlabel("MJD")
    ax.set_ylabel("Flux")
    ax.invert_yaxis()
    ax.legend(loc="lower right", fontsize=12)

plot_rubin_lc(lc, "psfFlux", "psfFluxErr")