# Compute Bayes estimates

In [1]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"

# Disable GPU memory pre-allocation
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import lsdb
import nested_pandas as npd
import numpy as np
import pandas as pd

from dask import delayed
from dask.distributed import Client, get_worker
from photod.bayes import makeBayesEstimates3d, getEstimatesMeta
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams
from photod.priors import initializePriorGrid

In [2]:
s82StripeUrl = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82StripeCatalog = lsdb.read_hats(s82StripeUrl)
s82StripeCatalog

Unnamed: 0_level_0,CALIBSTARS,ra,dec,RArms,Decrms,Ntot,Ar,uNobs,umag,ummu,uErr,umrms,umchi2,gNobs,gmag,gmmu,gErr,gmrms,gmchi2,rNobs,rmag,rmmu,rErr,rmrms,rmchi2,iNobs,imag,immu,iErr,imrms,imchi2,zNobs,zmag,zmmu,zErr,zmrms,zmchi2,Norder,Dir,Npix,Mr,FeH,MrEst,MrEstUnc,FeHEst,ug,gr,gi,ri,iz,ugErr,grErr,giErr,riErr,izErr,glon,glat
npartitions=7,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1
"Order: 4, Pixel: 0",string[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int8[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow]
"Order: 4, Pixel: 768",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 2303",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 3071",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [3]:
priorMapUrl = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
priorMapCatalog = lsdb.read_hats(priorMapUrl)
priorMapCatalog

Unnamed: 0_level_0,rmag,kde,xGrid,yGrid,Norder,Dir,Npix
npartitions=207,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
"Order: 5, Pixel: 0",double[pyarrow],binary[pyarrow],binary[pyarrow],binary[pyarrow],uint8[pyarrow],uint64[pyarrow],uint64[pyarrow]
"Order: 5, Pixel: 1",...,...,...,...,...,...,...
...,...,...,...,...,...,...,...
"Order: 5, Pixel: 12286",...,...,...,...,...,...,...
"Order: 5, Pixel: 12287",...,...,...,...,...,...,...


In [4]:
locusPath = "../../data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locusPath)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=1, kFeH=1)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)

subsampled locus 2D grid in FeH and Mr from 51 1559 to: 51 1559


In [5]:
def mergingFunction(
    partition,
    mapPartition,
    partitionPixel,
    mapPixel,
    globalParams,
    workerDict,
    batchSize=10,
    **kwargs,
):
    """Function used by lsdb `merge_map`"""
    priorGrid = initializePriorGrid(mapPartition, globalParams)
    gpuDevice = jax.devices()[workerDict[get_worker().id]]
    with jax.default_device(gpuDevice):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=batchSize)
    return npd.NestedFrame(estimatesDf)


In [6]:
def getWorkerToGpuMapping(nWorkers):
    """Create a mapping between each worker and a GPU"""
    result = s82StripeCatalog._ddf.partitions[:nWorkers].map_partitions(
        lambda _: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    workerIds = np.unique(result["workers"].to_numpy())
    return {id: i for i, id in enumerate(workerIds)}

In [7]:
nWorkers = 4

with Client(n_workers=nWorkers) as client:
    future = client.scatter(globalParams)
    workerToGpuMapping = getWorkerToGpuMapping(nWorkers)
    mergeLazy = s82StripeCatalog.merge_map(
        priorMapCatalog, 
        mergingFunction, 
        globalParams=future, 
        workerDict=workerToGpuMapping,
        meta=getEstimatesMeta(),
    )
    mergeResult = mergeLazy.compute()

mergeResult



Unnamed: 0_level_0,glon,glat,chi2min,Ar_quantile_hi,Ar_quantile_lo,Ar_quantile_median,ArdS,FeH_quantile_hi,FeH_quantile_lo,FeH_quantile_median,FeHdS,Mr_quantile_hi,Mr_quantile_lo,Mr_quantile_median,MrdS,Qr_quantile_hi,Qr_quantile_lo,Qr_quantile_median
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
122002702160,176.940106,-48.855926,6.268703,0.386949,0.174107,0.280442,-132.473633,-0.227780,-0.964648,-0.523272,-15.111343,10.874006,10.172916,10.476210,-306.561066,11.128321,10.469028,10.755150
162211513082,176.914264,-48.879749,0.364965,0.414760,0.332011,0.373790,-204.202469,-0.091860,-0.608757,-0.324400,-26.431089,10.827707,10.317924,10.558826,-359.967987,11.195699,10.696110,10.932160
187874205331,176.875399,-48.898395,17.801376,0.281888,0.191349,0.237319,-195.978607,-0.465124,-0.754734,-0.608852,-33.365372,6.583451,6.234819,6.410890,-382.531738,6.786264,6.503552,6.646283
268254148314,176.88689,-48.857814,0.055334,0.493730,0.299565,0.414506,-142.391022,-2.050107,-2.418640,-2.218580,-36.739525,4.663245,3.311964,4.256591,-242.156769,5.037274,3.637865,4.721453
282956553349,176.959307,-48.834366,21.198009,0.412039,0.351518,0.381228,-231.246979,-0.800082,-0.984448,-0.895246,-31.138645,6.499413,6.285667,6.408841,-400.166687,6.857510,6.690443,6.790002
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3458764488921378833,48.889173,-28.256075,0.503932,0.233284,0.190768,0.212024,-270.155884,-0.053726,-0.512604,-0.243493,-23.487144,10.355577,9.897295,10.096420,-322.765930,10.565968,10.109548,10.309708
3458764491323291543,48.891169,-28.255065,6.493702,0.156990,0.020727,0.081544,-173.114441,-0.047796,-0.520705,-0.239814,-30.629833,10.302315,9.813795,10.025238,-333.407837,10.382231,9.908035,10.110762
3458764494738379595,48.895862,-28.255315,9.436792,0.355219,0.168725,0.262869,-142.162201,-0.079470,-0.594446,-0.304783,-26.398663,10.727048,10.189111,10.439279,-314.661377,10.971685,10.466538,10.698907
3458764505128080304,48.915716,-28.267155,0.091048,0.368464,0.196369,0.278210,-148.013855,-1.812392,-2.172716,-2.011806,-30.696327,6.212793,5.642663,5.964199,-287.271667,6.414802,6.005089,6.242133
