In [1]:
#!/usr/bin/env python3
import numpy as np
import time
import JaxPeriodDrwFit
# import dill as pickle
import cloudpickle as pickle

import dask

from tape.ensemble import Ensemble
from tape.utils import ColumnMapper



In [2]:
# generated in create_data script, to avoid epyc failure
t_multi = np.load('/astro/users/ncaplar/data/t_multi.npy')
y_multi = np.load('/astro/users/ncaplar/data/y_multi.npy')
yerr_multi = np.load('/astro/users/ncaplar/data/yerr_multi.npy')

In [3]:

id, t, y, yerr, filter = np.array([]), np.array([]), np.array([]), np.array([]), np.array([])

for i in range(100):
    # data = data_all[()].get(i)

    # get time for a single lightcurve
    t_true = t_multi[i]
    # sample 100 points from 200
    downsample_int = np.sort(np.random.choice(np.arange(len(t_true)), 100))
    # extract 100 times from 200
    t_single = t_true[downsample_int]

    id = np.append(id, np.full(len(downsample_int), i))
    filter_single = np.full(len(t_single), 'r')
    t = np.append(t, t_single)
    filter = np.append(filter, filter_single)

    # create custom errors
    y_err_single = np.full(len(t_single), 0.001)
    yerr = np.append(yerr, np.full(len(t_single), 0.001))

    # extract measurements; 100 from each lightcurve
    # y_pre = data['y_tot'][downsample_int]
    y_pre = y_multi[i][downsample_int]

    # create noise and add to lightcurves
    noise = np.random.normal(0, y_err_single)
    y = np.append(y, y_pre + noise)

# columns assigned manually
manual_colmap = ColumnMapper().assign(
    id_col="id", time_col="t", flux_col="y", err_col="yerr", band_col="filter"
)

ens = Ensemble(silence_logs='error')
ens.from_source_dict({'id': id, "t": t, 'y': y, 'yerr': yerr, 'filter': filter},
                        column_mapper=manual_colmap)
single_lc = ens.compute("source")[id == 0]
# comment out line below if trying to run ensamble.batch
# ens.client.close()
##########



Process Dask Worker process (from Nanny):
Traceback (most recent call last):
  File "/astro/users/ncaplar/miniconda3/envs/tiny_lsst/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/astro/users/ncaplar/miniconda3/envs/tiny_lsst/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/astro/users/ncaplar/miniconda3/envs/tiny_lsst/lib/python3.10/site-packages/distributed/process.py", line 202, in _run
    target(*args, **kwargs)
  File "/astro/users/ncaplar/miniconda3/envs/tiny_lsst/lib/python3.10/site-packages/distributed/nanny.py", line 997, in _run
    logger.setLevel(silence_logs)
  File "/astro/users/ncaplar/miniconda3/envs/tiny_lsst/lib/python3.10/logging/__init__.py", line 1452, in setLevel
    self.level = _checkLevel(level)
  File "/astro/users/ncaplar/miniconda3/envs/tiny_lsst/lib/python3.10/logging/__init__.py", line 198, in _checkLevel
    raise ValueError("Unknown level: %r" 

In [4]:
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()

In [5]:
ens.head("object", 5)

filter,nobs_r,nobs_total
id,Unnamed: 1_level_1,Unnamed: 2_level_1
0.0,100,100
1.0,100,100
2.0,100,100
3.0,100,100
4.0,100,100


In [6]:
ens.head("source", 5)

Unnamed: 0_level_0,t,y,yerr,filter
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0.0,98.559856,-0.30356,0.001,r
0.0,98.559856,-0.302442,0.001,r
0.0,101.845185,-0.248341,0.001,r
0.0,104.035404,-0.17551,0.001,r
0.0,104.035404,-0.177207,0.001,r


In [7]:
# it works here !?!?
res = ens.batch(JaxPeriodDrwFit_instance.optimize_map, 't', 'y', 'yerr',
                 compute=True, meta=None, n_init=100)
print(res)



id
0.0     [-121.85756390067574, 2.025424465740777, -0.77...
1.0     [-119.35462093629818, 2.216255388328045, -0.70...
2.0     [-118.72923596254645, 1.8914790763868512, -0.8...
3.0     [-115.88897805675023, 1.920796463606562, -0.78...
4.0     [-123.6160197780911, 1.955781803159924, -0.772...
                              ...                        
95.0    [-123.73661896129992, 2.0674828883714675, -0.7...
96.0    [-126.4304628722924, 1.7525564690047732, -0.89...
97.0    [-111.38293780603686, 2.228549800187429, -0.67...
98.0    [-125.22108223493117, 1.5800022120793813, -0.9...
99.0    [-117.17786745943317, 2.1117239049128402, -0.7...
Name: id, Length: 100, dtype: object


In [8]:
ens.client.close()

INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:40054'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:37859'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:44361'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:44961'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:39614'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:43959'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0

In [9]:


t = single_lc['t'].values
y = single_lc['y'].values
yerr = single_lc['yerr'].values

# This block shows that the code works on a single lightcurve
# And it is faster second time
t1 = time.time()
test_single_lc_res = JaxPeriodDrwFit_instance.optimize_map(t, y, yerr, n_init=100)
t2 = time.time()
print(f'Execution time for single lc is {t2 - t1} sec')
print('Best result is:' + str(test_single_lc_res))
t1 = time.time()
test_single_lc_res = JaxPeriodDrwFit_instance.optimize_map(t, y, yerr, n_init=100)
t2 = time.time()
print(f'Execution time for second run with single lc is {t2 - t1} sec')
print('Best result is:' + str(test_single_lc_res))



Execution time for single lc is 4.077223062515259 sec
Best result is:[-130.76855682    2.02563672   -0.78154641    0.79178059   -1.65381861]
Execution time for second run with single lc is 0.1410675048828125 sec
Best result is:[-131.83079473    1.90012218   -0.83150308    2.79918596   -1.01825057]
