In [1]:
from upath import UPath

import torch
from dask.distributed import Client
from torch.optim import Adam
from tqdm.auto import tqdm

from uncle_val.datasets.dp1 import dp1_catalog_single_band
from uncle_val.learning.lsdb_dataset import LSDBIterableDataset
from uncle_val.learning.models import ConstantModel, LinearModel, MLPModel
from uncle_val.learning.losses import minus_ln_chi2_prob, kl_divergence_whiten
from uncle_val.learning.training import train_step

In [2]:
BAND = "r"

N_SRC = 10
BATCH_SIZE = 32

In [3]:
# DP1_ROOT = UPath("ssh://kmalanch@cmu.data.lsdb.io:/mnt/data/hats/catalogs/dp1")
DP1_ROOT = UPath("../../data/dp1")
assert DP1_ROOT.exists()

In [4]:
catalog = dp1_catalog_single_band(
    root=DP1_ROOT,
    band="r",
    obj="science",
    img="cal",
    phot="PSF",
    mode="forced",
).map_partitions(
    lambda df: df.query(
        "extendedness == 0",
    ).drop(
        columns=["id", "r_psfMag", "coord_ra", "coord_dec", "extendedness"],
    ),
)
catalog

Unnamed: 0_level_0,lc
npartitions=389,Unnamed: 1_level_1
"Order: 6, Pixel: 130","nested<x: [float], err: [float]>"
"Order: 8, Pixel: 2176",...
...,...
"Order: 9, Pixel: 2302101",...
"Order: 7, Pixel: 143884",...


In [5]:
model = LinearModel(d_input=2, d_output=1)
# model = torch.compile(model, mode='reduce-overhead')
model.train()

with Client(n_workers=8, memory_limit="8GB", threads_per_worker=1) as client:
    display(client)
    dataset = LSDBIterableDataset(
        catalog=catalog,
        columns=None,
        client=client,
        batch_lc=BATCH_SIZE,
        n_src=N_SRC,
        partitions_per_chunk=10,
        seed=0,
    )
    n_lc = 0
    optimizer = Adam(model.parameters(), lr=1e-3)
    for batch in tqdm(dataset):
        n_lc += len(batch)
        train_step(
            model=model,
            optimizer=optimizer,
            loss=minus_ln_chi2_prob,
            batch=batch,
        )
print(f"{n_lc = }")

model.eval()
torch.save(model, "model.pt")

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 8,Total memory: 59.60 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:60507,Workers: 0
Dashboard: http://127.0.0.1:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B

0,1
Comm: tcp://127.0.0.1:60526,Total threads: 1
Dashboard: http://127.0.0.1:60529/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60510,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-qfrcx5f8,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-qfrcx5f8

0,1
Comm: tcp://127.0.0.1:60527,Total threads: 1
Dashboard: http://127.0.0.1:60531/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60512,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-ffd4cjn4,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-ffd4cjn4

0,1
Comm: tcp://127.0.0.1:60528,Total threads: 1
Dashboard: http://127.0.0.1:60532/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60514,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-292nhhr5,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-292nhhr5

0,1
Comm: tcp://127.0.0.1:60536,Total threads: 1
Dashboard: http://127.0.0.1:60540/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60516,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-v1r6erf6,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-v1r6erf6

0,1
Comm: tcp://127.0.0.1:60535,Total threads: 1
Dashboard: http://127.0.0.1:60537/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60518,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-xsbjbehs,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-xsbjbehs

0,1
Comm: tcp://127.0.0.1:60539,Total threads: 1
Dashboard: http://127.0.0.1:60542/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60520,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-7uzcax57,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-7uzcax57

0,1
Comm: tcp://127.0.0.1:60547,Total threads: 1
Dashboard: http://127.0.0.1:60548/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60522,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-08_bbpix,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-08_bbpix

0,1
Comm: tcp://127.0.0.1:60544,Total threads: 1
Dashboard: http://127.0.0.1:60545/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60524,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-1ua3hk4u,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-1ua3hk4u




0it [00:00, ?it/s]

2025-10-06 16:56:04,035 - distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat.
Traceback (most recent call last):
  File "/Users/hombit/.virtualenvs/uncle-val/lib/python3.13/site-packages/distributed/comm/tcp.py", line 226, in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tornado.iostream.StreamClosedError: Stream is closed

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/hombit/.virtualenvs/uncle-val/lib/python3.13/site-packages/distributed/worker.py", line 1267, in heartbeat
    response = await retry_operation(
               ^^^^^^^^^^^^^^^^^^^^^^
    ...<14 lines>...
    )
    ^
  File "/Users/hombit/.virtualenvs/uncle-val/lib/python3.13/site-packages/distributed/utils_comm.py", line 416, in retry_operation
    return await retry(
           ^^^^^^^^^^^^
    ...<5 lines>...
    )
    ^
  Fi

n_lc = 95808


In [6]:
flux = 10 ** (-0.4 * (21 - 31.2))
err = 0.02 * flux
model(torch.tensor([flux, err]))

tensor([2.0004], grad_fn=<AsStridedBackward0>)