# Tansu Demo
Very rough draft that runs note that this is meant to be run on epyc or baldur
This takes the code provided [here](https://www.google.com/url?q=https://github.com/tdaylan/miletos/blob/71f7d18542f3ef82808eba588bad6361c8297351/miletos/main.py%23L5192&sa=D&source=docs&ust=1715889313765253&usg=AOvVaw2elaC5NDhYWMIVOZa-rq90) and applies it to a subset of ztf

In [None]:
# pip install lsdb
# pip install lf_tape

import pyarrow.parquet as pq
import numpy as np
import pandas as pd
import os
import lsdb
import tape

from lsdb.core.search import BoxSearch, ConeSearch, PolygonSearch
from tape import Ensemble, ColumnMapper
from hipscat.io.file_io import read_parquet_metadata
print(lsdb.__version__)
print(tape.__version__)

import dask
dask.config.set({"temporary-directory" :'/epyc/ssd/users/wbeebe/tmp'})
dask.config.set({"dataframe.shuffle-compression": 'Snappy'})
dask.config.set({"dataframe.convert-string": False})
from dask.distributed import Client

In [None]:
# initialize dask client 
client = Client(n_workers=8, threads_per_worker=1, memory_limit='40Gb')

In [None]:
ztf_object_path = "/epyc/data3/hipscat/catalogs/ztf_axs/ztf_dr14"
ztf_source_path = "/epyc/data3/hipscat/catalogs/ztf_axs/ztf_source"



In [None]:
ztf_object = lsdb.read_hipscat(ztf_object_path, search_filter=ConeSearch(ra=-60, dec=20, radius_arcsec=1*1600))
#sources load takes a minute, since it creates a healpix alignment on load
ztf_source = lsdb.read_hipscat(ztf_source_path,
                               columns=['index', 'ps1_objid',
                                       'ra', 'dec', 
                                       'catflags', 
                                       'fieldID', 
                                       'mjd', 'band', 'mag', 'magerr', 'Npix'], search_filter=ConeSearch(ra=-60, dec=20, radius_arcsec=1*1600))

In [None]:
ztf_object_100 = ztf_object.query("nobs_g > 100 and nobs_r > 100")
# We do this to get the source catalog indexed by the objects hipscat index
ztf_joined_source_cat = ztf_object_100.join(
    ztf_source, left_on="ps1_objid", right_on="ps1_objid", suffixes=("_object", "")
)

In [None]:
colmap = ColumnMapper(
    id_col="_hipscat_index",
    time_col="mjd",
    flux_col="mag",
    err_col="magerr",  
    band_col="band",
)

ens = Ensemble(client=Client)

# We just pass in the catalog objects
ens.from_lsdb(ztf_joined_source_cat, ztf_object, column_mapper=colmap)

ens.object.compute()

In [None]:
# Defining a simple function
def my_flux_average(flux_array, band_array, method="mean", band=None):
    """Read in an array of fluxes, and return the average of the fluxes by band"""
    if band != None:
        mask = [band_array == band]  # Create a band by band mask
        band_flux = flux_array[tuple(mask)]  # Mask the flux array
        if method == "mean":
            res = np.mean(band_flux)
        elif method == "median":
            res = np.median(band_flux)
    else:
        res = None
    return res

In [None]:
# Applying the function to the ensemble
res = ens.batch(my_flux_average, "mag", "band", meta=None, method="median", band="g")
res_computed = res.compute()

In [None]:
res_computed

In [None]:
ts = ens.to_timeseries(3647494584189059072)  # provided a target object id
ts.data

In [None]:
import matplotlib.pyplot as plt


ts_g = ts.data[ts.band == "g"]

plt.figure(figsize=(8, 5))
plt.errorbar(ts_g.mjd, ts_g.mag, ts_g.magerr, fmt="o", color="green", alpha=0.8, label="g")
plt.xlabel("Time (MJD)")
plt.ylabel("Flux (mJy)")
plt.minorticks_on()
plt.legend(title="Band", loc="upper left")

In [None]:
def srch_outlperi( \
                  # time of samples
                  time, \
                  # relative flux of samples
                  flux, \
                  # relative flux error of samples
                  stdvflux, \
                  # number of outliers to include in the search
                  numboutl=5, \
                  # Boolean flag to diagnose
                  booldiag=True, \
                 ):
    '''
    Search for periodic outliers in a computationally efficient way
    '''
    
    # indices of the outliers
    indxtimesort = np.argsort(flux)[::-1][:numboutl]
    
    # the times of the outliers
    timeoutl = time[indxtimesort]
    
    # number of differences between times of outlier samples
    numbdiff = int(numboutl * (numboutl - 1) / 2)
    
    # differences between times of outlier samples
    difftimeoutl = np.empty(numbdiff)
    
    # compute the differences between times of outlier samples
    listtemp = []
    c = 0
    indxoutl = np.arange(numboutl)
    for a in indxoutl:
        for b in indxoutl:
            if a >= b:
                continue
            listtemp.append([a, b])
            difftimeoutl[c] = abs(timeoutl[a] - timeoutl[b])
            c += 1
    
    # incides that sort the differences between times of outlier samples
    indxsort = np.argsort(difftimeoutl)
    
    # sorted differences between times of outlier samples
    difftimeoutlsort = difftimeoutl[indxsort]

    # fractional differences between differences of times of outlier samples
    frddtimeoutlsort = (difftimeoutlsort[1:] - difftimeoutlsort[:-1]) / ((difftimeoutlsort[1:] + difftimeoutlsort[:-1]) / 2.)

    # index of the minimum fractional difference between differences of times of outlier samples
    indxfrddtimeoutlsort = np.argmin(frddtimeoutlsort)
    
    # minimum fractional difference between differences of times of outlier samples
    minmfrddtimeoutlsort = frddtimeoutlsort[indxfrddtimeoutlsort]
    
    # estimate of the epoch
    epoccomp = timeoutl[0]
    
    # estimate of the period
    pericomp = difftimeoutlsort[indxfrddtimeoutlsort]
    
    # output dictionary
    dictoutp = dict()
    
    # populate the output dictionary
    if minmfrddtimeoutlsort < 0.1:
        dictoutp['boolposi'] = True
        dictoutp['pericomp'] = [pericomp]
        dictoutp['epocmtracomp'] = [epoccomp]
    else:
        dictoutp['boolposi'] = False
    dictoutp['minmfrddtimeoutlsort'] = [minmfrddtimeoutlsort]
    dictoutp['timeoutl'] = timeoutl 
    
    return dictoutp

In [None]:
res = ens.batch(
    srch_outlperi,
    ens._time_col,
    ens._flux_col,
    ens._err_col, # TODO should be std dev of flux error? Is this right?
    meta=None)
res_computed = res.compute()

In [None]:
res

In [None]:
res.head(10)