In [1]:
#!/usr/bin/env python3
"""
Neven Caplar 
Last updated: 2023-11-30

Goals: 
Fit the data

Each Section can/should run independently,
only these initial imports should be shared among all sections

Open questions:
None at the moment
"""

import os

import numpy as np
import pandas as pd
import pyarrow as pa

# from scipy.spatial import KDTree
import matplotlib.pyplot as plt

import JaxPeriodDrwFit


from tape.ensemble import Ensemble
from tape.utils import ColumnMapper

from tqdm import tqdm

from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

In [2]:
import dask
# many workers
# dask.config.set(scheduler='threads') 

dask.config.set({"temporary-directory" :'/epyc/ssd/users/ncaplar/tmp'})

# does not work
# from multiprocessing.pool import ThreadPool
# dask.config.set(pool=ThreadPool(20))

# one worker
# dask.config.set(scheduler='processes')  
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=8, threads_per_worker=1)
client = Client(cluster)
# cluster.adapt(minimum=10, maximum=40) 

In [3]:
dask.config.get("temporary-directory")

'/epyc/ssd/users/ncaplar/tmp'

In [4]:

ens = Ensemble(client = client)  # initialize an ensemble object
ens.client_info()

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 8,Total memory: 251.68 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:44876,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 251.68 GiB

0,1
Comm: tcp://127.0.0.1:43908,Total threads: 1
Dashboard: http://127.0.0.1:41354/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:41625,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-4ucia93p,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-4ucia93p
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:45024,Total threads: 1
Dashboard: http://127.0.0.1:43490/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:34369,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-ohd313yf,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-ohd313yf
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:46440,Total threads: 1
Dashboard: http://127.0.0.1:46367/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:34607,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-34wqc3kz,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-34wqc3kz
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:43606,Total threads: 1
Dashboard: http://127.0.0.1:33806/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:33704,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-tg10uwlz,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-tg10uwlz
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:43154,Total threads: 1
Dashboard: http://127.0.0.1:36603/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:38712,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-62j3xan6,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-62j3xan6
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:42095,Total threads: 1
Dashboard: http://127.0.0.1:46247/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:36471,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-96_7sh1w,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-96_7sh1w
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:44407,Total threads: 1
Dashboard: http://127.0.0.1:46256/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:35681,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-_yh3cqtx,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-_yh3cqtx
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB

0,1
Comm: tcp://127.0.0.1:46846,Total threads: 1
Dashboard: http://127.0.0.1:42147/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:39032,
Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-skmmzlqc,Local directory: /epyc/ssd/users/ncaplar/tmp/dask-scratch-space/worker-skmmzlqc
GPU: NVIDIA GeForce RTX 2080 Ti,GPU memory: 11.00 GiB


In [5]:
dask.config.get("temporary-directory")

'/epyc/ssd/users/ncaplar/tmp'

In [None]:
# Setup base directory for saving output files
username= "ncaplar"
basedir = f"/astro/users/{username}/data/"

In [6]:
# Setup base directory for saving output files
username= "ncaplar"
basedir = f"/astro/users/{username}/data/"

# ZTF

## ZTF - Loading original data 

In [None]:
# TODO - move to shuffling notebook

In [None]:
rel_path = "/data3/epyc/data3/hipscat/catalogs/ztf_dr14_x_agns_source"

from pyarrow import parquet
# parquet.read_schema(f"{rel_path}/Norder=8/Dir=210000/Npix=217286.parquet", memory_map=True)

parquet.read_schema("/epyc/projects3/sean_hipscat/agns_x_ztf_source/Norder=8/Dir=210000/Npix=217286.parquet", memory_map=True)

In [None]:
"""
colmap = ColumnMapper(id_col="_hipscat_index",
                      time_col="mjd_ztf_source",
                      flux_col="mag_ztf_source",
                      err_col="magerr_ztf_source",
                      band_col="band_ztf_source")
ens.from_parquet(source_file="/epyc/projects3/sean_hipscat/agns_x_ztf_source/Norder*/Dir*/Npix*.parquet",
                 #object_file=datapath+"object/*.parquet",
                 column_mapper=colmap,
                 partition_size="1000MB")
ens._source.reset_index().set_index("_hipscat_index").to_parquet("/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big/")
"""


## ZTF - Repartitioned dataset / partitioning and fitting and shuffling

In [None]:
# TODO - move the shuffling part to the shuffling notebook

In [None]:
import glob

directory_path = '/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big/'
file_pattern = 'part.25*.parquet'

matching_files = glob.glob(f'{directory_path}{file_pattern}')

for file_path in matching_files:
    print(file_path)

In [None]:
names = []
for i in range(250,254):
    name = f"/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big/part.{i}.parquet"
    names.append(name)
    
names

In [None]:
colmap = ColumnMapper(id_col="_hipscat_index",
                      time_col="mjd_ztf_source",
                      flux_col="mag_ztf_source",
                      err_col="magerr_ztf_source",
                      band_col="band_ztf_source")
ens.from_parquet(source_file=names,
                 #object_file=datapath+"object/*.parquet",
                 column_mapper=colmap,
                 sorted=True)

### load into memory

In [None]:
ens_object = ens._object.compute()
ens_source = ens._source.compute()

In [None]:
# in GB
ens_source.memory_usage(deep=True).sum()/1e9

In [None]:
len(ens_object)

In [None]:
ens_source['SDSS_NAME_dr16q_constant'].nunique()

### investigating the sample and fit 

In [None]:
ens.check_lightcurve_cohesion()

In [None]:
ens.query("band_ztf_source == 'g'", table="source").prune(50)
ens._lazy_sync_tables(table="object")
ens.calc_nobs(temporary=False)

data = ens.compute('object')['nobs_total'].values.astype(int)

bin_edges = range(0, 321 + 21, 20) 
plt.hist(data, bins=bin_edges, edgecolor='k')

In [None]:
def count_rows(partition):
    return len(partition)

# Let us try the same thing again
n_sources_per_div = ens._source.map_partitions(count_rows, meta=int).compute()

print("Number of rows in each partition:", n_sources_per_div)

In [None]:
# 18 minutes, 95.17 GB used
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp_drw = ens.batch(JaxPeriodDrwFit_instance.optimize_map_drw, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100)

## ZTF - reshuffled and fit 

In [6]:
import glob

directory_path = '/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/'
file_pattern = 'part.*.parquet'

matching_files = glob.glob(f'{directory_path}{file_pattern}')

for file_path in matching_files:
    print(file_path)

/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.3.parquet
/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.2.parquet
/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.1.parquet
/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.0.parquet


In [7]:
test_pd = pd.read_parquet(matching_files[0])
test_pd


Unnamed: 0_level_0,count,_hipscat_index,SDSS_NAME_dr16q_constant,PLATE_dr16q_constant,MJD_dr16q_constant,FIBERID_dr16q_constant,RA_dr16q_constant,DEC_dr16q_constant,OBJID_dr16q_constant,IF_BOSS_SDSS_dr16q_constant,...,magerr_ztf_source,mjd_ztf_source,rcID_ztf_source,band_ztf_source,Norder_ztf_source,Dir_ztf_source,Npix_ztf_source,Norder,Dir,provenance
__null_dask_index__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,298.0,7343808069173772288,135628.50+045054.4,856,52339,458,209.118759,4.848461,0856-52339-0458,SDSS,...,0.050652,59274.44767,46,g,6,20000,26090,6,20000,survey_1
1,298.0,7416242961983209472,122749.00+053201.9,2880,54509,581,186.954179,5.533870,2880-54509-0581,SDSS,...,0.079153,58956.25185,39,g,6,20000,26347,6,20000,survey_1
2,298.0,7416242961983209472,122749.00+053201.9,2880,54509,581,186.954179,5.533870,2880-54509-0581,SDSS,...,0.083438,58955.30475,39,g,6,20000,26347,6,20000,survey_1
3,298.0,7416242961983209472,122749.00+053201.9,2880,54509,581,186.954179,5.533870,2880-54509-0581,SDSS,...,0.078276,58955.30152,39,g,6,20000,26347,6,20000,survey_1
4,298.0,7416242961983209472,122749.00+053201.9,2880,54509,581,186.954179,5.533870,2880-54509-0581,SDSS,...,0.069537,58944.25345,39,g,6,20000,26347,6,20000,survey_1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
638022,1012.0,7488300675512664064,131254.92+151528.7,1773,53112,337,198.228874,15.257989,1773-53112-0337,SDSS,...,0.074948,59639.40354,60,g,6,20000,26603,6,20000,survey_1
638023,1012.0,7488300675512664064,131254.92+151528.7,1773,53112,337,198.228874,15.257989,1773-53112-0337,SDSS,...,0.067980,59625.38157,60,g,6,20000,26603,6,20000,survey_1
638024,1012.0,7488300675512664064,131254.92+151528.7,1773,53112,337,198.228874,15.257989,1773-53112-0337,SDSS,...,0.078997,59625.36090,3,g,6,20000,26603,6,20000,survey_1
638025,1012.0,7488300675512664064,131254.92+151528.7,1773,53112,337,198.228874,15.257989,1773-53112-0337,SDSS,...,0.072453,59623.32152,3,g,6,20000,26603,6,20000,survey_1


In [13]:
ls = test_pd[test_pd['_hipscat_index'] == 7416242961983209472]
ls['mjd_ztf_source'].values

array([58956.25185, 58955.30475, 58955.30152, 58944.25345, 58941.27305,
       58968.27727, 59290.30288, 58994.22464, 58995.20446, 59278.32219,
       59275.44972, 59272.41197, 59270.40558, 59268.44115, 59262.42162,
       59257.38351, 59255.38222, 59253.38323, 58962.23679, 58962.23589,
       58964.1943 , 58963.2283 , 58911.40306, 58546.29754, 58547.32199,
       58482.54343, 59309.31494, 59309.2138 , 59308.3652 , 59321.25142,
       59308.27701, 59306.3602 , 59305.44069, 59305.28297, 58482.54389,
       58556.37929, 58487.48479, 58524.42666, 58233.2387 , 58559.3599 ,
       58286.21694, 58442.50248, 58912.2794 , 58913.33778, 59216.53664,
       59198.52503, 58662.19698, 58249.19353, 58249.19444, 58272.19203,
       58886.46379, 59225.46684, 58995.20353, 59223.40131, 59298.33867,
       58997.18914, 59303.36476, 58587.23765, 58674.20227, 58914.42757,
       58274.23006, 58560.27857, 58558.38408, 58283.23547, 58280.21328,
       58653.25615, 58898.40147, 58838.49827, 58828.54516, 59204

In [15]:
ls.columns

Index(['count', '_hipscat_index', 'SDSS_NAME_dr16q_constant',
       'PLATE_dr16q_constant', 'MJD_dr16q_constant', 'FIBERID_dr16q_constant',
       'RA_dr16q_constant', 'DEC_dr16q_constant', 'OBJID_dr16q_constant',
       'IF_BOSS_SDSS_dr16q_constant', 'Z_DR16Q_dr16q_constant',
       'SOURCE_Z_DR16Q_dr16q_constant', 'Z_FIT_dr16q_constant',
       'Z_SYS_dr16q_constant', 'Z_SYS_ERR_dr16q_constant',
       'EBV_dr16q_constant', 'SN_MEDIAN_ALL_dr16q_constant',
       'FEII_UV_EW_dr16q_constant', 'FEII_UV_EW_ERR_dr16q_constant',
       'FEII_OPT_EW_dr16q_constant', 'FEII_OPT_EW_ERR_dr16q_constant',
       'LOGL1350_dr16q_constant', 'LOGL1350_ERR_dr16q_constant',
       'LOGL1700_dr16q_constant', 'LOGL1700_ERR_dr16q_constant',
       'LOGL2500_dr16q_constant', 'LOGL2500_ERR_dr16q_constant',
       'LOGL3000_dr16q_constant', 'LOGL3000_ERR_dr16q_constant',
       'LOGL5100_dr16q_constant', 'LOGL5100_ERR_dr16q_constant',
       'LOGLBOL_dr16q_constant', 'LOGLBOL_ERR_dr16q_constant',
       'L

In [17]:
ls['mag_ztf_source'].values

array([19.312641, 19.383255, 19.29778 , 19.140474, 19.288923, 19.351116,
       19.249687, 19.46319 , 19.3588  , 19.206415, 19.153954, 19.065487,
       19.399252, 19.282434, 19.217   , 19.264502, 19.301836, 19.202887,
       19.416456, 19.35193 , 19.353617, 19.365479, 19.289461, 19.37284 ,
       19.408007, 19.307314, 19.2956  , 19.27231 , 19.334406, 19.326107,
       19.238342, 19.311033, 19.129662, 19.664099, 19.252045, 19.325909,
       19.385979, 19.40222 , 19.454357, 19.190248, 19.478886, 19.2441  ,
       19.235447, 19.25444 , 19.092367, 19.241459, 19.272612, 19.3093  ,
       19.363537, 19.398252, 19.324017, 19.174726, 19.425188, 19.2279  ,
       19.211605, 19.453007, 19.194225, 19.194786, 19.165335, 19.300331,
       19.4272  , 19.304312, 19.428223, 19.408361, 19.39347 , 19.251125,
       19.336624, 19.164324, 19.0892  , 19.171383, 19.384426, 19.051455,
       19.19374 , 19.24575 , 19.280582, 19.442562, 19.263414, 19.309263,
       19.423101, 19.278053, 19.331123, 19.375303, 

In [14]:
ls['magerr_ztf_source'].values

array([0.07915251, 0.0834377 , 0.07827622, 0.0695373 , 0.07775825,
       0.08146235, 0.07550106, 0.08853268, 0.08193084, 0.07308243,
       0.07024831, 0.06570782, 0.08443651, 0.07738075, 0.07366718,
       0.07634608, 0.07851453, 0.07288845, 0.08552233, 0.08151188,
       0.08161458, 0.08233994, 0.0777896 , 0.08279303, 0.08498754,
       0.07883736, 0.07814847, 0.07679506, 0.08045178, 0.07995415,
       0.0748598 , 0.07905726, 0.06897193, 0.10247952, 0.07563499,
       0.07994224, 0.08360704, 0.084623  , 0.08795694, 0.07219763,
       0.0895637 , 0.07518469, 0.07469703, 0.07577123, 0.06705619,
       0.07503549, 0.07681245, 0.07895472, 0.08222079, 0.08437383,
       0.0798292 , 0.07135777, 0.08607806, 0.07427403, 0.07336865,
       0.08786917, 0.07241432, 0.07244499, 0.07085408, 0.07842601,
       0.08620656, 0.07866032, 0.08627194, 0.08500998, 0.08407439,
       0.07558276, 0.0805853 , 0.07080007, 0.06689584, 0.07117806,
       0.08351043, 0.06501463, 0.07238792, 0.07527799, 0.07727

In [9]:
names = []
for i in range(0,4):
    name = f"/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.{i}.parquet"
    names.append(name)
    
names

['/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.0.parquet',
 '/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.1.parquet',
 '/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.2.parquet',
 '/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.3.parquet']

In [10]:
colmap = ColumnMapper(id_col="count",
                      time_col="mjd_ztf_source",
                      flux_col="mag_ztf_source",
                      err_col="magerr_ztf_source",
                      band_col="band_ztf_source")
ens.from_parquet(source_file=names,
                 #object_file=datapath+"object/*.parquet",
                 column_mapper=colmap,
                 sorted=True)

<tape.ensemble.Ensemble at 0x7fa28edbfc70>

In [12]:
ens._object

1.0
146.0
204.0
298.0
1012.0


In [13]:
ens_c_ix = ens._source._hipscat_index.compute()

In [15]:
len(ens_c_ix)

2494598

In [16]:
n_count = np.sort(np.unique(ens_c_ix.values, return_counts=True))[1].astype(int)
n_count

array([   1,    1,    1, ...,  898,  929, 1012])

In [17]:
len(n_count)

16923

In [None]:
# custom divisions in the middle
n_count[np.array([len(n_count)/4, 2* len(n_count)/4, 3* len(n_count)/4]).astype(int)]


In [None]:
ens._source.divisions

In [None]:
custom_divisions = (1.0, 54.0, 135.0, 207.0, 1012.0)
ens._source = ens._source.repartition(divisions=custom_divisions)
ens._source.divisions


In [None]:
c1 = np.unique(ens._source.get_partition(0).compute()['_hipscat_index'].values, return_counts=True)  
c3 = np.unique(ens._source.get_partition(3).compute()['_hipscat_index'].values, return_counts=True)  

In [None]:
bin_edges = range(0, 721 + 21, 20) 

plt.figure(figsize=(12, 5))

plt.subplot(121 )
plt.hist(c1[1], bins=bin_edges, edgecolor='k')
plt.ylim(0, 5000)

plt.subplot(122 )
plt.hist(c3[1], bins=bin_edges, edgecolor='k')
plt.ylim(0, 5000)

In [None]:
# 13.5 minutes, 22.84 GB used without repartitioning

JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp_drw = ens.batch(JaxPeriodDrwFit_instance.optimize_map_drw, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100)