Skip to content

Commit

Permalink
add template_dx kwarg to model, change epoch dropping
Browse files Browse the repository at this point in the history
  • Loading branch information
megbedell committed Jul 27, 2022
1 parent 9771c89 commit e65ce86
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
31 changes: 19 additions & 12 deletions wobble/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,20 +241,13 @@ def drop_bad_epochs(self, min_snr=5.):
except:
epochs = np.arange(self.N)
snrs_by_epoch = np.sqrt(np.nanmean(self.ivars, axis=(0,2)))
epochs_to_cut = snrs_by_epoch < min_snr
if np.sum(epochs_to_cut) > 0:
print("Data: Dropping epochs {0} because they have average SNR < {1:.0f}".format(epochs[epochs_to_cut], min_snr))
epochs = epochs[~epochs_to_cut]
for attr in REQUIRED_3D:
old = getattr(self, attr)
setattr(self, attr, [o[~epochs_to_cut] for o in old]) # might fail if self.N = 1
for attr in np.append(REQUIRED_1D, OPTIONAL_1D):
setattr(self, attr, getattr(self,attr)[~epochs_to_cut])
self.epochs = epochs
self.N = len(epochs)
bad_epoch_mask = snrs_by_epoch < min_snr
if np.sum(bad_epoch_mask) > 0:
print("Data: Dropping epochs {0} because they have average SNR < {1:.0f}".format(epochs[bad_epoch_mask], min_snr))
self.delete_epochs(bad_epoch_mask)
if self.N == 0:
print("All epochs failed the quality cuts with min_snr={0:.0f}.".format(min_snr))
return
return

def delete_orders(self, bad_order_mask):
"""
Expand All @@ -267,6 +260,20 @@ def delete_orders(self, bad_order_mask):
setattr(self, attr, new)
self.orders = self.orders[good_order_mask]
self.R = len(self.orders)

def delete_epochs(self, bad_epoch_mask):
"""
Take an N-epoch length boolean mask & drop all epochs marked True.
"""
good_epoch_mask = ~bad_epoch_mask
for attr in REQUIRED_3D:
old = getattr(self, attr)
setattr(self, attr, [o[good_epoch_mask] for o in old]) # might fail if self.N = 1
for attr in np.append(REQUIRED_1D, OPTIONAL_1D):
old = np.array(getattr(self,attr)) # fix bug with filelist being a list
setattr(self, attr, old[good_epoch_mask])
self.epochs = self.epochs[good_epoch_mask]
self.N = len(self.epochs)

class Spectrum(object):
"""
Expand Down
15 changes: 12 additions & 3 deletions wobble/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ class Component(object):
in the same units as data `ys`.
If `None`, generate automatically upon initialization.
If not `None`, `template_xs` must be provided in the same shape.
template_dx : `float` or `None` (default `None`)
Spacing between control points in the log-uniform x-value grid of
the spectral template. Only used if `template_xs` is not specified.
If `None', generate automatically.
initialize_at_zero : `bool` (default `False`)
If `True`, initialize template as a flat continuum. Equivalent to
providing a vector of zeros with `template_ys` keyword but does
Expand Down Expand Up @@ -374,13 +378,15 @@ class Component(object):
def __init__(self, name, r, starting_rvs, epoch_mask,
rvs_fixed=False, template_fixed=False, rv_steps = 1,
variable_bases=0, scale_by_airmass=False,
template_xs=None, template_ys=None, initialize_at_zero=False,
template_xs=None, template_ys=None, template_dx=None,
initialize_at_zero=False,
learning_rate_rvs=1., learning_rate_template=0.01,
learning_rate_basis=0.01, regularization_par_file=None,
**kwargs):
for attr in ['name', 'r', 'starting_rvs', 'epoch_mask',
'rvs_fixed', 'template_fixed', 'rv_steps',
'template_xs', 'template_ys', 'initialize_at_zero',
'template_xs', 'template_ys', 'template_dx',
'initialize_at_zero',
'learning_rate_rvs', 'learning_rate_template',
'learning_rate_basis', 'scale_by_airmass']:
setattr(self, attr, eval(attr))
Expand Down Expand Up @@ -481,7 +487,10 @@ def initialize_template(self, data_xs, data_ys, data_ivars):
N = len(self.starting_rvs)
shifted_xs = data_xs + np.log(doppler(self.starting_rvs[:, None], tensors=False)) # component rest frame
if self.template_xs is None:
dx = 2.*(np.log(6000.01) - np.log(6000.)) # log-uniform spacing
if self.template_dx is None:
dx = 2.*(np.log(6000.01) - np.log(6000.)) # log-uniform spacing
else:
dx = self.template_dx
tiny = 10.
self.template_xs = np.arange(np.nanmin(shifted_xs)-tiny*dx,
np.nanmax(shifted_xs)+tiny*dx, dx)
Expand Down

0 comments on commit e65ce86

Please sign in to comment.