In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np

from astropy.table import Table
from astropy import units as u
from astropy.time import Time

from matplotlib import pyplot as plt

from stellarphot.differential_photometry.aij_rel_fluxes import calc_aij_relative_flux

In [None]:
flux_file = 'relative_flux-kelt-1.fits'
aperture_file = 'aperture_locations.fits'
filter_bassband = 'r'

In [None]:
phot = Table.read(flux_file)
aperture_locations = Table.read(aperture_file)

In [None]:
target_star = phot['id'] == 1

In [None]:
band_filter = phot['filter'] == filter_bassband

target_and_filter = target_star & band_filter

## When is this event supposed to happen?

You will need to look up the epoch/period for your object. If it is a TIC object use [ExoFOP-TESS](https://exofop.ipac.caltech.edu/tess/) to do that. If it is not a TIC object use the [NASA Exoplanet Archive](https://exoplanetarchive.ipac.caltech.edu/) or the [Exoplanet Transit Database](http://var2.astro.cz/ETD/).

In both cases you will be calculating the time of mid-transit.

In [None]:
# Enter your object's period here
period = 1.217494 * u.day

# Enter the epoch here
epoch = Time(2456583.78435, scale='tdb', format='jd')

# No changes to the line below, it is grabbing the first time in the data series
then = Time(phot['BJD'][target_star][0], scale='tdb', format='jd')

In [None]:
cycle_number = np.int((then - epoch) / period + 1)
cycle_number

In [None]:
that_transit = cycle_number * period + epoch
that_transit

In [None]:
phot_one_filter_only = phot[band_filter]

## Check for and remove any comparison stars that are "bad"

Bad here means that their net counts is `NaN`. This can happen if the star is very faint or near the edge of the image.

In [None]:
bad_flux = np.isnan(phot_one_filter_only['aperture_net_flux'])
bad_ids = list(set(phot_one_filter_only['id'][bad_flux]))
print(bad_ids)

In [None]:
if bad_ids:
    is_bad_comp = aperture_locations['id'] == bad_ids[0]
    for bad_id in bad_ids[1:]:
        is_bad_comp = is_bad_comp | aperture_locations['id'] == bad_id
else:
    # Make an array that has no bad values
    is_bad_comp = aperture_locations['id'] == -42

In [None]:
updated_apertures = aperture_locations[~is_bad_comp]

In [None]:
comps = updated_apertures[updated_apertures['marker name'] == 'APASS comparison']
new_flux = calc_aij_relative_flux(phot_one_filter_only, comps, 
                                  in_place=False, coord_column='coord')

In [None]:
target = new_flux['id'] == 1

In [None]:
assert not np.isnan(new_flux['relative_flux'][target]).sum()

In [None]:
plt.plot(new_flux['BJD'][target], new_flux['relative_flux'][target], '.')
plt.vlines(that_transit.value, *plt.ylim())
plt.xlabel('BJD')
plt.ylabel('Relative flux')
