In [1]:
import numpy as np
import matplotlib.pyplot as plt
import astropy.table as at
import jax.numpy as jnp

import loss      as wobble_loss
import simulator as wobble_sim
import model     as wobble_model
import dataset   as wobble_data

In [None]:
@profile
def main():
    sigma = 80
    l = 0 
    r = 200
    n = 256
    maxiter =4
    tbl     = at.QTable.read('data/hat-p-20.fits')
    dataset = wobble_data.AstroDataset(tbl['flux'],tbl['wavelength'],tbl['mask'],tbl['flux_err'])
    dataset.interpolate_mask()
    dataset.gauss_filter(sigma=sigma)
    x, y, y_err = dataset.get_xy(subset=(l,r))

    x_shifts = wobble_data.getInitXShift(tbl['BJD'],'HAT-P-20','APO')

    loss_1 = wobble_model.LossFunc('L2Loss')
    model  = wobble_model.JnpLin(n,y,x,x_shifts)

    model.optimize(loss_1,maxiter=maxiter)
    model.plot()
    plt.savefig('out/hatp20jnpL{}_R{}_N{}f.png'.format(args.l,args.r,args.n))
    model.save_model('modeln{}_l{}_r{}_mI{}f.pt'.format(args.n,args.l,args.r,args.maxiter))
