## Running JAX with Dask

Test whether we can run JAX with Dask on CPU and GPU.

In [2]:
import dask
import jax
import lsdb

from dask.distributed import Client
from hats.pixel_math.healpix_pixel import HealpixPixel
from photod.bayes import makeBayesEstimates3d
from photod.parameters import GlobalParams
from photod.priors import readPriors
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList

Configure JAX to use the device according to the specified flag. Using GPU requires a smaller batch size.

In [None]:
# Need to restart kernel everytime we change this
device = "gpu"
jax.config.update("jax_platform_name", device)
print(f"Using {device}: {jax.devices()}")
batchSize = 1000 #if device == "gpu" else 10000

Using gpu: [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]


Let's load the s82 stripe catalog:

In [4]:
s82 = lsdb.read_hats("/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed/")

And generate the model data:

In [6]:
locus_path = "/home/scampos/photoD/data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locus_path)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=1, kFeH=1)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)

subsampled locus 2D grid in FeH and Mr from 51 1559 to: 51 1559


Let's wrap the `makeBayesEstimates3d` with delayed.

In [None]:
def run_jax_with_dask(partition_df, batchSize):
    pix = HealpixPixel(5,0)
    partition_df = s82.pixel_search([(pix.order, pix.pixel)]).compute().reset_index(drop=True)
    globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)
    return delayed_bayes_estimates_3d(partition_df, pix, globalParams, batchSize)

@dask.delayed
def delayed_bayes_estimates_3d(partition_df, pix, globalParams, batchSize):
    priorsRootName = f"/mnt/beegfs/scratch/scampos/photod/priors/TRILEGAL/S82/{pix.order}/{pix.pixel}"
    priorGrid = readPriors(priorsRootName, globalParams.locusData, globalParams.MrColumn)
    priorGrid = jax.numpy.array(list(priorGrid.values()))
    estimatesDf, _ = makeBayesEstimates3d(partition_df, priorGrid, globalParams, batchSize=batchSize)
    return estimatesDf

Let's instantiate a Dask Client and run the workflow:

In [6]:
%%time
with Client(n_workers=3):
    results = run_jax_with_dask(s82, batchSize)
    results_cpu = results.compute()

Perhaps you already have a cluster running?
Hosting the HTTP server on port 44479 instead
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


CPU times: user 1.97 s, sys: 3.58 s, total: 5.55 s
Wall time: 15 s


In [8]:
%%time
with Client(n_workers=3):
    results = run_jax_with_dask(s82, batchSize)
    results_gpu = results.compute()

Perhaps you already have a cluster running?
Hosting the HTTP server on port 45987 instead
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


CPU times: user 1.75 s, sys: 3.03 s, total: 4.78 s
Wall time: 14.2 s
