## Running PhotoD with LSDB

In [1]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"

import jax
import lsdb
import nested_pandas as npd
import numpy as np
import pandas as pd

from dask import delayed
from dask.distributed import Client, get_worker
from scipy.interpolate import griddata
from photod.bayes import makeBayesEstimates3d
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams

In [2]:
s82_stripe_url = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82_stripe_catalog = lsdb.read_hats(s82_stripe_url)
s82_stripe_catalog

Unnamed: 0_level_0,CALIBSTARS,ra,dec,RArms,Decrms,Ntot,Ar,uNobs,umag,ummu,uErr,umrms,umchi2,gNobs,gmag,gmmu,gErr,gmrms,gmchi2,rNobs,rmag,rmmu,rErr,rmrms,rmchi2,iNobs,imag,immu,iErr,imrms,imchi2,zNobs,zmag,zmmu,zErr,zmrms,zmchi2,Norder,Dir,Npix,Mr,FeH,MrEst,MrEstUnc,FeHEst,ug,gr,gi,ri,iz,ugErr,grErr,giErr,riErr,izErr,glon,glat
npartitions=7,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1
"Order: 4, Pixel: 0",string[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int8[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow]
"Order: 4, Pixel: 768",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 2303",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 4, Pixel: 3071",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [3]:
prior_map_url = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
prior_map_catalog = lsdb.read_hats(prior_map_url)
prior_map_catalog

Unnamed: 0_level_0,rmag,kde,xGrid,yGrid,Norder,Dir,Npix
npartitions=207,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
"Order: 5, Pixel: 0",double[pyarrow],binary[pyarrow],binary[pyarrow],binary[pyarrow],uint8[pyarrow],uint64[pyarrow],uint64[pyarrow]
"Order: 5, Pixel: 1",...,...,...,...,...,...,...
...,...,...,...,...,...,...,...
"Order: 5, Pixel: 12286",...,...,...,...,...,...,...
"Order: 5, Pixel: 12287",...,...,...,...,...,...,...


In [4]:
locus_path = "/home/scampos/photoD/data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locus_path)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=10, kFeH=2)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)

subsampled locus 2D grid in FeH and Mr from 51 1559 to: 25 155


In [5]:
def merging_function(partition, map_partition, partition_pixel, map_pixel, globalParams, worker_dict, *kwargs):
    priorGrid = {}
    for rind, r in enumerate(np.sort(map_partition["rmag"].to_numpy())):
        # interpolate prior map onto locus Mr-FeH grid
        Z = map_partition[map_partition["rmag"] == r]
        Zval = np.frombuffer(Z.iloc[0]["kde"], dtype=np.float64).reshape((96, 36))
        X = np.frombuffer(Z.iloc[0]["xGrid"], dtype=np.float64).reshape((96, 36))
        Y = np.frombuffer(Z.iloc[0]["yGrid"], dtype=np.float64).reshape((96, 36))
        points = np.array((X.flatten(), Y.flatten())).T
        values = Zval.flatten()
        # actual (linear) interpolation
        priorGrid[rind] = griddata(
            points, values, (globalParams.locusData["FeH"], globalParams.locusData[globalParams.MrColumn]), method="linear", fill_value=0
        )
    gpu_device = jax.devices()[worker_dict[get_worker().id]]
    with jax.default_device(gpu_device):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=10)
    return npd.NestedFrame(estimatesDf)

In [6]:
col_names = ["glon","glat","FeHEst","FeHUnc","MrEst","MrUnc","chi2min","MrdS","FeHdS","ArEst","ArUnc","ArdS","D","DUnc"]
meta = npd.NestedFrame.from_dict({ col: pd.Series([], dtype=float) for col in col_names})

In [7]:
def get_worker_dict():
    res = s82_stripe_catalog._ddf.partitions[0:5].map_partitions(lambda df: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    worker_ids = np.unique(res["workers"].to_numpy())
    worker_dict = {id: i for i, id in enumerate(worker_ids)}
    print(worker_dict)
    return worker_dict

In [8]:
with Client(n_workers=4) as client:
    worker_dict = get_worker_dict()
    delayed_global_params = delayed(globalParams)
    merge_lazy = s82_stripe_catalog.merge_map(prior_map_catalog, merging_function, globalParams=delayed_global_params, worker_dict=worker_dict, meta=meta)
    result = merge_lazy.compute()
result

Perhaps you already have a cluster running?
Hosting the HTTP server on port 46301 instead


{'Worker-2a307675-9585-4af2-a3d4-babc3a652806': 0, 'Worker-715e1c9c-214f-4ccb-a02b-2e38518a6e80': 1, 'Worker-e8f58724-5209-44b9-8e4c-75ce088dc8c0': 2, 'Worker-eeb3ac85-1a1e-4756-b284-83117288883b': 3}


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Unnamed: 0_level_0,glon,glat,FeHEst,FeHUnc,MrEst,MrUnc,chi2min,MrdS,FeHdS,ArEst,ArUnc,ArdS,D,DUnc
_healpix_29,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
122002702160,176.940106,-48.855926,-0.627993,0.432425,10.533989,0.343333,6.268703,-31.022161,-7.843840,0.280774,0.097274,-132.510666,1357.210992,223.207431
162211513082,176.914264,-48.879749,-0.379619,0.223275,10.601394,0.211412,0.447543,-37.352390,-14.186052,0.374928,0.036350,-206.702408,688.867406,68.078076
187874205331,176.875399,-48.898395,-0.608126,0.123861,6.405376,0.146170,17.926819,-39.963490,-17.045719,0.239046,0.033753,-248.966537,1961.171423,135.515142
268254148314,176.88689,-48.857814,-2.232135,0.158194,4.075833,0.578908,0.192304,-23.896175,-17.458168,0.393059,0.101834,-141.893982,9886.021366,2676.109538
282956553349,176.959307,-48.834366,-0.866308,0.055142,6.366106,0.056288,26.733196,-47.315948,-20.416243,0.391810,0.010461,-298.399414,438.951855,11.589071
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3458764488921378833,48.889173,-28.256075,-0.413612,0.247434,10.247563,0.231156,0.559832,-30.283442,-9.389055,0.214340,0.021931,-244.715988,401.253973,42.909476
3458764491323291543,48.891169,-28.255065,-0.322921,0.208012,10.096582,0.208678,6.671353,-34.317211,-16.164330,0.088037,0.060076,-173.111969,1316.061252,131.810726
3458764494738379595,48.895862,-28.255315,-0.374013,0.241811,10.492229,0.238628,9.550972,-32.273098,-14.045506,0.262515,0.085019,-142.177460,1118.024346,130.599012
3458764505128080304,48.915716,-28.267155,-1.997893,0.166214,5.934564,0.272493,0.098417,-28.421608,-14.802888,0.281323,0.078703,-148.049698,4221.071177,551.429854
