In [1]:
#!/usr/bin/env python3
"""
Neven Caplar 
Last updated: 2023-10-07

Goals: 
Fit the data

Each Section can/should run independently,
only these initial imports should be shared among all sections

Questions:
What determines memory limit of the workers
How to partition the dataframe in order to get more workers active


"""
import numpy as np
import pandas as pd
import pyarrow as pa

# from scipy.spatial import KDTree
import matplotlib.pyplot as plt

import JaxPeriodDrwFit


from tape.ensemble import Ensemble
from tape.utils import ColumnMapper


from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

In [2]:
import dask
# many workers
# dask.config.set(scheduler='threads') 

# does not work
# from multiprocessing.pool import ThreadPool
# dask.config.set(pool=ThreadPool(20))

# one worker
# dask.config.set(scheduler='processes')  
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()
client = Client(cluster)
# cluster.adapt(minimum=10, maximum=40) 

In [3]:
ens = Ensemble(client = client)  # initialize an ensemble object
ens.client_info()


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 64,Total memory: 251.68 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:42239,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 64
Started: Just now,Total memory: 251.68 GiB

0,1
Comm: tcp://127.0.0.1:33541,Total threads: 8
Dashboard: http://127.0.0.1:39024/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:38733,
Local directory: /tmp/dask-scratch-space-1398143/worker-gabjnm_4,Local directory: /tmp/dask-scratch-space-1398143/worker-gabjnm_4
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:34321,Total threads: 8
Dashboard: http://127.0.0.1:34939/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:46689,
Local directory: /tmp/dask-scratch-space-1398143/worker-9mmpkv22,Local directory: /tmp/dask-scratch-space-1398143/worker-9mmpkv22
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:36782,Total threads: 8
Dashboard: http://127.0.0.1:46542/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:38527,
Local directory: /tmp/dask-scratch-space-1398143/worker-5skcbxwz,Local directory: /tmp/dask-scratch-space-1398143/worker-5skcbxwz
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:39229,Total threads: 8
Dashboard: http://127.0.0.1:46538/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:43117,
Local directory: /tmp/dask-scratch-space-1398143/worker-tok5ku9a,Local directory: /tmp/dask-scratch-space-1398143/worker-tok5ku9a
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:33206,Total threads: 8
Dashboard: http://127.0.0.1:41855/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:39730,
Local directory: /tmp/dask-scratch-space-1398143/worker-dgvph7yd,Local directory: /tmp/dask-scratch-space-1398143/worker-dgvph7yd
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:37473,Total threads: 8
Dashboard: http://127.0.0.1:44010/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:37068,
Local directory: /tmp/dask-scratch-space-1398143/worker-4vj2jl7v,Local directory: /tmp/dask-scratch-space-1398143/worker-4vj2jl7v
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:44544,Total threads: 8
Dashboard: http://127.0.0.1:46579/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:45808,
Local directory: /tmp/dask-scratch-space-1398143/worker-xhz7c5to,Local directory: /tmp/dask-scratch-space-1398143/worker-xhz7c5to
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:46037,Total threads: 8
Dashboard: http://127.0.0.1:41808/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:35341,
Local directory: /tmp/dask-scratch-space-1398143/worker-6ljqo5a1,Local directory: /tmp/dask-scratch-space-1398143/worker-6ljqo5a1
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB


# Setup base directory for saving output files

In [4]:
username= "wbeebe"
basedir = f"/astro/users/{username}/data/"

# Tape Single Pixel - real data 

In [5]:
# if running on baldur
data_path = "/astro/store/epyc/data3/hipscat/catalogs/tape_test/"

col_map = ColumnMapper(id_col="SDSS_NAME_dr16q_constant", 
                       time_col="mjd_ztf_source",
                       flux_col="mag_ztf_source", 
                       err_col="magerr_ztf_source",
                       band_col="band_ztf_source")

ens.from_hipscat(data_path,
                 source_subdir="tape_test_sources",
                 object_subdir="tape_test_obj",
                 column_mapper=col_map,
                 additional_cols=True,
                 sync_tables=True,
                 npartitions=10
                 )



<tape.ensemble.Ensemble at 0x7f0025f7f6d0>

Filter data

In [6]:
ens.query("band_ztf_source == 'g'", table = 'source')
ens.prune(10)
ens.query("rMeanPSFMag_ps1_otmo < 20", table = 'object')

<tape.ensemble.Ensemble at 0x7f0025f7f6d0>

In [7]:
# 9min, 13 sec on baldur, for 603 sources in 4 partitions
# 7min, 30 sec on baldur, for 603 sources in 4 partitions, Nov 7
# 5min, 14 sec on baldur, for 603 sources in 10 partitions, Nov 7
# 3min, 3 sec on baldur, with padding 
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp = ens.batch(JaxPeriodDrwFit_instance.optimize_map, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [8]:
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp_drw = ens.batch(JaxPeriodDrwFit_instance.optimize_map_drw, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100)

In [74]:
username="wbeebe"
def pack_output_to_parquet(result, cols, output_dir, output_filename, drop_cols=[], full=False):
    """Packs output to a dataframe, written as a parquet file. The created dataframe object is returned for inspection."""
    result_df = None
    if full:
        # Construct dataframes with the results for each object.
        dfs = []
        for i in range(len(result)):
            obj_data = result.iloc[i]
            # Construct a series representing the index
            obj_index = pd.Series(np.full(len(obj_data), result.index[i]), name=result.index.name)
            dfs.append(pd.DataFrame(data=obj_data, columns=cols, index=obj_index))

        # Concatenate all of the per-object dataframes
        result_df = pd.concat(dfs)
    else:
        # Each object only has a 1D array in the result series, so the constructed
        # dataframe has the same number of rows. So we can just do a 1:1 mapping with column names 
        result_df = pd.DataFrame(columns=cols, index=result.index)
        for i in range(len(result)):
            result_df.iloc[i] = result[i]

    # Drop any columns if requested.
    if drop_cols:
        result_df = result_df.drop(columns=drop_cols)

    # Write the output to the parquet file
    pa_table = pa.Table.from_pandas(result_df)
    pa_table
    pa.parquet.write_table(pa_table, f"{output_dir}/data/{output_filename}.parquet")
    return result_df

# Create columns for result of using just the drw kernel
param_cols = ['log_drw_scale', 'log_drw_amp']
init_param_cols = ["init_" + c for c in param_cols]
drw_columns = ['min_neg_log_lh', 'neg_log_lh'] + param_cols + init_param_cols

# Create columns for result of combining the drw params with periodic params
param_cols = ['log_drw_scale', 'log_drw_amp', 'log_per_scale', 'log_per_amp']
init_param_cols = ["init_" + c for c in param_cols]
combined_columns = ['min_neg_log_lh', 'neg_log_lh'] + param_cols + init_param_cols


In [68]:
# Save output for results from just the drw kernel
drw_df = pack_output_to_parquet(res_tsp_drw, drw_columns,
                       f"/astro/users/{username}", "res_tsp_run_g_0_drw")
drw_df

  result_df.iloc[i] = result[i]
  if _pandas_api.is_sparse(col):


Unnamed: 0_level_0,min_neg_log_lh,neg_log_lh,log_drw_scale,log_drw_amp,init_log_drw_scale,init_log_drw_amp
SDSS_NAME_dr16q_constant,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
b'024050.36-003109.0',40.08991,40.08991,2.980099,-0.760837,3.534287,0.225864
b'024052.82-004110.9',-34.805998,-34.805998,-3.807629,-4.933484,0.979914,0.808098
b'024126.71-004526.3',-28.434609,-28.434609,2.066565,-1.24888,1.872701,-2.842854
b'024151.76-001953.6',35.285865,35.285865,2.441909,-1.277017,3.854836,-1.304851
b'024154.42-004757.6',-12.792811,-12.792811,1.545028,-4.170314,1.554912,-1.710292
...,...,...,...,...,...,...
b'024400.64+004723.0',-107.04529,-107.04529,-20.117763,-10.030782,1.061696,1.040602
b'024419.10+005539.2',-94.10112,-94.10112,2.633371,-0.931073,2.623782,-2.06715
b'024200.53+005322.2',-229.193073,-229.193073,2.69227,-1.333217,2.087055,-0.912945
b'024202.28+005740.3',2.743977,2.743977,2.718112,-3.644382,2.733551,-2.742606


In [69]:
# Save output for results from the combined drw and periodic kernel
combined_df = pack_output_to_parquet(res_tsp, combined_columns,
                       f"/astro/users/{username}", "res_tsp_run_g_0")
combined_df

  result_df.iloc[i] = result[i]
  if _pandas_api.is_sparse(col):


Unnamed: 0_level_0,min_neg_log_lh,neg_log_lh,log_drw_scale,log_drw_amp,log_per_scale,log_per_amp,init_log_drw_scale,init_log_drw_amp,init_log_per_scale,init_log_per_amp
SDSS_NAME_dr16q_constant,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
b'024050.36-003109.0',39.447842,39.447842,11.906401,-0.971743,3.427691,-0.847186,1.504392,-0.830742,3.358503,-2.099382
b'024052.82-004110.9',-35.156772,-35.156772,-12.49335,-10.902907,2.738975,-1.331822,0.232252,-0.446263,3.155693,-2.31874
b'024126.71-004526.3',-29.632743,-29.632743,-10.076967,-6.401982,2.944183,-1.101969,0.780093,1.537832,3.032145,-1.003249
b'024151.76-001953.6',31.643813,31.643813,2.550391,-1.240236,1.469939,-1.185938,4.436064,-0.351747,1.432706,-0.935809
b'024154.42-004757.6',-12.792812,-12.792812,3.320252,-1.58099,-0.159079,-2.115329,2.137705,1.485551,1.079105,-1.820266
...,...,...,...,...,...,...,...,...,...,...
b'024400.64+004723.0',-107.04529,-107.04529,-14.353206,-7.598119,6.842181,-8.98097,0.697469,1.480456,2.841543,-1.034861
b'024419.10+005539.2',-97.450038,-97.450038,2.66236,-1.459577,2.504225,-1.096036,3.645036,-2.128168,2.507581,-2.122817
b'024200.53+005322.2',-230.257416,-230.257416,1.953016,-2.081511,3.348743,-1.408624,1.523069,-0.406047,1.403862,-0.647483
b'024202.28+005740.3',2.743976,2.743976,0.576741,-11.326638,-12.233244,-6.243121,1.061696,1.040602,1.121347,-0.352717


# Redo but save all results (full=True)

In [71]:
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp_full = ens.batch(JaxPeriodDrwFit_instance.optimize_map, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100, full=True)

In [72]:
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp_drw_full = ens.batch(JaxPeriodDrwFit_instance.optimize_map_drw, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100, full=True)

In [75]:
# Save output for results from just the drw kernel
drw_df_full = pack_output_to_parquet(res_tsp_drw_full, drw_columns,
                       f"/astro/users/{username}", "res_tsp_run_g_0_drw_full", full=True)
drw_df_full

  if _pandas_api.is_sparse(col):


Unnamed: 0_level_0,min_neg_log_lh,neg_log_lh,log_drw_scale,log_drw_amp,init_log_drw_scale,init_log_drw_amp
SDSS_NAME_dr16q_constant,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
b'024050.36-003109.0',40.089910,41.100533,1.874229,-2.834812,1.872701,-2.842854
b'024050.36-003109.0',40.089910,40.089910,2.980098,-0.760837,4.753572,0.182052
b'024050.36-003109.0',40.089910,40.406710,2.691709,0.039295,3.659970,-1.428220
b'024050.36-003109.0',40.089910,40.089910,2.980099,-0.760837,2.993292,-0.457147
b'024050.36-003109.0',40.089910,40.583519,8.919517,0.937089,0.780093,1.537832
...,...,...,...,...,...,...
b'024240.31+005727.1',-763.660976,-763.660976,2.877154,-1.107215,2.468978,-1.253952
b'024240.31+005727.1',-763.660976,-743.254656,1.575367,-1.266480,2.613664,0.629778
b'024240.31+005727.1',-763.660976,-763.660976,2.877154,-1.107215,2.137705,1.485551
b'024240.31+005727.1',-763.660976,-671.623311,-16.918065,-7.352328,0.127096,1.435432


In [76]:
# Save output for results from the combined drw and periodic kernel
combined_df_full = pack_output_to_parquet(res_tsp_full, combined_columns,
                       f"/astro/users/{username}", "res_tsp_run_g_0_full", full=True)
combined_df_full

  if _pandas_api.is_sparse(col):


Unnamed: 0_level_0,min_neg_log_lh,neg_log_lh,log_drw_scale,log_drw_amp,log_per_scale,log_per_amp,init_log_drw_scale,init_log_drw_amp,init_log_per_scale,init_log_per_amp
SDSS_NAME_dr16q_constant,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
b'024050.36-003109.0',39.083982,44.316890,1.898980,-2.704153,3.652927,-2.559627,1.872701,-2.842854,3.210158,-2.857875
b'024050.36-003109.0',39.083982,39.773864,3.215409,-0.760633,1.437213,-1.479241,4.753572,0.182052,0.420700,-1.538775
b'024050.36-003109.0',39.083982,43.786997,2.703628,0.050113,-41.261877,-2.965467,3.659970,-1.428220,0.808144,-1.513253
b'024050.36-003109.0',39.083982,40.089910,2.980099,-0.760837,4.084577,-3.842707,2.993292,-0.457147,4.492771,-1.247068
b'024050.36-003109.0',39.083982,44.385650,-47.359028,-16.907984,-23.533388,-6.329036,0.780093,1.537832,3.032145,-1.003249
...,...,...,...,...,...,...,...,...,...,...
b'024240.31+005727.1',-768.437432,-758.787013,2.475945,-1.252895,2.576596,-0.444865,2.468978,-1.253952,2.611216,-0.440417
b'024240.31+005727.1',-768.437432,-763.630661,2.921942,-1.105673,3.969455,-3.761862,2.613664,0.629778,3.849968,-0.639365
b'024240.31+005727.1',-768.437432,-763.660975,2.877154,-1.107215,3.944142,-3.925888,2.137705,1.485551,1.079105,-1.820266
b'024240.31+005727.1',-768.437432,-761.861483,-16.927255,-7.454603,3.169676,-1.116270,0.127096,1.435432,3.114452,-0.935105
