# Load and process ZTF AXS (DR14) with LSDB and Tape

In [1]:
from time import monotonic

import light_curve as licu
import lsdb
import numpy as np
from dask.distributed import Client
from tape import Ensemble, ColumnMapper

### Use LSDB to load 'object' and 'source' catalogs

We do not really read or process anything (but some metadata) until `.compute()` is called in the very end

We load few columns only, but actual analysis bellow doesn't really use them all

Paths are for PSC

In [2]:
%%time

objects = lsdb.read_hipscat(
    '/ocean/projects/phy210048p/shared/hipscat/catalogs/ztf_axs/ztf_dr14',
    # Select few columns only
    columns=['ps1_objid', 'ra', 'dec', 'nobs_g', 'nobs_r', 'nobs_i',
             # HiPSCat-specific columns
             'Norder', 'Dir', 'Npix'],
)
sources = lsdb.read_hipscat(
    '/ocean/projects/phy210048p/shared/hipscat/catalogs/ztf_axs/ztf_source',
    # Select few columns only
    columns=['ps1_objid', 'mjd', 'mag', 'magerr', 'catflags', 'band',
             # HiPSCat-specific columns
             'Norder', 'Dir', 'Npix'],
)

CPU times: user 28.6 s, sys: 2.57 s, total: 31.2 s
Wall time: 39.2 s


### Use LSDB to join objects and sources

This would assign sources object's `_hipscat_index`, which we are going to use as a primary key

In [3]:
%%time
joined_sources = objects.join(
    sources,
    left_on='ps1_objid',
    right_on='ps1_objid',
    output_catalog_name='ztf_axs_sources'
)

CPU times: user 2min 22s, sys: 1.72 s, total: 2min 24s
Wall time: 2min 24s


In [4]:
joined_sources

Unnamed: 0_level_0,ps1_objid_ztf_dr14,ra_ztf_dr14,dec_ztf_dr14,nobs_g_ztf_dr14,nobs_r_ztf_dr14,nobs_i_ztf_dr14,Norder_ztf_dr14,Dir_ztf_dr14,Npix_ztf_dr14,ps1_objid_ztf_source,mjd_ztf_source,mag_ztf_source,magerr_ztf_source,catflags_ztf_source,band_ztf_source,Norder_ztf_source,Dir_ztf_source,Npix_ztf_source
npartitions=311037,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,int64,float64,float64,int32,int32,int32,int32,int32,int32,int64,float64,float32,float32,int16,string,int32,int32,int32
281474976710656,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13834987686537986048,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18446744073709551615,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


### Let's take only first few partitions of the object table

In [5]:
# n_object_partitions = 5
# # off by one because we need last partition **end**
# last_object_division = objects._ddf.divisions[n_object_partitions+1]

# n_joined_source_partitions = np.searchsorted(joined_sources._ddf.divisions, last_object_division) - 1

# print(f'{n_object_partitions = } / {objects._ddf.npartitions}')
# print(f'{n_joined_source_partitions = } / {joined_sources._ddf.npartitions}')

# object_frame = objects._ddf.partitions[:n_object_partitions]
# source_frame = joined_sources._ddf.partitions[:n_joined_source_partitions]

### Or just take all the data

In [6]:
object_frame = objects._ddf
source_frame = joined_sources._ddf

### Create Dask client

On a SLURM cluster we may use `dask_jobqueue.SLURMCluster` to scale our job.
In this case the current node would be a manager node, and would run none of Dask workers itself.
Instead it would run SLURM jobs, each with few workers, and assign Dask tasks for them.

In [7]:
### Set larger timeouts

import dask

dask.config.set({
    'distributed.comm.timeouts.connect': '3600s',
    'distributed.comm.timeouts.tcp': '3600s',
})

### Create a SLURM cluster

from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # Number of Dask workers per node
    processes=8,
    # Regular memory node type on PSC bridges2
    queue="RM",
    # Infiniband should be faster, but it doesn't work well =(
    # interface='ib0',
    # dask_jobqueue requires cores and memory to be specified
    # We set them to match RM specs
    cores=128,
    memory="256GB",
    walltime="12:00:00",
    death_timeout=7200.0,
)
# Scale to up to 20 nodes
cluster.scale(jobs=10)
# cluster.adapt(maximum_jobs=20)
client = Client(cluster)

### Or create a local cluster
# client = Client()

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://10.8.9.34:8787/status,

0,1
Dashboard: http://10.8.9.34:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.8.9.34:45931,Workers: 0
Dashboard: http://10.8.9.34:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


### Show how to access Dask dashboard

In [8]:
# Make a command for dashboard ssh-tunneling

import socket
from getpass import getuser
from urllib.parse import urlparse

local_addr = '127.0.0.1:8787'
remote_host = 'bridges2.psc.edu'

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
    s.connect(('1.1.1.1', 53))
    ip = s.getsockname()[0]
username = getuser()
dashboard_port = urlparse(client.dashboard_link).port

print(f'''
Copy-paste and run in your terminal:

ssh -N -L {local_addr}:{ip}:{dashboard_port} {username}@{remote_host}

And open this URL in your browser to see the dashboard:
http://{local_addr}/
''')


Copy-paste and run in your terminal:

ssh -N -L 127.0.0.1:8787:10.8.9.34:8787 malanche@bridges2.psc.edu

And open this URL in your browser to see the dashboard:
http://127.0.0.1:8787/



### Create Tape Ensemble and plan the pipeline

In [9]:
%%time

ens = Ensemble(client)
column_mapper = ColumnMapper(
    id_col='_hipscat_index',
    time_col='mjd_ztf_source',
    flux_col='mag_ztf_source',
    err_col='magerr_ztf_source',
    band_col='band_ztf_source',
)
ens.from_dask_dataframe(
    object_frame=object_frame,
    source_frame=source_frame,
    sorted=True,
    sort=False,
    sync_tables=False,
    column_mapper=column_mapper,
)

# ens.source.query('catflags_ztf_source == 0 & magerr_ztf_source > 0').update_ensemble()
# ens = ens.calc_nobs(by_band=False, label="nobs", temporary=False)
# ens.object.query('nobs_total >= 2000').update_ensemble()
# features = ens.batch(licu.ReducedChi2(), band_to_calc=None, label='features', compute=False)
# ens.object.merge(features).update_ensemble()
# ens.object.query('chi2 >= 3.0').update_ensemble()
ens.object.query('nobs_r == 1000 and nobs_g == 0 and nobs_i == 0').update_ensemble()
print(len(ens.object.dask))
new_object = ens.object.repartition(npartitions=1)
new_object.ensemble = ens
new_object.update_ensemble()
# ens.source.query('catflags_ztf_source == 0 & magerr_ztf_source > 0').update_ensemble()
features = ens.batch(licu.Amplitude(), band_to_calc=None, label='features')
print(len(features.dask))
max_amplitude = features['amplitude'].max()
print(len(max_amplitude.dask))



7056
1557537
2179612
CPU times: user 1.35 s, sys: 105 ms, total: 1.46 s
Wall time: 1.46 s


In [10]:
len(max_amplitude.dask) / 1e6
# len(ens.source.dask) / 1e6

2.179612

### Run the pipeline

In [None]:
%%time

t = monotonic()
result = max_amplitude.compute()
dt = monotonic() - t
dt, result

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


### Save results to disk, just in case

In [None]:
with open('cluster_result.txt', 'w') as f:
    f.write(f'{dt = }\n{result = }\n')

### Shut down Dask cluster

If we run SLURM cluster it would also cancel all associated slurm jobs

In [None]:
client.close()