## Running PhotoD with LSDB

In [13]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import lsdb
import matplotlib.pyplot as plt
import nested_pandas as npd
import numpy as np
import pandas as pd

from dask import delayed
from dask.distributed import Client, get_worker
from scipy.interpolate import griddata
from photod.bayes import makeBayesEstimates3d
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams

In [14]:
# s82_stripe_url = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82_stripe_url = "/mnt/beegfs/scratch/data/Gaia-SDSS/hats/Gaia-SDSS_w_extinctions"
s82_stripe_catalog = lsdb.read_hats(s82_stripe_url)
s82_stripe_catalog

Unnamed: 0_level_0,random_index,ra,dec,phot_g_mean_mag,parallax,parallax_error,source_id,r_med_geo,r_lo_geo,r_hi_geo,r_med_photogeo,r_lo_photogeo,r_hi_photogeo,flag,objid,type,psfmag_u,psfmag_g,psfmag_r,psfmag_i,psfmag_z,psfmagerr_u,psfmagerr_g,psfmagerr_r,psfmagerr_i,psfmagerr_z,extinction_u,extinction_g,extinction_r,extinction_i,extinction_z,Norder,Dir,Npix
npartitions=63,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1
"Order: 1, Pixel: 0",int64[pyarrow],double[pyarrow],double[pyarrow],float[pyarrow],double[pyarrow],float[pyarrow],int64[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],string[pyarrow],int64[pyarrow],int16[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],float[pyarrow],uint8[pyarrow],uint64[pyarrow],uint64[pyarrow]
"Order: 1, Pixel: 1",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 0, Pixel: 10",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 0, Pixel: 11",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [15]:
# prior_map_url = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
prior_map_url = "/mnt/beegfs/scratch/data/priors/TRILEGAL/level_5/"
prior_map_catalog = lsdb.read_hats(prior_map_url)
prior_map_catalog

FileNotFoundError: Failed to read HATS at location /mnt/beegfs/scratch/data/priors/TRILEGAL/level_5

In [8]:
def merging_function(partition, map_partition, partition_pixel, map_pixel, globalParams, worker_dict, *kwargs):
    priorGrid = {}
    for rind, r in enumerate(np.sort(map_partition["rmag"].to_numpy())):
        # interpolate prior map onto locus Mr-FeH grid
        Z = map_partition[map_partition["rmag"] == r]
        Zval = np.frombuffer(Z.iloc[0]["kde"], dtype=np.float64).reshape((96, 36))
        X = np.frombuffer(Z.iloc[0]["xGrid"], dtype=np.float64).reshape((96, 36))
        Y = np.frombuffer(Z.iloc[0]["yGrid"], dtype=np.float64).reshape((96, 36))
        points = np.array((X.flatten(), Y.flatten())).T
        values = Zval.flatten()
        # actual (linear) interpolation
        priorGrid[rind] = griddata(
            points, values, (globalParams.locusData["FeH"], globalParams.locusData[globalParams.MrColumn]), method="linear", fill_value=0
        )
    gpu_device = jax.devices()[worker_dict[get_worker().id]]
    with jax.default_device(gpu_device):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=10)
    # Append ra and dec to be able to later crossmatch
    return pd.concat([partition[["ra", "dec"]], npd.NestedFrame(estimatesDf)], axis=1)

In [9]:
# locus_path = "/home/scampos/photoD/data/MSandRGBcolors_v1.3.txt"
locus_path = "../../data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locus_path)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=10, kFeH=2)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)

subsampled locus 2D grid in FeH and Mr from 51 1559 to: 25 155


In [10]:
quantile_cols = [f"{statisticsName}_quantile_{quantile}" for statisticsName in ["Mr","FeH","Ar","Qr"] for quantile in ["lo","median","hi"]]
estimate_cols = sorted([*quantile_cols,"MrdS","FeHdS","ArdS"])
col_names = ["ra","dec","glon","glat","chi2min",*estimate_cols]
meta = npd.NestedFrame.from_dict({ col: pd.Series([], dtype=np.float32) for col in col_names })
meta.index.name = "_healpix_29"
meta

Unnamed: 0_level_0,ra,dec,glon,glat,chi2min,Ar_quantile_hi,Ar_quantile_lo,Ar_quantile_median,ArdS,FeH_quantile_hi,FeH_quantile_lo,FeH_quantile_median,FeHdS,Mr_quantile_hi,Mr_quantile_lo,Mr_quantile_median,MrdS,Qr_quantile_hi,Qr_quantile_lo,Qr_quantile_median
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1


In [11]:
def get_worker_dict():
    res = s82_stripe_catalog._ddf.partitions[0:5].map_partitions(lambda df: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    worker_ids = np.unique(res["workers"].to_numpy())
    worker_dict = {id: i for i, id in enumerate(worker_ids)}
    print(worker_dict)
    return worker_dict

In [12]:
%%time
with Client(n_workers=4, dashboard_address=':41987') as client:  # threads_per_worker=1, memory_limit="auto",
    worker_dict = get_worker_dict()
    print(worker_dict)
    delayed_global_params = delayed(globalParams)
    merge_lazy = s82_stripe_catalog.merge_map(prior_map_catalog, merging_function, globalParams=delayed_global_params, worker_dict=worker_dict, meta=meta)
    xmatch_result = merge_lazy.compute()
xmatch_result

Perhaps you already have a cluster running?
Hosting the HTTP server on port 35519 instead


{'Worker-33aed7a3-94be-4691-b674-55721006b50a': 0, 'Worker-7971a1ed-2e9b-4f2c-946d-ecd0b2530eea': 1, 'Worker-90b3cfaa-b426-41a2-a806-aa4cdbffad59': 2, 'Worker-c3183a67-6e94-4020-a897-edecee95187e': 3}
{'Worker-33aed7a3-94be-4691-b674-55721006b50a': 0, 'Worker-7971a1ed-2e9b-4f2c-946d-ecd0b2530eea': 1, 'Worker-90b3cfaa-b426-41a2-a806-aa4cdbffad59': 2, 'Worker-c3183a67-6e94-4020-a897-edecee95187e': 3}


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
2024-12-18 17:11:16,017 - distributed.worker - ERROR - Compute Failed
Key:       ('delayed-container-717cdf6ad6baaf15b881330b0453b96b', 137)
State:     executing
Function:  execute_task
args:      ((<function apply at 0x7f195c7793a0>, <function perform_merge_map at 0x7f18d2b62980>, [                     random_index          ra       dec  ...  Norder  Dir  Npix
_healpix_29                                              ...                   
1333082597903076271    1229045708  326.210375 -9.070541  ...       2    0    74
1333082695467910450      91393013  326.190244 -9.068081  ...       2    0    74
1333082774583744613     615844200  326.198001   -9.0499  ...       2    0    74
1333082777060530646     135807859  326.200813 -9

TimeoutError: 

In [9]:
client.close()