## Running PhotoD with LSDB

In [1]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import lsdb
import matplotlib.pyplot as plt
import nested_pandas as npd
import numpy as np
import pandas as pd

from dask import delayed
from dask.distributed import Client, get_worker
from scipy.interpolate import griddata
from photod.bayes import makeBayesEstimates3d
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams

In [2]:
s82_stripe_url = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82_stripe_catalog = lsdb.read_hats(s82_stripe_url)
s82_stripe_catalog

Unnamed: 0_level_0,CALIBSTARS,ra,dec,RArms,Decrms,Ntot,Ar,uNobs,umag,ummu,uErr,umrms,umchi2,gNobs,gmag,gmmu,gErr,gmrms,gmchi2,rNobs,rmag,rmmu,rErr,rmrms,rmchi2,iNobs,imag,immu,iErr,imrms,imchi2,zNobs,zmag,zmmu,zErr,zmrms,zmchi2,Norder,Dir,Npix,Mr,FeH,MrEst,MrEstUnc,FeHEst,ug,gr,gi,ri,iz,ugErr,grErr,giErr,riErr,izErr,glon,glat
npartitions=7,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1
"Order: 4, Pixel: 0",string[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int8[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow]
"Order: 4, Pixel: 768",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 2303",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 3071",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [3]:
prior_map_url = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
prior_map_catalog = lsdb.read_hats(prior_map_url)
prior_map_catalog

Unnamed: 0_level_0,rmag,kde,xGrid,yGrid,Norder,Dir,Npix
npartitions=207,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
"Order: 5, Pixel: 0",double[pyarrow],binary[pyarrow],binary[pyarrow],binary[pyarrow],uint8[pyarrow],uint64[pyarrow],uint64[pyarrow]
"Order: 5, Pixel: 1",...,...,...,...,...,...,...
...,...,...,...,...,...,...,...
"Order: 5, Pixel: 12286",...,...,...,...,...,...,...
"Order: 5, Pixel: 12287",...,...,...,...,...,...,...


In [4]:
def merging_function(partition, map_partition, partition_pixel, map_pixel, globalParams, worker_dict, *kwargs):
    priorGrid = {}
    for rind, r in enumerate(np.sort(map_partition["rmag"].to_numpy())):
        # interpolate prior map onto locus Mr-FeH grid
        Z = map_partition[map_partition["rmag"] == r]
        Zval = np.frombuffer(Z.iloc[0]["kde"], dtype=np.float64).reshape((96, 36))
        X = np.frombuffer(Z.iloc[0]["xGrid"], dtype=np.float64).reshape((96, 36))
        Y = np.frombuffer(Z.iloc[0]["yGrid"], dtype=np.float64).reshape((96, 36))
        points = np.array((X.flatten(), Y.flatten())).T
        values = Zval.flatten()
        # actual (linear) interpolation
        priorGrid[rind] = griddata(
            points, values, (globalParams.locusData["FeH"], globalParams.locusData[globalParams.MrColumn]), method="linear", fill_value=0
        )
    gpu_device = jax.devices()[worker_dict[get_worker().id]]
    with jax.default_device(gpu_device):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=10)
    # Append ra and dec to be able to later crossmatch
    return pd.concat([partition[["ra", "dec"]], npd.NestedFrame(estimatesDf)], axis=1)

In [5]:
locus_path = "/home/scampos/photoD/data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locus_path)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=10, kFeH=2)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)

subsampled locus 2D grid in FeH and Mr from 51 1559 to: 25 155


In [6]:
quantile_cols = [f"{statisticsName}_quantile_{quantile}" for statisticsName in ["Mr","FeH","Ar","Qr"] for quantile in ["lo","median","hi"]]
estimate_cols = sorted([*quantile_cols,"MrdS","FeHdS","ArdS"])
col_names = ["ra","dec","glon","glat","chi2min",*estimate_cols]
meta = npd.NestedFrame.from_dict({ col: pd.Series([], dtype=np.float32) for col in col_names })
meta.index.name = "_healpix_29"
meta

Unnamed: 0_level_0,ra,dec,glon,glat,chi2min,Ar_quantile_hi,Ar_quantile_lo,Ar_quantile_median,ArdS,FeH_quantile_hi,FeH_quantile_lo,FeH_quantile_median,FeHdS,Mr_quantile_hi,Mr_quantile_lo,Mr_quantile_median,MrdS,Qr_quantile_hi,Qr_quantile_lo,Qr_quantile_median
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1


In [7]:
def get_worker_dict():
    res = s82_stripe_catalog._ddf.partitions[0:5].map_partitions(lambda df: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    worker_ids = np.unique(res["workers"].to_numpy())
    worker_dict = {id: i for i, id in enumerate(worker_ids)}
    print(worker_dict)
    return worker_dict

In [None]:
with Client(n_workers=4) as client:
    worker_dict = get_worker_dict()
    print(worker_dict)
    delayed_global_params = delayed(globalParams)
    merge_lazy = s82_stripe_catalog.merge_map(prior_map_catalog, merging_function, globalParams=delayed_global_params, worker_dict=worker_dict, meta=meta)
    xmatch_result = merge_lazy.compute()
xmatch_result

Perhaps you already have a cluster running?
Hosting the HTTP server on port 37787 instead


{'Worker-25b91350-7e4a-437f-a380-69ea6a90a98b': 0, 'Worker-28b3acd6-9c58-4cdc-9ba0-666347e9276d': 1, 'Worker-93168bfc-2499-4d8c-8713-50a313fed928': 2, 'Worker-96d37e11-6427-49ec-8067-3f8d9e57fd63': 3}
{'Worker-25b91350-7e4a-437f-a380-69ea6a90a98b': 0, 'Worker-28b3acd6-9c58-4cdc-9ba0-666347e9276d': 1, 'Worker-93168bfc-2499-4d8c-8713-50a313fed928': 2, 'Worker-96d37e11-6427-49ec-8067-3f8d9e57fd63': 3}


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
2024-12-11 22:19:38,526 - distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat.
Traceback (most recent call last):
  File "/home/scampos/photoD/.venv/lib/python3.10/site-packages/distributed/comm/tcp.py", line 225, in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
tornado.iostream.StreamClosedError: Stream is closed

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/scampos/photoD/.venv/lib/python3.10/site-packages/distributed/worker.py", line 1250, in heartbeat
    response = await retry_operation(
  File "/home/scampos/photoD/.venv/lib/python3.10/site-packages/distributed/utils_comm.py", line 461, in

Unnamed: 0_level_0,ra,dec,glon,glat,chi2min,Ar_quantile_hi,Ar_quantile_lo,Ar_quantile_median,ArdS,FeH_quantile_hi,FeH_quantile_lo,FeH_quantile_median,FeHdS,Mr_quantile_hi,Mr_quantile_lo,Mr_quantile_median,MrdS,Qr_quantile_hi,Qr_quantile_lo,Qr_quantile_median
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
122002702160,45.02398,0.042446,176.940106,-48.855926,6.268703,2.075961,0.229620,1.008921,-1.882393,-0.406211,-2.209344,-1.318472,1.057394,13.587354,1.389385,5.562404,0.152917,11.835555,11.085973,11.639488
162211513082,44.995031,0.038152,176.914264,-48.879749,0.447543,2.034478,0.216871,0.957247,-2.736221,-0.417717,-2.216061,-1.336751,2.567276,13.208927,1.133960,4.956499,0.242155,12.135562,11.095332,11.742871
187874205331,44.963869,0.043597,176.875399,-48.898395,17.926819,2.305922,0.614307,1.651837,-5.526733,-0.348573,-2.157651,-1.199505,7.828697,13.930016,8.673445,12.198126,-0.339550,6.828775,6.540124,6.659608
268254148314,44.998317,0.06634,176.88689,-48.857814,0.192304,2.297126,0.621026,1.631422,-4.986763,-0.350469,-2.155016,-1.199275,3.098622,13.752872,6.900836,11.391280,-0.152298,4.941794,3.304157,4.210094
282956553349,45.048274,0.048304,176.959307,-48.834366,26.733196,2.303081,0.600418,1.640683,-5.170105,-0.345306,-2.155912,-1.193337,33.961845,13.919404,8.499684,12.153976,0.693478,7.060387,6.769723,6.804479
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3458764488921378833,314.983478,-0.019971,48.889173,-28.256075,0.559832,2.121118,0.231395,1.042495,-1.602081,-0.406992,-2.210618,-1.320727,8.212790,13.597891,1.586069,5.822998,1.091535,11.507401,10.735964,11.113958
3458764491323291543,314.983523,-0.017944,48.891169,-28.255065,6.671353,2.138630,0.251128,1.099684,-0.930801,-0.394774,-2.202152,-1.299239,2.296966,13.779917,1.814878,6.826294,0.386984,11.285149,10.074786,10.586880
3458764494738379595,314.985876,-0.014536,48.895862,-28.255315,9.550972,2.062611,0.225541,0.992030,-2.133850,-0.410438,-2.211771,-1.325165,2.293740,13.472128,1.276860,5.308072,0.533657,11.972618,11.076015,11.734375
3458764505128080304,315.005058,-0.005702,48.915716,-28.267155,0.098417,2.306288,0.653148,1.662648,-5.977448,-0.351074,-2.156328,-1.201229,4.897235,13.761716,7.332669,11.440234,0.423956,6.490888,6.068348,6.297610
