In [None]:
#!/usr/bin/env python3
"""
Neven Caplar 
Last updated: 2023-11-30

Goals: 
Fit the data

Each Section can/should run independently,
only these initial imports should be shared among all sections

Open questions:
None at the moment
"""

import os

import numpy as np
import pandas as pd
import pyarrow as pa

# from scipy.spatial import KDTree
import matplotlib.pyplot as plt

import JaxPeriodDrwFit


from tape.ensemble import Ensemble
from tape.utils import ColumnMapper

from tqdm import tqdm

from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

In [None]:
import dask
# many workers
# dask.config.set(scheduler='threads') 

dask.config.set({"temporary-directory" :'/epyc/ssd/users/ncaplar/tmp'})

# does not work
# from multiprocessing.pool import ThreadPool
# dask.config.set(pool=ThreadPool(20))

# one worker
# dask.config.set(scheduler='processes')  
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=8, threads_per_worker=1)
client = Client(cluster)
# cluster.adapt(minimum=10, maximum=40) 

In [None]:
dask.config.get("temporary-directory")

In [None]:

ens = Ensemble(client = client)  # initialize an ensemble object
ens.client_info()

In [None]:
dask.config.get("temporary-directory")

In [None]:
# Setup base directory for saving output files
username= "ncaplar"
basedir = f"/astro/users/{username}/data/"

In [None]:
# Setup base directory for saving output files
username= "ncaplar"
basedir = f"/astro/users/{username}/data/"

# ZTF

## ZTF - Loading original data 

In [None]:
# TODO - move to shuffling notebook

In [None]:
rel_path = "/data3/epyc/data3/hipscat/catalogs/ztf_dr14_x_agns_source"

from pyarrow import parquet
# parquet.read_schema(f"{rel_path}/Norder=8/Dir=210000/Npix=217286.parquet", memory_map=True)

parquet.read_schema("/epyc/projects3/sean_hipscat/agns_x_ztf_source/Norder=8/Dir=210000/Npix=217286.parquet", memory_map=True)

In [None]:
"""
colmap = ColumnMapper(id_col="_hipscat_index",
                      time_col="mjd_ztf_source",
                      flux_col="mag_ztf_source",
                      err_col="magerr_ztf_source",
                      band_col="band_ztf_source")
ens.from_parquet(source_file="/epyc/projects3/sean_hipscat/agns_x_ztf_source/Norder*/Dir*/Npix*.parquet",
                 #object_file=datapath+"object/*.parquet",
                 column_mapper=colmap,
                 partition_size="1000MB")
ens._source.reset_index().set_index("_hipscat_index").to_parquet("/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big/")
"""


## ZTF - Repartitioned dataset / partitioning and fitting and shuffling

In [None]:
# TODO - move the shuffling part to the shuffling notebook

In [None]:
import glob

directory_path = '/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big/'
file_pattern = 'part.25*.parquet'

matching_files = glob.glob(f'{directory_path}{file_pattern}')

for file_path in matching_files:
    print(file_path)

In [None]:
names = []
for i in range(250,254):
    name = f"/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big/part.{i}.parquet"
    names.append(name)
    
names

In [None]:
colmap = ColumnMapper(id_col="_hipscat_index",
                      time_col="mjd_ztf_source",
                      flux_col="mag_ztf_source",
                      err_col="magerr_ztf_source",
                      band_col="band_ztf_source")
ens.from_parquet(source_file=names,
                 #object_file=datapath+"object/*.parquet",
                 column_mapper=colmap,
                 sorted=True)

### load into memory

In [None]:
ens_object = ens._object.compute()
ens_source = ens._source.compute()

In [None]:
# in GB
ens_source.memory_usage(deep=True).sum()/1e9

In [None]:
len(ens_object)

In [None]:
ens_source['SDSS_NAME_dr16q_constant'].nunique()

### investigating the sample and fit 

In [None]:
ens.check_lightcurve_cohesion()

In [None]:
ens.query("band_ztf_source == 'g'", table="source").prune(50)
ens._lazy_sync_tables(table="object")
ens.calc_nobs(temporary=False)

data = ens.compute('object')['nobs_total'].values.astype(int)

bin_edges = range(0, 321 + 21, 20) 
plt.hist(data, bins=bin_edges, edgecolor='k')

In [None]:
def count_rows(partition):
    return len(partition)

# Let us try the same thing again
n_sources_per_div = ens._source.map_partitions(count_rows, meta=int).compute()

print("Number of rows in each partition:", n_sources_per_div)

In [None]:
# 18 minutes, 95.17 GB used
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp_drw = ens.batch(JaxPeriodDrwFit_instance.optimize_map_drw, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100)

## ZTF - reshuffled and fit 

In [None]:
import glob

directory_path = '/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/'
file_pattern = 'part.*.parquet'

matching_files = glob.glob(f'{directory_path}{file_pattern}')

for file_path in matching_files:
    print(file_path)

In [None]:
test_pd = pd.read_parquet(matching_files[0])
test_pd


In [None]:
names = []
for i in range(0,4):
    name = f"/astro/store/epyc3/data3/hipscat/catalogs/ztf_dr14_x_agns_source_repar_big_shuffled_test/part.{i}.parquet"
    names.append(name)
    
names

In [None]:
colmap = ColumnMapper(id_col="count",
                      time_col="mjd_ztf_source",
                      flux_col="mag_ztf_source",
                      err_col="magerr_ztf_source",
                      band_col="band_ztf_source")
ens.from_parquet(source_file=names,
                 #object_file=datapath+"object/*.parquet",
                 column_mapper=colmap,
                 sorted=True)

In [None]:
ens._object

In [None]:
ens_c_ix = ens._source._hipscat_index.compute()

In [None]:
n_count = np.sort(np.unique(ens_c_ix.values, return_counts=True))[1].astype(int)

In [None]:
# custom divisions in the middle
n_count[np.array([len(n_count)/4, 2* len(n_count)/4, 3* len(n_count)/4]).astype(int)]


In [None]:
ens._source.divisions

In [None]:
custom_divisions = (1.0, 54.0, 135.0, 207.0, 1012.0)
ens._source = ens._source.repartition(divisions=custom_divisions)
ens._source.divisions


In [None]:
c1 = np.unique(ens._source.get_partition(0).compute()['_hipscat_index'].values, return_counts=True)  
c3 = np.unique(ens._source.get_partition(3).compute()['_hipscat_index'].values, return_counts=True)  

In [None]:
bin_edges = range(0, 721 + 21, 20) 

plt.figure(figsize=(12, 5))

plt.subplot(121 )
plt.hist(c1[1], bins=bin_edges, edgecolor='k')
plt.ylim(0, 5000)

plt.subplot(122 )
plt.hist(c3[1], bins=bin_edges, edgecolor='k')
plt.ylim(0, 5000)

In [None]:
# 13.5 minutes, 22.84 GB used without repartitioning

JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
res_tsp_drw = ens.batch(JaxPeriodDrwFit_instance.optimize_map_drw, 'mjd_ztf_source', "mag_ztf_source", "magerr_ztf_source",
                compute=True, meta=None, n_init=100)