In [1]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [2]:
from functools import partial
from upath import UPath

import torch
from dask.distributed import Client
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

from uncle_val.datasets.dp1 import dp1_catalog_single_band
from uncle_val.learning.lsdb_dataset import LSDBIterableDataset
from uncle_val.learning.models import ConstantModel, LinearModel, MLPModel
from uncle_val.learning.losses import minus_ln_chi2_prob, kl_divergence_whiten
from uncle_val.learning.training import train_step, evaluate_loss

In [3]:
BAND = "r"

N_SRC = 10
BATCH_SIZE = 32
N_LCS = 400_000

VALIDATION_PER_BATCHES = 128
VALIDATION_BATCH_SIZE = VALIDATION_PER_BATCHES * BATCH_SIZE
N_VALIDATION_BATCHES = N_LCS // VALIDATION_BATCH_SIZE

SOFT = 20
LOSS_FN = minus_ln_chi2_prob
LOSS_FN = partial(LOSS_FN, soft=SOFT)

In [4]:
# DP1_ROOT = UPath("ssh://kmalanch@cmu.data.lsdb.io:/mnt/data/hats/catalogs/dp1")
DP1_ROOT = UPath("../../data/dp1")
assert DP1_ROOT.exists()

In [5]:
catalog = dp1_catalog_single_band(
    root=DP1_ROOT,
    band="r",
    obj="science",
    img="cal",
    phot="PSF",
    mode="forced",
).map_partitions(
    lambda df: df.query(
        "extendedness == 0",
    ).drop(
        columns=["r_psfMag", "coord_ra", "coord_dec", "extendedness"],
    ),
)
catalog

Unnamed: 0_level_0,id,lc
npartitions=389,Unnamed: 1_level_1,Unnamed: 2_level_1
"Order: 6, Pixel: 130",int64[pyarrow],"nested<x: [float], err: [float]>"
"Order: 8, Pixel: 2176",...,...
...,...,...
"Order: 9, Pixel: 2302101",...,...
"Order: 7, Pixel: 143884",...,...


In [6]:
from uncle_val.learning.training import evaluate_loss

model = LinearModel(d_input=2, d_output=1)
# model = torch.compile(model, mode='reduce-overhead')
model.train()

summary_writer = SummaryWriter()

with Client(n_workers=8, memory_limit="8GB", threads_per_worker=1) as client:
    display(client)
    training_dataset = iter(
        LSDBIterableDataset(
            catalog=catalog,
            columns=None,
            client=client,
            batch_lc=BATCH_SIZE,
            n_src=N_SRC,
            partitions_per_chunk=10,
            loop=True,
            hash_range=(0.00, 0.70),
            seed=0,
        )
    )
    validation_dataset = iter(
        LSDBIterableDataset(
            catalog=catalog,
            columns=None,
            client=client,
            batch_lc=VALIDATION_BATCH_SIZE,
            n_src=N_SRC,
            partitions_per_chunk=30,
            loop=True,
            hash_range=(0.70, 0.85),
            seed=0,
        )
    )
    optimizer = Adam(model.parameters(), lr=1e-3)
    train_tqdm = tqdm(total=VALIDATION_PER_BATCHES, desc="Training batches", position=1)
    for epoch, val_batch in tqdm(
        zip(range(N_VALIDATION_BATCHES), validation_dataset),
        total=N_VALIDATION_BATCHES,
        desc="Validation batches",
        position=0,
    ):
        sum_train_loss = 0.0
        train_tqdm.reset()
        for _i_train_batch, train_batch in zip(range(VALIDATION_PER_BATCHES), training_dataset):
            train_tqdm.update(1)
            sum_train_loss += train_step(
                model=model,
                optimizer=optimizer,
                loss=LOSS_FN,
                batch=train_batch,
            )

        model.eval()
        validation_loss = evaluate_loss(
            model=model,
            loss=LOSS_FN,
            batch=val_batch,
        )
        model.train()

        summary_writer.add_scalar("Sum train loss", sum_train_loss, epoch)
        summary_writer.add_scalar("Validation loss", validation_loss, epoch)
        torch.save(model, f"runs/model_{epoch:06d}.pt")

model.eval()
summary_writer.add_graph(model, train_batch[0])
torch.save(model, "model.pt")

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 8,Total memory: 59.60 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:60322,Workers: 0
Dashboard: http://127.0.0.1:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B

0,1
Comm: tcp://127.0.0.1:60342,Total threads: 1
Dashboard: http://127.0.0.1:60344/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60325,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-bjpkyqdz,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-bjpkyqdz

0,1
Comm: tcp://127.0.0.1:60343,Total threads: 1
Dashboard: http://127.0.0.1:60346/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60327,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-79xa3n54,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-79xa3n54

0,1
Comm: tcp://127.0.0.1:60348,Total threads: 1
Dashboard: http://127.0.0.1:60350/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60329,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-l10lkedx,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-l10lkedx

0,1
Comm: tcp://127.0.0.1:60349,Total threads: 1
Dashboard: http://127.0.0.1:60352/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60331,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-2j2mz5lv,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-2j2mz5lv

0,1
Comm: tcp://127.0.0.1:60354,Total threads: 1
Dashboard: http://127.0.0.1:60356/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60333,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-fqdn58xq,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-fqdn58xq

0,1
Comm: tcp://127.0.0.1:60360,Total threads: 1
Dashboard: http://127.0.0.1:60362/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60335,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-z3o604gl,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-z3o604gl

0,1
Comm: tcp://127.0.0.1:60355,Total threads: 1
Dashboard: http://127.0.0.1:60358/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60337,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-buiif5hg,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-buiif5hg

0,1
Comm: tcp://127.0.0.1:60361,Total threads: 1
Dashboard: http://127.0.0.1:60364/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:60339,
Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-p0027a58,Local directory: /var/folders/w1/lh3h4s7d5g10rdlfj4h0mshw0000gn/T/dask-scratch-space/worker-p0027a58




Training batches:   0%|          | 0/128 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/97 [00:00<?, ?it/s]



In [7]:
flux = 10 ** (-0.4 * (21 - 31.2))
err = 0.02 * flux
model(torch.tensor([flux, err]))

tensor([2.0915], grad_fn=<AsStridedBackward0>)