## automating the regularization parameter tuning

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import copy
import wobble
import tensorflow as tf
from tqdm import tqdm

In [None]:
starname = '51peg'
data = wobble.Data(starname+'_e2ds.hdf5', filepath='data/', orders=[10])

In [None]:
model = wobble.Model(data)
model.add_star('star')
K = 3
model.add_telluric('tellurics', rvs_fixed=True, variable_bases=K)
print(model)

In [None]:
# from hand-tuning
model.components[1].L2_template = 1.e7
model.components[1].L2_basis_vectors = 1.e7
model.components[0].L1_template = 1.e1
model.components[0].L2_template = 1.e1

In [None]:
def fit_rvs_only(model, data, r):
    synth = model.synthesize(r)
    nll = 0.5*tf.reduce_sum(tf.square(tf.boolean_mask(data.ys[r], data.epoch_mask) 
                                      - tf.boolean_mask(synth, data.epoch_mask)) 
                            * tf.boolean_mask(data.ivars[r], data.epoch_mask)) 
    
    # set up optimizers: 
    session = wobble.get_session()

    for c in model.components:
        if not c.rvs_fixed:
            c.gradients_rvs = tf.gradients(nll, c.rvs_block[r])
            optimizer = tf.train.AdamOptimizer(c.learning_rate_rvs)
            c.opt_rvs = optimizer.minimize(nll, 
                            var_list=[c.rvs_block[r]])
            session.run(tf.variables_initializer(optimizer.variables()))
        if c.K > 0:
            c.gradients_basis = tf.gradients(nll, c.basis_weights[r])
            optimizer = tf.train.AdamOptimizer(c.learning_rate_basis)
            c.opt_basis = optimizer.minimize(nll, 
                            var_list=c.basis_weights[r])
            session.run(tf.variables_initializer(optimizer.variables()))
    
    results = wobble.Results(model=model, data=data)

    # optimize:
    for i in tqdm(range(100)):         
        for c in model.components:
            if not c.rvs_fixed:            
                session.run(c.opt_rvs) # optimize RVs
            if c.K > 0:
                session.run(c.opt_basis) # optimize variable components
    results.copy_model(model) # update
    return results
   

In [None]:
def improve_regularization(c, model, data, r, names=None, verbose=True, plot=False, basename=''):
    if names is None:
        names = ['L1_template', 'L2_template']
        if c.K > 0:
            names.append(['L1_basis_vectors', 'L2_basis_vectors', 'L2_basis_weights'])
            
    validation_epochs = np.random.choice(data.N, data.N//10, replace=False)
    training_epochs = np.delete(np.arange(data.N), validation_epochs)
    
    training_data = copy.copy(data)
    training_data.epoch_mask = np.isin(np.arange(data.N), training_epochs)
    validation_mask = np.isin(np.arange(data.N), validation_epochs)
    validation_data = copy.copy(data)
    validation_data.epoch_mask = validation_mask
    
    for name in names:
        current_value = getattr(c, name)
        grid = np.logspace(-3.0, 3.0, num=7) * current_value
        chisqs_grid = np.zeros_like(grid)
        for i,val in enumerate(grid):
            setattr(c, name, val)
            for co in model.components:
                co.template_exists[r] = False # force reinitialization at each iteration
                
            results_train = wobble.optimize_order(model, training_data, r)
            
            results = fit_rvs_only(model, validation_data, r)
            
            chisqs = (results.ys[r][validation_mask] 
                      - results.ys_predicted[r][validation_mask])**2 * (results.ivars[r][validation_mask])
            chisqs_grid[i] = np.sum(chisqs)
            
            if plot:
                e = validation_epochs[0] # random epoch
                plt.plot(np.exp(results.xs[0][e]), np.exp(results.ys[0][e]), label='data')
                plt.plot(np.exp(results.xs[0][e]), np.exp(results.ys_predicted[0][e]), 
                     label='best-fit model')
                plt.legend()
                plt.title('{0}: value {1:.0e}, chisq {2:.0f}'.format(name, val, chisqs_grid[i]), 
                     fontsize=12)
                plt.savefig('{0}_{1}_val{2:.0e}.png'.format(basename, name, val))
            if verbose:
                print('{0}: value {1:.0e}, chisq {2:.0f}'.format(name, val, chisqs_grid[i]))
            
        if plot:
            plt.scatter(grid, chisqs_grid)
            plt.xscale('log')
            plt.yscale('log')
            plt.xlabel('{0} values'.format(name))
            plt.ylabel(r'$\chi^2$')
            plt.savefig('{0}_{1}_chis.png'.format(basename, name))
        
        best = grid[np.argmin(chisqs_grid)]
        if verbose:
            print("{0} optimized; setting to {1:.0e}".format(name, best))
        setattr(c, name, best)

In [None]:
%%time
improve_regularization(model.components[0], model, data, 0, names=None, 
                       plot=True, basename='regularization/o10_star')

In [None]:
%%time
improve_regularization(model.components[1], model, data, 0, names=None, 
                       plot=True, basename='regularization/o10_tellurics')