In [1]:
#!/usr/bin/env python3
import numpy as np
import time
import JaxPeriodDrwFit
# import dill as pickle
import cloudpickle as pickle

from tape.ensemble import Ensemble
from tape.utils import ColumnMapper

In [2]:
# generated in create_data script, to avoid epyc failure
t_multi = np.load('/astro/users/ncaplar/data/t_multi.npy')
y_multi = np.load('/astro/users/ncaplar/data/y_multi.npy')
yerr_multi = np.load('/astro/users/ncaplar/data/yerr_multi.npy')

In [3]:

id, t, y, yerr, filter = np.array([]), np.array([]), np.array([]), np.array([]), np.array([])

for i in range(100):
    # data = data_all[()].get(i)

    # get time for a single lightcurve
    t_true = t_multi[i]
    # sample 100 points from 200
    downsample_int = np.sort(np.random.choice(np.arange(len(t_true)), 100))
    # extract 100 times from 200
    t_single = t_true[downsample_int]

    id = np.append(id, np.full(len(downsample_int), i))
    filter_single = np.full(len(t_single), 'r')
    t = np.append(t, t_single)
    filter = np.append(filter, filter_single)

    # create custom errors
    y_err_single = np.full(len(t_single), 0.001)
    yerr = np.append(yerr, np.full(len(t_single), 0.001))

    # extract measurements; 100 from each lightcurve
    # y_pre = data['y_tot'][downsample_int]
    y_pre = y_multi[i][downsample_int]

    # create noise and add to lightcurves
    noise = np.random.normal(0, y_err_single)
    y = np.append(y, y_pre + noise)

# columns assigned manually
manual_colmap = ColumnMapper().assign(
    id_col="id", time_col="t", flux_col="y", err_col="yerr", band_col="filter"
)

ens = Ensemble()
ens.from_source_dict({'id': id, "t": t, 'y': y, 'yerr': yerr, 'filter': filter},
                        column_mapper=manual_colmap)
single_lc = ens.compute("source")[id == 0]
# comment out line below if trying to run ensamble.batch
# ens.client.close()
##########



INFO:distributed.http.proxy:To route to workers diagnostics web server please install jupyter-server-proxy: python -m pip install jupyter-server-proxy
INFO:distributed.scheduler:State start
INFO:distributed.scheduler:  Scheduler at:     tcp://127.0.0.1:44160
INFO:distributed.scheduler:  dashboard at:  http://127.0.0.1:8787/status
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:41520'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:37047'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:35172'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:36807'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:34044'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:42496'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:38329'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:33869'
INFO:distributed.scheduler:Register worker <WorkerState 'tcp://127.0.0.1:42917', name: 0, status: in

In [4]:
JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()

In [5]:
ens.head("object", 5)

filter,nobs_r,nobs_total
id,Unnamed: 1_level_1,Unnamed: 2_level_1
0.0,100,100
1.0,100,100
2.0,100,100
3.0,100,100
4.0,100,100


In [6]:
ens.head("source", 5)

Unnamed: 0_level_0,t,y,yerr,filter
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0.0,100.750075,-0.279083,0.001,r
0.0,101.845185,-0.250719,0.001,r
0.0,104.035404,-0.17937,0.001,r
0.0,160.981098,0.099968,0.001,r
0.0,181.423142,0.037428,0.001,r


In [7]:
# it works here !?!?
res = ens.batch(JaxPeriodDrwFit_instance.optimize_map, 't', 'y', 'yerr',
                 compute=True, meta=None, n_init=100)
print(res)



id
0.0     [-130.9124282376623, 1.9940229510834508, -0.79...
1.0     [-121.419540418573, 2.1799845126346407, -0.710...
2.0     [-130.07752684151336, 1.9703068037069376, -0.8...
3.0     [-121.0830567877342, 2.0018271214347214, -0.78...
4.0     [-128.58517651196684, 1.9173932408575038, -0.8...
                              ...                        
95.0    [-133.53417385356724, 2.2511764821387024, -0.7...
96.0    [-131.89283492618694, 1.8781759511073668, -0.8...
97.0    [-126.5910879300794, 2.325249051470435, -0.668...
98.0    [-135.91441737620806, 1.936658843185543, -0.84...
99.0    [-122.93198917600216, 2.176840636105713, -0.72...
Name: id, Length: 100, dtype: object


In [8]:
ens.client.close()

INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:41520'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:37047'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:35172'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:36807'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:34044'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:42496'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0

In [9]:


t = single_lc['t'].values
y = single_lc['y'].values
yerr = single_lc['yerr'].values

# This block shows that the code works on a single lightcurve
# And it is faster second time
t1 = time.time()
test_single_lc_res = JaxPeriodDrwFit_instance.optimize_map(t, y, yerr, n_init=100)
t2 = time.time()
print(f'Execution time for single lc is {t2 - t1} sec')
print('Best result is:' + str(test_single_lc_res))
t1 = time.time()
test_single_lc_res = JaxPeriodDrwFit_instance.optimize_map(t, y, yerr, n_init=100)
t2 = time.time()
print(f'Execution time for second run with single lc is {t2 - t1} sec')
print('Best result is:' + str(test_single_lc_res))



Execution time for single lc is 4.077223062515259 sec
Best result is:[-130.76855682    2.02563672   -0.78154641    0.79178059   -1.65381861]
Execution time for second run with single lc is 0.1410675048828125 sec
Best result is:[-131.83079473    1.90012218   -0.83150308    2.79918596   -1.01825057]
