# Injection/recovery tests- Part I

How accurate is `blasé`?  It depends!  The best way to assess accuracy is to test the code on spectra with known line properties.  We therefore create noised-up synthetic spectra with known perturbations to lines and see how close `blase` comes to recovering the ground truth.  This simulation procedure may be referred to as "injection/recovery tests"; it is common in many subfields of science as a strategy for quantifying uncertainty.

We anticipate that there is some threshold of signal-to-noise-ratio under which the information loss is just too great to overcome, and `blase` will face the impossibility of sorting signal from noise.  The goal of this experiment is to expose those contours, and build an intuition for the failure modes `blasé` can expect.

In [None]:
%config Completer.use_jedi = False

In [None]:
import torch
from blase.emulator import SparseLogEmulator, ExtrinsicModel, InstrumentalModel
import matplotlib.pyplot as plt
from gollum.phoenix import PHOENIXSpectrum
from gollum.telluric import TelFitSpectrum
from blase.utils import doppler_grid
import astropy.units as u
import numpy as np

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
device

## We need data simply for the wavelength coordinates and pixel sampling

In [None]:
from muler.hpf import HPFSpectrum, HPFSpectrumList

In [None]:
path = 'https://github.com/OttoStruve/muler_example_data/raw/main/HPF/01_A0V_standards/'
filename = 'Goldilocks_20210212T072837_v1.0_0037.spectra.fits'
#raw_data = HPFSpectrum(file=path+filename, order=5)

In [None]:
raw_data = HPFSpectrumList.read(path+filename)

In [None]:
raw_data = HPFSpectrumList(raw_data[2:9])

In [None]:
data = raw_data.sky_subtract().trim_edges().remove_nans().deblaze().stitch()

In [None]:
wavelength_coordinates = data.wavelength.value
bin_edges = data.bin_edges.value

In [None]:
wl_lo = wavelength_coordinates.min()-30.0
wl_hi = wavelength_coordinates.max()+30.0
wavelength_grid = doppler_grid(wl_lo, wl_hi)

### Fetch a Phoenix model

In [None]:
from gollum.phoenix import PHOENIXGrid

In [None]:
observed_RV = 0.0 # Just say zero for simplicity
vsini = 15.9 #km/s
resolving_power = 55_000

In [None]:
native_spectrum = PHOENIXSpectrum(teff=5400, logg=4.5, metallicity=0.0, wl_lo=wl_lo, wl_hi=wl_hi)
native_spectrum = native_spectrum.divide_by_blackbody()
native_spectrum = native_spectrum.normalize()
continuum_fit = native_spectrum.fit_continuum(polyorder=5)
native_spectrum = native_spectrum.divide(continuum_fit, handle_meta="ff")

In [None]:
spectrum = native_spectrum.rotationally_broaden(vsini)
spectrum = spectrum.rv_shift(observed_RV)
spectrum = spectrum.instrumental_broaden(resolving_power=resolving_power).resample(data)

### Clone the PHOENIX stellar model with `blase`

In [None]:
stellar_emulator = SparseLogEmulator(native_spectrum.wavelength.value, 
                                     np.log(native_spectrum.flux.value), prominence=0.01, device=device)
stellar_emulator.to(device)

### Fine-tune the clone

In [None]:
stellar_emulator.optimize(epochs=1000, LR=0.01)

In [None]:
clone_params = stellar_emulator.state_dict()

### Initialize the models

In [None]:
## Extinsic Layer
extrinsic_layer = ExtrinsicModel(wavelength_grid, device=device)
vsini = torch.tensor(vsini)
extrinsic_layer.ln_vsini.data = torch.log(vsini)
extrinsic_layer.to(device)

## Stellar emulator Layer
stellar_emulator = SparseLogEmulator(wavelength_grid, 
                                     init_state_dict=stellar_emulator.state_dict(), device=device)
stellar_emulator.radial_velocity.data = torch.tensor(observed_RV)
stellar_emulator.to(device)

# Instrument Layer
instrumental_model = InstrumentalModel(bin_edges, wavelength_grid, device=device)
instrumental_model.to(device)

instrumental_model.ln_sigma_angs.data = torch.log(torch.tensor(0.064))

## Make fake "synthetic" data

#### Perturb individual lines by about 9%, with a 3% systematic offset (all lines are deeper than expected)

In [None]:
ln_amp_perturbs = np.random.normal(loc=-0.4, scale=0.7, 
                                size=stellar_emulator.n_lines)
amp_perturbs = np.exp(ln_amp_perturbs)

In [None]:
plt.hist(amp_perturbs, bins=np.arange(0, 5,0.1));
plt.axvline(1, linestyle='dashed', color='k', label='Unchanged')
plt.xlabel('Amplitude scale factor'); plt.legend();

In [None]:
stellar_emulator.load_state_dict(clone_params)

In [None]:
with torch.no_grad():
    stellar_emulator.amplitudes.data += torch.tensor(ln_amp_perturbs).to(device)
    super_res_truth = stellar_emulator.forward()
    broadened_flux = extrinsic_layer(super_res_truth)
    perturbed = instrumental_model.forward(broadened_flux)
    
with torch.no_grad():
    stellar_emulator.load_state_dict(clone_params)
    stellar_flux = stellar_emulator.forward()
    broadened_flux = extrinsic_layer(stellar_flux)
    pristine = instrumental_model.forward(broadened_flux)

#### Noise-up the spectra to $S/N\sim100$ 

In [None]:
n_pixels = len(data.wavelength)
per_pixel_uncertainty = torch.tensor(0.005, device=device, dtype=torch.float64)
noise_draw = np.random.normal(loc=0, scale=per_pixel_uncertainty.cpu(), size=n_pixels)
synthetic_data = perturbed + torch.tensor(noise_draw, device=device)

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(data.wavelength, synthetic_data.cpu(), '.', label='Noised-up', color='k', alpha=0.2)
plt.step(data.wavelength, pristine.cpu(), label='Pristine', alpha=1, lw=1)
plt.step(data.wavelength, perturbed.cpu(), label='Perturbed', alpha=1, lw=1)

plt.legend();

In [None]:
data_target = synthetic_data.to(device)

data_wavelength = torch.tensor(
    wavelength_coordinates.astype(np.float64), device=device, dtype=torch.float64
)

## Transfer learn a semi-empirical model

In [None]:
from torch import nn
from tqdm import trange
import torch.optim as optim

In [None]:
loss_fn = nn.MSELoss(reduction="mean")

### Fix certain parameters, allow others to vary
As we have seen before, you can fix parameters by "turning off their gradients".  We will start by turning off *ALL* gradients.  Then turn on some.

In [None]:
for key in stellar_emulator.state_dict().keys():
    stellar_emulator.__getattr__(key).requires_grad = False

In [None]:
stellar_emulator.amplitudes.requires_grad = True
#stellar_emulator.lam_centers.requires_grad = False
stellar_emulator.radial_velocity.requires_grad = True
instrumental_model.ln_sigma_angs.requires_grad = True

In [None]:
optimizer = optim.Adam(
    list(filter(lambda p: p.requires_grad, stellar_emulator.parameters()))
    + list(filter(lambda p: p.requires_grad, extrinsic_layer.parameters()))
    + list(filter(lambda p: p.requires_grad, instrumental_model.parameters())),
    0.01,
    amsgrad=True,
)         

In [None]:
n_epochs = 200
losses = []

## Regularization


Then we need the prior.  For now, let's just apply priors on the amplitudes (almost everything else is fixed).  We need to set the regularization hyperparameter tuning.

In [None]:
stellar_amp_regularization = 5.1
stellar_lam_regularization = 0.5

In [None]:
plt.plot(ln_amp_perturbs, ln_amp_perturbs**2/stellar_amp_regularization**2, '.')

In [None]:
import copy

In [None]:
with torch.no_grad():
    stellar_init_amps = copy.deepcopy(stellar_emulator.amplitudes)
    stellar_init_lams = copy.deepcopy(stellar_emulator.lam_centers)

# Define the prior on the amplitude
def ln_prior(stellar_amps):
    """
    Prior for the amplitude vector
    """
    amp_diff1 = stellar_amps - stellar_init_amps
    ln_prior1 = 0.5 * torch.sum((amp_diff1 ** 2) / (stellar_amp_regularization ** 2))
    
    
    #lam_diff1 = stellar_init_lams - lam_centers
    #ln_prior3 = 0.5 * torch.sum((lam_diff1 ** 2) / (stellar_lam_regularization ** 2))

    return ln_prior1#  + ln_prior3

In [None]:
t_iter = trange(n_epochs, desc="Training", leave=True)
for epoch in t_iter:
    stellar_emulator.train()
    extrinsic_layer.train()
    instrumental_model.train()
    
    stellar_flux = stellar_emulator.forward()
    broadened_flux = extrinsic_layer(stellar_flux)
    detector_flux = instrumental_model.forward(broadened_flux)
    
    loss = loss_fn(detector_flux / per_pixel_uncertainty, data_target / per_pixel_uncertainty)
    loss += ln_prior(stellar_emulator.amplitudes)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    t_iter.set_description("Training Loss: {:0.8f}".format(loss.item()))

### Spot check the transfer-learned joint model

In [None]:
plt.figure(figsize=(8, 4))
#plt.plot(data.wavelength, synthetic_data, '.', label='Noised-up', color='k', alpha=0.2)
plt.step(data.wavelength, pristine.cpu(), label='Pristine', alpha=0.3, lw=1, color='k')
plt.step(data.wavelength, perturbed.cpu(), label='Perturbed', alpha=0.7, lw=2)
plt.step(data.wavelength, detector_flux.detach().cpu(), label='Retrieved', alpha=0.7, lw=2)

#plt.xlim(8500, 8700)

plt.legend();

In [None]:
residual_truth = perturbed.cpu() - detector_flux.detach().cpu()

In [None]:
from scipy.stats import norm

In [None]:
100*per_pixel_uncertainty.cpu()

In [None]:
bins=np.arange(-2, 2, 0.1)
pdf = norm.pdf(bins, loc=0, scale=100*per_pixel_uncertainty.cpu())

In [None]:
plt.hist(residual_truth*100, bins=bins, density=True);
plt.yscale('log'); plt.xlabel('Residual (%)', fontsize=12)
plt.plot(bins, pdf)

In [None]:
plt.plot(wavelength_coordinates, residual_truth, 'ko', alpha=0.02)
plt.axhline(0,)
plt.axhline(0,)

### Retrieved line strengths

In [None]:
plt.figure(figsize=(5,5))
plt.plot(clone_params['amplitudes'].cpu()+ln_amp_perturbs, 
         clone_params['amplitudes'].cpu(), 'ko', alpha=0.2, label='Injected')
plt.plot(clone_params['amplitudes'].cpu()+ln_amp_perturbs, 
         stellar_emulator.amplitudes.detach().cpu(), 'o', label='Recovered')
plt.plot([-8, 0], [-8,0], 'k--', label='1:1')
plt.xlim(-8, 0);plt.ylim(-8, 0)