In [1]:
from upath import UPath

import jax
import jax.numpy as jnp
import optax
from dask.distributed import Client
from flax import nnx
from tqdm.auto import tqdm

from uncle_val.datasets.dp1 import dp1_catalog_single_band
from uncle_val.learning.lsdb_dataset import LSDBDataGenerator
from uncle_val.learning.models import LinearModel, MLPModel
from uncle_val.learning.losses import minus_ln_chi2_prob, kl_divergence_whiten
from uncle_val.learning.training import train_step

  import pynvml


In [2]:
BAND = "r"

N_SRC = 20
BATCH_SIZE = 32

FLUX_SCALER_SCALE = 10 ** (-0.4 * (23 - 31.2))
FLUX_SCALER_MAX = 10 ** (-0.4 * (14 - 31.2))
# Rounded values for "valid" r-band forced photometry
LG_FLUXERR_SCALER_MIN = -1.0
LG_FLUXERR_SCALER_MAX = 4.0

In [3]:
@jax.jit
def norm_flux(flux):
    return jnp.arcsinh(flux / FLUX_SCALER_SCALE) / jnp.arcsinh(FLUX_SCALER_MAX / FLUX_SCALER_SCALE)


@jax.jit
def norm_fluxerr(err):
    lg_err = jnp.log10(err)
    return (lg_err - LG_FLUXERR_SCALER_MIN) / (LG_FLUXERR_SCALER_MAX - LG_FLUXERR_SCALER_MIN)

In [4]:
# DP1_ROOT = UPath("ssh://kmalanch@cmu.data.lsdb.io:/mnt/data/hats/catalogs/dp1")
DP1_ROOT = UPath("../../data/dp1")
assert DP1_ROOT.exists()

In [5]:
catalog = dp1_catalog_single_band(
    root=DP1_ROOT,
    band="r",
    obj="science",
    img="cal",
    phot="PSF",
    mode="forced",
).map_partitions(
    lambda df: df.drop(
        columns=["r_psfMag", "objectId", "coord_ra", "coord_dec"],
    ),
)
catalog

Unnamed: 0_level_0,objectForcedSource
npartitions=389,Unnamed: 1_level_1
"Order: 6, Pixel: 130","nested<psfFlux: [float], psfFluxErr: [float]>"
"Order: 8, Pixel: 2176",...
...,...
"Order: 9, Pixel: 2302101",...
"Order: 7, Pixel: 143884",...


In [6]:
# with Client(n_workers=8, memory_limit="8GB", threads_per_worker=1) as client:
#     mini = catalog["objectForcedSource.psfFlux"].min().compute()
#     maxi = catalog["objectForcedSource.psfFlux"].max().compute()
# print(jnp.log10(mini))
# print(jnp.log10(maxi))

# with Client(n_workers=8, memory_limit="8GB", threads_per_worker=1) as client:
#     mini = catalog["objectForcedSource.psfFluxErr"].min().compute()
#     maxi = catalog["objectForcedSource.psfFluxErr"].max().compute()
# print(jnp.log10(mini))
# print(jnp.log10(maxi))

# with Client(n_workers=8, memory_limit="8GB", threads_per_worker=1) as client:
#     mediani = catalog["objectForcedSource.psfFluxErr"].median_approximate().compute()
# print(jnp.log10(mediani))

In [7]:
# model = MLPModel(
#     d_input=2,
#     d_middle=(300, 300, 400),
#     dropout=0.2,
#     rngs=nnx.Rngs(42),
# )
model = LinearModel(d_input=2, d_output=1, rngs=nnx.Rngs(42))
# step = nnx.jit(lambda **kwargs: train_step(loss=minus_ln_chi2_prob, **kwargs))
step = nnx.jit(lambda **kwargs: train_step(loss=kl_divergence_whiten, **kwargs))

with Client(n_workers=8, memory_limit="8GB", threads_per_worker=1) as client:
    display(client)
    data_gen = LSDBDataGenerator(
        catalog=catalog,
        client=client,
        n_src=N_SRC,
        partitions_per_chunk=20,  # number of partitions per chunk
        seed=42,
    )
    n_lc = 0
    optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
    for chunk in tqdm(data_gen):
        n_obj = len(chunk)
        flux_2d = jnp.asarray(chunk.nest["psfFlux"]).reshape(n_obj, N_SRC)
        err_2d = jnp.asarray(chunk.nest["psfFluxErr"]).reshape(n_obj, N_SRC)
        for i_obj_start in tqdm(range(0, n_obj - BATCH_SIZE, BATCH_SIZE)):
            n_lc += BATCH_SIZE
            i_obj_end = i_obj_start + BATCH_SIZE
            flux = flux_2d[i_obj_start:i_obj_end]
            err = err_2d[i_obj_start:i_obj_end]
            theta = jnp.stack([norm_flux(flux), norm_fluxerr(err)], axis=-1)
            step(
                model=model,
                optimizer=optimizer,
                theta=theta,
                flux=flux,
                err=err,
            )


print(f"{n_lc = }")
model(jnp.asarray([norm_flux(10 ** (-0.4 * (20 - 31.4))), norm_fluxerr(10**2)]))

  import pynvml
  import pynvml
  import pynvml
  import pynvml
  import pynvml
  import pynvml
  import pynvml
  import pynvml


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 8,Total memory: 59.60 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:63809,Workers: 0
Dashboard: http://127.0.0.1:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B

0,1
Comm: tcp://127.0.0.1:63828,Total threads: 1
Dashboard: http://127.0.0.1:63830/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63812,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-tb0kjp0l,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-tb0kjp0l

0,1
Comm: tcp://127.0.0.1:63829,Total threads: 1
Dashboard: http://127.0.0.1:63832/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63814,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-ry2pm8r4,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-ry2pm8r4

0,1
Comm: tcp://127.0.0.1:63834,Total threads: 1
Dashboard: http://127.0.0.1:63838/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63816,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-53mskvcb,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-53mskvcb

0,1
Comm: tcp://127.0.0.1:63835,Total threads: 1
Dashboard: http://127.0.0.1:63837/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63818,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-gwcvx028,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-gwcvx028

0,1
Comm: tcp://127.0.0.1:63836,Total threads: 1
Dashboard: http://127.0.0.1:63841/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63820,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-1ps9r10d,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-1ps9r10d

0,1
Comm: tcp://127.0.0.1:63844,Total threads: 1
Dashboard: http://127.0.0.1:63846/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63822,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-wub93vrk,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-wub93vrk

0,1
Comm: tcp://127.0.0.1:63843,Total threads: 1
Dashboard: http://127.0.0.1:63845/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63824,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-6nvzbu10,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-6nvzbu10

0,1
Comm: tcp://127.0.0.1:63849,Total threads: 1
Dashboard: http://127.0.0.1:63850/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:63826,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-3rp198a7,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-3rp198a7




  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/574 [00:00<?, ?it/s]

  0%|          | 0/1153 [00:00<?, ?it/s]

  0%|          | 0/980 [00:00<?, ?it/s]

  0%|          | 0/1134 [00:00<?, ?it/s]

  0%|          | 0/1321 [00:00<?, ?it/s]

  0%|          | 0/1613 [00:00<?, ?it/s]

  0%|          | 0/1281 [00:00<?, ?it/s]

  0%|          | 0/1108 [00:00<?, ?it/s]

  0%|          | 0/1210 [00:00<?, ?it/s]

  0%|          | 0/907 [00:00<?, ?it/s]

  0%|          | 0/877 [00:00<?, ?it/s]

  0%|          | 0/1876 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/1157 [00:00<?, ?it/s]

  0%|          | 0/1249 [00:00<?, ?it/s]

  0%|          | 0/1255 [00:00<?, ?it/s]

  0%|          | 0/992 [00:00<?, ?it/s]

  0%|          | 0/922 [00:00<?, ?it/s]

  0%|          | 0/1357 [00:00<?, ?it/s]

  0%|          | 0/285 [00:00<?, ?it/s]

n_lc = 718944


Array([2.8908763], dtype=float32)

In [8]:
model(jnp.asarray([norm_flux(10 ** (-0.4 * (20 - 31.4))), norm_fluxerr(10**0)]))

Array([2.4432547], dtype=float32)