## Running PhotoD with LSDB

In [1]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"

import jax
import lsdb
import matplotlib.pyplot as plt
import nested_pandas as npd
import numpy as np
import pandas as pd

from dask import delayed
from dask.distributed import Client, get_worker
from scipy.interpolate import griddata
from photod.bayes import makeBayesEstimates3d
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams

In [2]:
s82_stripe_url = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82_stripe_catalog = lsdb.read_hats(s82_stripe_url)
s82_stripe_catalog

Unnamed: 0_level_0,CALIBSTARS,ra,dec,RArms,Decrms,Ntot,Ar,uNobs,umag,ummu,uErr,umrms,umchi2,gNobs,gmag,gmmu,gErr,gmrms,gmchi2,rNobs,rmag,rmmu,rErr,rmrms,rmchi2,iNobs,imag,immu,iErr,imrms,imchi2,zNobs,zmag,zmmu,zErr,zmrms,zmchi2,Norder,Dir,Npix,Mr,FeH,MrEst,MrEstUnc,FeHEst,ug,gr,gi,ri,iz,ugErr,grErr,giErr,riErr,izErr,glon,glat
npartitions=7,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1
"Order: 4, Pixel: 0",string[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int8[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow]
"Order: 4, Pixel: 768",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 2303",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 3071",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [3]:
prior_map_url = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
prior_map_catalog = lsdb.read_hats(prior_map_url)
prior_map_catalog

Unnamed: 0_level_0,rmag,kde,xGrid,yGrid,Norder,Dir,Npix
npartitions=207,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
"Order: 5, Pixel: 0",double[pyarrow],binary[pyarrow],binary[pyarrow],binary[pyarrow],uint8[pyarrow],uint64[pyarrow],uint64[pyarrow]
"Order: 5, Pixel: 1",...,...,...,...,...,...,...
...,...,...,...,...,...,...,...
"Order: 5, Pixel: 12286",...,...,...,...,...,...,...
"Order: 5, Pixel: 12287",...,...,...,...,...,...,...


In [4]:
gaia_distances = lsdb.read_hats('https://data.lsdb.io/hats/gaia_dr3/gaia_edr3_distances', margin_cache='https://data.lsdb.io/hats/gaia_dr3/gaia_edr3_distances_10arcs')
gaia_distances

Unnamed: 0_level_0,source_id,r_med_geo,r_lo_geo,r_hi_geo,r_med_photogeo,r_lo_photogeo,r_hi_photogeo,flag,ra,dec,Norder,Dir,Npix
npartitions=3243,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
"Order: 2, Pixel: 0",int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],int8[pyarrow],int64[pyarrow],int64[pyarrow]
"Order: 3, Pixel: 4",...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 3, Pixel: 766",...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 3, Pixel: 767",...,...,...,...,...,...,...,...,...,...,...,...,...


In [5]:
locus_path = "/home/scampos/photoD/data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locus_path)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=10, kFeH=2)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)

subsampled locus 2D grid in FeH and Mr from 51 1559 to: 25 155


In [6]:
def merging_function(partition, map_partition, partition_pixel, map_pixel, globalParams, worker_dict, *kwargs):
    priorGrid = {}
    for rind, r in enumerate(np.sort(map_partition["rmag"].to_numpy())):
        # interpolate prior map onto locus Mr-FeH grid
        Z = map_partition[map_partition["rmag"] == r]
        Zval = np.frombuffer(Z.iloc[0]["kde"], dtype=np.float64).reshape((96, 36))
        X = np.frombuffer(Z.iloc[0]["xGrid"], dtype=np.float64).reshape((96, 36))
        Y = np.frombuffer(Z.iloc[0]["yGrid"], dtype=np.float64).reshape((96, 36))
        points = np.array((X.flatten(), Y.flatten())).T
        values = Zval.flatten()
        # actual (linear) interpolation
        priorGrid[rind] = griddata(
            points, values, (globalParams.locusData["FeH"], globalParams.locusData[globalParams.MrColumn]), method="linear", fill_value=0
        )
    gpu_device = jax.devices()[worker_dict[get_worker().id]]
    with jax.default_device(gpu_device):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=10)
    # Append ra and dec to be able to later crossmatch
    return pd.concat([partition[["ra", "dec"]], npd.NestedFrame(estimatesDf)], axis=1)

In [7]:
def get_worker_dict():
    res = s82_stripe_catalog._ddf.partitions[0:5].map_partitions(lambda df: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    worker_ids = np.unique(res["workers"].to_numpy())
    worker_dict = {id: i for i, id in enumerate(worker_ids)}
    print(worker_dict)
    return worker_dict

In [8]:
quantile_cols = [f"{statisticsName}_quantile_{i}" for statisticsName in ["Mr", "FeH", "Ar"] for i in range(3)]
estimate_cols = sorted([*quantile_cols,"MrdS","FeHdS","ArdS"])
col_names = ["ra","dec","glon","glat","chi2min",*estimate_cols]
meta = npd.NestedFrame.from_dict({ col: pd.Series([], dtype=np.float32) for col in col_names })
meta.index.name = "_healpix_29"
meta

Unnamed: 0_level_0,ra,dec,glon,glat,chi2min,Ar_quantile_0,Ar_quantile_1,Ar_quantile_2,ArdS,FeH_quantile_0,FeH_quantile_1,FeH_quantile_2,FeHdS,Mr_quantile_0,Mr_quantile_1,Mr_quantile_2,MrdS
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1


In [9]:
with Client(n_workers=4) as client:
    worker_dict = get_worker_dict()
    delayed_global_params = delayed(globalParams)
    merge_lazy = s82_stripe_catalog.merge_map(prior_map_catalog, merging_function, globalParams=delayed_global_params, worker_dict=worker_dict, meta=meta)
    crossmatched = merge_lazy.crossmatch(gaia_distances)
    xmatch_result = crossmatched.compute()
xmatch_result

{'Worker-03b264fa-5d21-436c-b8a7-618914e368dc': 0, 'Worker-04a63540-e3d4-40e5-90b8-59332237ca46': 1, 'Worker-40cdb118-a667-4653-bade-bb1200de49df': 2, 'Worker-83de145a-0ff0-4d12-9377-8a23e8b2f1ad': 3}


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
2024-12-10 21:15:32,824 - distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat.
Traceback (most recent call last):
  File "/home/scampos/photoD/.venv/lib/python3.10/site-packages/distributed/comm/tcp.py", line 225, in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
tornado.iostream.StreamClosedError: Stream is closed

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/scampos/photoD/.venv/lib/python3.10/site-packages/distributed/worker.py", line 1250, in heartbeat
    response = await retry_operation(
  File "/home/scampos/photoD/.venv/lib/python3.10/site-packages/distributed/utils_comm.py", line 461, in

Unnamed: 0_level_0,ra_S82_fixed,dec_S82_fixed,glon_S82_fixed,glat_S82_fixed,chi2min_S82_fixed,Ar_quantile_0_S82_fixed,Ar_quantile_1_S82_fixed,Ar_quantile_2_S82_fixed,ArdS_S82_fixed,FeH_quantile_0_S82_fixed,...,r_med_photogeo_gaia_edr3_distances,r_lo_photogeo_gaia_edr3_distances,r_hi_photogeo_gaia_edr3_distances,flag_gaia_edr3_distances,ra_gaia_edr3_distances,dec_gaia_edr3_distances,Norder_gaia_edr3_distances,Dir_gaia_edr3_distances,Npix_gaia_edr3_distances,_dist_arcsec
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
162211513082,44.995031,0.038152,176.914264,-48.879749,0.447543,,,,,,...,894.886536,740.259399,1037.2533,10033,44.995037,0.038152,2,0.0,0.0,0.022146
187874205331,44.963869,0.043597,176.875399,-48.898395,17.926819,,,,,,...,2312.65527,2050.8374,2602.06006,10033,44.963896,0.043595,2,0.0,0.0,0.098435
268254148314,44.998317,0.06634,176.88689,-48.857814,0.192304,,,,,,...,5598.27002,4505.23193,6764.95508,11033,44.998327,0.066333,2,0.0,0.0,0.044784
282956553349,45.048274,0.048304,176.959307,-48.834366,26.733196,,,,,,...,616.651428,603.039551,628.16156,10033,45.048282,0.048254,2,0.0,0.0,0.182617
425727624950,45.023562,0.068453,176.911213,-48.838178,35.408913,,,,,,...,811.626465,768.958496,852.594788,10033,45.02362,0.068419,2,0.0,0.0,0.241828
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3458764488921378833,314.983478,-0.019971,48.889173,-28.256075,0.559832,,,,,,...,400.299591,376.472382,422.925659,10033,314.983418,-0.020028,3,0.0,767.0,0.296106
3458764491323291543,314.983523,-0.017944,48.891169,-28.255065,6.671353,,,,,,...,2046.27917,1743.60352,2379.9353,10033,314.983513,-0.017962,3,0.0,767.0,0.07487
3458764494738379595,314.985876,-0.014536,48.895862,-28.255315,9.550972,,,,,,...,2105.90015,1608.70532,2510.09351,10033,314.985882,-0.014514,3,0.0,767.0,0.083264
3458764505128080304,315.005058,-0.005702,48.915716,-28.267155,0.098417,,,,,,...,6418.42676,5273.95117,7988.17871,10033,315.005038,-0.005738,3,0.0,767.0,0.148563


In [10]:
#xmatch_result[["r_med_photogeo_gaia_edr3_distances", "r_lo_geo_gaia_edr3_distances", "r_hi_geo_gaia_edr3_distances", "D_S82_fixed", "DUnc_S82_fixed"]]

In [11]:
#mod_err = (xmatch_result["r_hi_photogeo_gaia_edr3_distances"].to_numpy() - xmatch_result["r_lo_photogeo_gaia_edr3_distances"].to_numpy()) / 2
#obs_err = xmatch_result["DUnc_S82_fixed"].to_numpy()
#normalized_residual = (xmatch_result["D_S82_fixed"].to_numpy() - xmatch_result["r_med_photogeo_gaia_edr3_distances"].to_numpy()) / np.sqrt(obs_err ** 2 + mod_err ** 2)
#normalized_residual

In [12]:
#mean_kh = np.nanmean(normalized_residual)
#std_kh = np.nanstd(normalized_residual)
#mean_kh, std_kh

In [13]:
#histog = plt.hist(normalized_residual, bins=np.linspace(-10, 10, 100))
#plt.title(f"Normalized residual with mean {mean_kh:.2f} and std {std_kh:.2f}")