## Running PhotoD with LSDB

In [1]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"

# Disable GPU memory pre-allocation
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import lsdb
import nested_pandas as npd
import numpy as np
import pandas as pd

from dask import delayed
from dask.distributed import Client, get_worker
from photod.bayes import makeBayesEstimates3d, getEstimatesMeta
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams
from photod.priors import initializePriorGrid

In [2]:
s82StripeUrl = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82StripeCatalog = lsdb.read_hats(s82StripeUrl)
s82StripeCatalog

Unnamed: 0_level_0,CALIBSTARS,ra,dec,RArms,Decrms,Ntot,Ar,uNobs,umag,ummu,uErr,umrms,umchi2,gNobs,gmag,gmmu,gErr,gmrms,gmchi2,rNobs,rmag,rmmu,rErr,rmrms,rmchi2,iNobs,imag,immu,iErr,imrms,imchi2,zNobs,zmag,zmmu,zErr,zmrms,zmchi2,Norder,Dir,Npix,Mr,FeH,MrEst,MrEstUnc,FeHEst,ug,gr,gi,ri,iz,ugErr,grErr,giErr,riErr,izErr,glon,glat
npartitions=7,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1
"Order: 4, Pixel: 0",string[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int8[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow]
"Order: 4, Pixel: 768",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 2303",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 3071",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [3]:
priorMapUrl = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
priorMapCatalog = lsdb.read_hats(priorMapUrl)
priorMapCatalog

Unnamed: 0_level_0,rmag,kde,xGrid,yGrid,Norder,Dir,Npix
npartitions=207,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
"Order: 5, Pixel: 0",double[pyarrow],binary[pyarrow],binary[pyarrow],binary[pyarrow],uint8[pyarrow],uint64[pyarrow],uint64[pyarrow]
"Order: 5, Pixel: 1",...,...,...,...,...,...,...
...,...,...,...,...,...,...,...
"Order: 5, Pixel: 12286",...,...,...,...,...,...,...
"Order: 5, Pixel: 12287",...,...,...,...,...,...,...


In [4]:
locusPath = "../../data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locusPath)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=10, kFeH=2)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)

subsampled locus 2D grid in FeH and Mr from 51 1559 to: 25 155


In [5]:
def mergingFunction(
    partition,
    mapPartition,
    partitionPixel,
    mapPixel,
    globalParams,
    workerDict,
    batchSize=10,
    **kwargs,
):
    """Function used by lsdb `merge_map`"""
    priorGrid = initializePriorGrid(mapPartition, globalParams)
    gpuDevice = jax.devices()[workerDict[get_worker().id]]
    with jax.default_device(gpuDevice):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=batchSize)
    return npd.NestedFrame(estimatesDf)


In [6]:
def getWorkerToGpuMapping(nWorkers):
    """Create a mapping between each worker and a GPU"""
    result = s82StripeCatalog._ddf.partitions[:nWorkers].map_partitions(
        lambda _: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    workerIds = np.unique(result["workers"].to_numpy())
    return {id: i for i, id in enumerate(workerIds)}

In [7]:
nWorkers = 4

with Client(n_workers=nWorkers) as client:
    workerToGpuMapping = getWorkerToGpuMapping(nWorkers)
    mergeLazy = s82StripeCatalog.merge_map(
        priorMapCatalog, 
        mergingFunction, 
        globalParams=delayed(globalParams), 
        workerDict=workerToGpuMapping,
        meta=getEstimatesMeta(),
    )
    mergeResult = mergeLazy.compute()

mergeResult

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Unnamed: 0_level_0,glon,glat,chi2min,Ar_quantile_hi,Ar_quantile_lo,Ar_quantile_median,ArdS,FeH_quantile_hi,FeH_quantile_lo,FeH_quantile_median,FeHdS,Mr_quantile_hi,Mr_quantile_lo,Mr_quantile_median,MrdS,Qr_quantile_hi,Qr_quantile_lo,Qr_quantile_median
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
122002702160,176.940106,-48.855926,6.268703,0.387065,0.174300,0.280527,-132.510651,-0.245328,-0.980604,-0.531740,-7.843840,10.887256,10.191526,10.484066,-31.022160,11.135601,10.491317,10.761868
162211513082,176.914264,-48.879749,0.447543,0.414165,0.328856,0.380125,-206.702423,-0.139455,-0.625581,-0.350933,-14.186054,10.845268,10.361386,10.582030,-37.352390,11.201021,10.743488,10.956738
187874205331,176.875399,-48.898395,17.926819,0.259215,0.216018,0.237981,-248.966492,-0.460716,-0.756895,-0.608496,-17.045713,6.567495,6.244486,6.409079,-39.963490,6.755890,6.528317,6.643213
268254148314,176.88689,-48.857814,0.192304,0.492900,0.290464,0.412680,-141.893982,-2.040463,-2.433770,-2.226372,-17.458170,4.664255,3.283507,4.237616,-23.896175,5.040036,3.610987,4.701930
282956553349,176.959307,-48.834366,26.733196,0.410532,0.373446,0.391706,-298.399414,-0.775559,-0.961157,-0.865502,-20.416246,6.461159,6.275530,6.365502,-47.315956,6.804713,6.691720,6.776640
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3458764488921378833,48.889173,-28.256075,0.559832,0.241524,0.185057,0.215584,-244.715988,-0.131298,-0.740228,-0.370222,-9.389055,10.541523,9.970162,10.229635,-30.283442,10.737287,10.186606,10.428654
3458764491323291543,48.891169,-28.255065,6.671353,0.157090,0.020824,0.081469,-173.111969,-0.111893,-0.553180,-0.284827,-16.164330,10.334628,9.869857,10.072090,-34.317211,10.406882,9.976911,10.153548
3458764494738379595,48.895862,-28.255315,9.550972,0.355292,0.168895,0.263082,-142.177429,-0.132357,-0.614308,-0.335406,-14.045504,10.750250,10.235509,10.470230,-32.273094,10.985167,10.523683,10.727262
3458764505128080304,48.915716,-28.267155,0.098417,0.368252,0.196223,0.277904,-148.049713,-1.808016,-2.180966,-2.012542,-14.802888,6.220438,5.635613,5.963743,-28.421606,6.414466,6.003981,6.244876
