In [12]:
# Load and process ZTF AXS (DR14) with LSDB and Tape

In [1]:
from time import monotonic

import light_curve as licu
import lsdb
import numpy as np
from dask.distributed import Client
from tape import Ensemble, ColumnMapper

### Use LSDB to load 'object' and 'source' catalogs

We do not really read or process anything (but some metadata) until `.compute()` is called in the very end

We load few columns only, but actual analysis bellow doesn't really use them all

Paths are for PSC

In [2]:
%%time

objects = lsdb.read_hipscat(
    '/ocean/projects/phy210048p/shared/hipscat/catalogs/ztf_axs/ztf_dr14',
    columns=['ps1_objid', 'ra', 'dec', 'nobs_g', 'nobs_r', 'nobs_i',
             'Norder', 'Dir', 'Npix']
)
sources = lsdb.read_hipscat(
    '/ocean/projects/phy210048p/shared/hipscat/catalogs/ztf_axs/ztf_source',
    columns=['ps1_objid', 'mjd', 'mag', 'magerr', 'catflags', 'band',
             'Norder', 'Dir', 'Npix']
)

CPU times: user 28.8 s, sys: 2.81 s, total: 31.6 s
Wall time: 37 s


### Use LSDB to join objects and sources

This would assign sources object's `_hipscat_index`, which we are going to use as a primary key

In [None]:
%%time
joined_sources = objects.join(
    sources,
    left_on='ps1_objid',
    right_on='ps1_objid',
    output_catalog_name='ztf_axs_sources'
)

In [None]:
joined_sources

### Let's take only first few partitions of the object table

In [None]:
n_object_partitions = 5
# off by one because we need last partition **end**
last_object_division = objects._ddf.divisions[n_object_partitions+1]

n_joined_source_partitions = np.searchsorted(joined_sources._ddf.divisions, last_object_division) - 1

print(f'{n_object_partitions = } / {objects._ddf.npartitions}')
print(f'{n_joined_source_partitions = } / {joined_sources._ddf.npartitions}')

object_frame = objects._ddf.partitions[:n_object_partitions]
source_frame = joined_sources._ddf.partitions[:n_joined_source_partitions]

### Create Dask client

On a SLURM cluster we may use `dask_jobqueue.SLURMCluster` to scale our job.
In this case the current node would be a manager node, and would run none of Dask workers itself.
Instead it would run SLURM jobs, each with few workers, and assign Dask tasks for them.

In [None]:
### Create a SLURM cluster

# from dask_jobqueue import SLURMCluster

# cluster = SLURMCluster(
#     # Number of Dask workers per node
#     processes=16,
#     # Regular memory node type on PSC bridges2
#     queue="RM",
#     # dask_jobqueue requires cores and memory to be specified
#     # We set them to match RM specs
#     cores=128,
#     memory="256GB",
#     walltime="12:00:00",
# )
# cluster.adapt(maximum_jobs=10)
# client = Client(cluster)

### Or create a local cluster
client = Client()

client

### Show how to access Dask dashboard

In [None]:
# Make a command for dashboard ssh-tunneling

import socket
from getpass import getuser
from urllib.parse import urlparse

local_addr = '127.0.0.1:8787'
remote_host = 'bridges2.psc.edu'

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
    s.connect(('1.1.1.1', 53))
    ip = s.getsockname()[0]
username = getuser()
dashboard_port = urlparse(client.dashboard_link).port

print(f'''
Copy-paste and run in your terminal:

ssh -N -L {local_addr}:{ip}:{dashboard_port} {username}@{remote_host}

And open this URL in your browser to see the dashboard:
http://{local_addr}/
''')

### Create Tape Ensemble and plan the pipeline

In [None]:
%%time

ens = Ensemble(client)
column_mapper = ColumnMapper(
    id_col='_hipscat_index',
    time_col='mjd_ztf_source',
    flux_col='mag_ztf_source',
    err_col='magerr_ztf_source',
    band_col='band_ztf_source',
)
ens.from_dask_dataframe(
    object_frame=object_frame,
    source_frame=source_frame,
    sorted=True,
    sort=False,
    sync_tables=True,
    column_mapper=column_mapper,
)

ens.source.query('catflags_ztf_source == 0 & magerr_ztf_source > 0').update_ensemble()
ens = ens.calc_nobs(by_band=False, label="nobs", temporary=False)
ens.object.query('nobs_total >= 200').update_ensemble()
features = ens.batch(licu.Amplitude(), band_to_calc=None, label='features', compute=False)
max_amplitude = features['amplitude'].max()

### Run the pipeline

In [None]:
%%time

t = monotonic()
result = max_amplitude.compute()
dt = monotonic() - t
dt, result

### Save results to disk, just in case

In [None]:
with open('cluster_result.txt', 'w') as f:
    f.write(f'{dt = }\n{result = }\n')

### Shut down Dask cluster

If we run SLURM cluster it would also cancel all associated slurm jobs

In [None]:
client.close()