# Join SFD map with a point source catalog

We need LSDB for that

In [1]:
from pathlib import Path
from typing import Literal

import dask
import lsdb
import numpy as np
import pandas as pd
import ray
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN, hipscat_id_to_healpix
from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm
from ray.util.dask import enable_dask_on_ray

  from .autonotebook import tqdm as notebook_tqdm
2023-12-18 16:39:12,708	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Linear search is faster than `np.searsorted` for most of the cases.
See some benchmarks and tests here:
https://github.com/hombit/linear-search

In [2]:
import numpy as np
from numba import njit, uint64
from numpy.typing import NDArray


# @njit(uint64[:](uint64[:], uint64[:]), boundscheck=False, fastmath=True)
@njit(boundscheck=False, fastmath=True)
def linear_search_numba(a: NDArray, b: NDArray) -> NDArray:
    """Find the place index of each element of b in a. Both a and b are sorted."""

    # Initialize the index with the last index of the target array
    idx = np.full(shape=b.size, fill_value=a.size, dtype=np.uint64)

    if a.size == 0 or b.size == 0:
        return idx

    i = 0
    j = 0

    while i < a.size and j < b.size:
        while j < b.size and b[j] < a[i]:
            idx[j] = i
            j += 1
        i += 1

    return idx


# Run first time to compile
linear_search_numba(np.zeros(2, dtype=np.uint64), np.zeros(2, dtype=np.uint64))

array([2, 2], dtype=uint64)

Data paths

Hardcoded path to PS1 DR2 object table (OTMO) and SFD map at PSC

In [3]:
STARS_PATH = Path('/ocean/projects/phy210048p/shared/hipscat/catalogs/ps1/ps1_otmo')

# Use SDSS DR16 Quasar catalog for a while...
# STARS_PATH = Path('/ocean/projects/phy210048p/shared/hipscat/catalogs/agns_dr16q_prop_May16')

# Fixed order 14 SFD map
# SFD_PATH = Path('/ocean/projects/phy210048p/shared/hipscat/catalogs/sfd/sfd_order14_map')
# Multiorder SFD map, interpolation error is <1%
SFD_PATH = Path('/ocean/projects/phy210048p/shared/hipscat/catalogs/sfd/sfd_multiorder_map')

### We are using LSDB's cross-matching interface for joining

In [4]:
class JoinWithContinuousMap(AbstractCrossmatchAlgorithm):
    DISTANCE_COLUMN_NAME = '_DIST'
    
    def crossmatch(
            self,
            search_algo: Literal['auto', 'numpy', 'linear'] = 'auto',
    ) -> pd.DataFrame:
        """Perfrom cross-match
        
        Parameters
        ----------
        search_algo : 'auto' or 'numpy' or 'linear'
            Index join algorithm, one of the followoing:
            - 'numpy' - `np.searsorted(right, left, side='right')`,
              it is faster for smaller left tables it is
              O(n_left * log(n_right)). Right table hipscat index must
              be sorted.
            - 'linear' - linear search algorithm, it is faster for
              smaller right tables, it is O(n_left + n_right).
              Both tables' hipscat index must be sorted.
            - 'auto' - use algoithm which is faster by the following
              heuristics based on algorithmic complexities with some
              coefficient driven by experiments on sizes between
              thousand and million:
              if `n_left + n_right > 5 * n_left * lb(n_right)` use
              'numpy', and 'linear' otherwise.
        """        
        # Check that both catalogs are sorted by HIPSCAT_ID_COLUMN
        assert np.all(np.diff(self.left.index) > 0)
        assert np.all(np.diff(self.right[HIPSCAT_ID_COLUMN]) > 0)
   
        if search_algo == 'auto':
            if self.left.shape[0] + self.right.shape[0] > 5.0 * self.left.shape[0] * np.log2(self.right.shape[0]):
                search_algo = 'numpy'
            else:
                search_algo = 'linear'
        if search_algo == 'numpy':
            idx = np.searchsorted(
                self.right[HIPSCAT_ID_COLUMN],
                self.left.index,
                side='right',
            ) - 1
        elif search_algo == 'linear':
            idx = linear_search_numba(
                np.asarray(self.right[HIPSCAT_ID_COLUMN], dtype=np.uint64),
                np.asarray(self.left.index, dtype=np.uint64),
            ) - 1
        else:
            raise ValueError(f'Unknown search algo "{search_algo}"')
            
        
        # np.searchsorted output must be between 0 and N,
        # so we are checking -1 case only
        assert np.all(idx >= 0)
        
        self._rename_columns_with_suffix(self.left, self.suffixes[0])
        self._rename_columns_with_suffix(self.right, self.suffixes[1])
        
        left_join_part = self.left.reset_index()
        right_join_part = self.right.iloc[idx].reset_index(drop=True)
        
        out = pd.concat(
            [
                left_join_part,
                right_join_part,
            ],
            axis=1,
        )
        out[self.DISTANCE_COLUMN_NAME] = 0.0
        out.set_index(HIPSCAT_ID_COLUMN, inplace=True)
        
        return out

In [5]:
# Make a command for dashboard ssh-tunneling

import socket
from getpass import getuser
from urllib.parse import urlparse

local_addr = '127.0.0.1:8787'
remote_host = 'bridges2.psc.edu'


def print_client_info(client):
    display(client)
    
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
        s.connect(('1.1.1.1', 53))
        ip = s.getsockname()[0]
    username = getuser()
    dashboard_port = urlparse(client.dashboard_link).port

    print(f'''
    Copy-paste and run in your terminal:

    ssh -N -L {local_addr}:{ip}:{dashboard_port} {username}@{remote_host}

    And open this URL in your browser to see the dashboard:
    http://{local_addr}/
    ''')

In [None]:
%%time

# I have some connect issues runiing on PSC...
dask.config.set({
    'distributed.comm.timeouts.connect': '60s',
    'distributed.comm.timeouts.tcp': '60s',
})

context = ray.init()
print(context.dashboard_url)
with enable_dask_on_ray():

# with SLURMCluster(
#     # Number of Dask workers per node
#     processes=4,
#     # Regular memory node type on PSC bridges2
#     queue="RM",
#     # dask_jobqueue requires cores and memory to be specified
#     # We set them to match RM specs
#     cores=128,
#     memory="256GB",
#     walltime="12:00:00",
# ) as cluster:
#     # Run multiple jobs
#     # cluster.scale(jobs=10)
#     # Allow to run more jobs
#     cluster.adapt(maximum_jobs=10)

#     with Client(cluster) as client:
# with Client(n_workers=4) as client:
        # print_client_info(client)

        stars = lsdb.read_hipscat(STARS_PATH)
        sfd = lsdb.read_hipscat(SFD_PATH)
        matched = stars.crossmatch(
            sfd,
            algorithm=JoinWithContinuousMap,
            search_algo='auto',
        )
        mean_sfd = matched._ddf[f'ebv_{sfd.name}'].mean().compute()
mean_sfd

with open('ps1-multiorder.txt', 'w') as f:
    f.write(f'{mean_sfd = }\n')

2023-12-18 16:39:15,374	INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


127.0.0.1:8265




Alternative approach: use dustmaps package

In [None]:
%%time
import dask
import pandas as pd
from astropy.coordinates import SkyCoord
from dustmaps.sfd import SFDQuery

# Get original SFD FITS file location, INPUT_DIR
from paths import *


def worker(df, query):
    coord = SkyCoord(
        ra=df[stars.hc_structure.catalog_info.ra_column],
        dec=df[stars.hc_structure.catalog_info.dec_column],
        unit='deg',
    )
    ebv = query(coord)
    return pd.DataFrame(dict(ebv=ebv), index=df.index)


# context = ray.init()
# print(context.dashboard_url)
with enable_dask_on_ray():

# with Client(n_workers=24) as client:
    # print_client_info(client)
    
    query = dask.delayed(SFDQuery, pure=True, traverse=False)(INPUT_DIR)
    
    stars = lsdb.read_hipscat(STARS_PATH)
    values = stars._ddf.map_partitions(worker, query, meta={'ebv': np.float32})
    mean_values = values.mean().compute()
    
print(mean_values)

with open('ps1-dustmaps.txt', 'w') as f:
    f.write(f'{mean_values = }\n')

### Validation

First, we check that both hipscat indexes and SFD pixel index-order pair are all consistent 

In [None]:
np.testing.assert_array_equal(
    hipscat_id_to_healpix(result[f'_hipscat_index_{SFD_NAME}'], result[f'pixel_Norder_{SFD_NAME}']),
    result[f'pixel_Npix_{SFD_NAME}'],
)
np.testing.assert_array_equal(
    hipscat_id_to_healpix(result.index, result[f'pixel_Norder_{SFD_NAME}']),
    result[f'pixel_Npix_{SFD_NAME}'],
)

Check that SFD map values are close enough to the ones from `dustmap` module.
The difference must be below 16% for fixed order and 1% for multiorder.

In [None]:
# Validate
from astropy.coordinates import SkyCoord
from dustmaps.sfd import SFDQuery

sfd_query = SFDQuery(INPUT_DIR)
coord = SkyCoord(ra=result['ra_small_sky_order1'], dec=result['dec_small_sky_order1'], unit='deg')
dustmaps_sfd_values = sfd_query(coord)

diff = (
    np.abs(result[f'ebv_{SFD_NAME}'] - dustmaps_sfd_values)
    / np.where(result[f'ebv_{SFD_NAME}'] > dustmaps_sfd_values, result[f'ebv_{SFD_NAME}'], dustmaps_sfd_values)
)
i = np.argsort(diff)[::-1]
display(result.assign(diff=diff, ebv_dustmap=dustmaps_sfd_values).iloc[i[:10]])
diff.max()

In [None]:
area17 = 4 ** (17 - sfd._ddf['pixel_Norder'].astype(np.uint64))
area17.sum().compute(), 12 * 4 ** 17

In [None]:
import pyarrow.parquet as pq

for norder in range(8, 18):
    count = (sfd._ddf['pixel_Norder'] == norder).sum().compute()
    count_real = pq.read_metadata(PARQUET_DIR / f'pixel_Norder={norder:02d}.parquet').num_rows
    print(norder, count - count_real) 

In [None]:
import dask.array as da

index = sfd._ddf['_hipscat_index'].to_dask_array(lengths=True)
display(da.sum(da.diff(index) <= 0).compute())
index.argmin().compute()

In [None]:
from hipscat.pixel_math.hipscat_id import healpix_to_hipscat_id

index = sfd._ddf['_hipscat_index'].to_dask_array(lengths=True)
diff_index = da.diff(index)
diff_index_from_norder = sfd._ddf['pixel_Norder'].to_dask_array(lengths=True).astype(np.uint64).map_blocks(lambda order: healpix_to_hipscat_id(order, 1))[:-1]

da.sum((diff_index != diff_index_from_norder).astype(np.uint64)).compute()