# Amplitude and width clustering

The blase model *should* overfit.  One common path for overfitting is to have lines with really large widths to make up for continuum imperfections.  Let's see if we can identify and flag these.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

import seaborn as sns
sns.set_context('paper', font_scale=2)

In [None]:
from blase.emulator import PhoenixEmulator

In [None]:
! ls -t1 ../examples/*.pt

In [None]:
! du -hs '../examples/native_res_0p1prom.pt'

In [None]:
with torch.no_grad():
    model_params = torch.load('../examples/native_res_0p1prom.pt')

In [None]:
emulator = PhoenixEmulator(4700, 4.5, prominence=0.1,)

In [None]:
emulator.load_state_dict(model_params)

This step takes a *TON* of RAM unless you use `torch.no_grad`!

In [None]:
with torch.no_grad():
    cloned_spectrum = emulator.forward(emulator.wl_native)

In [None]:
type(cloned_spectrum), cloned_spectrum.shape

## $\pm 2.5\%$ residuals with a long tail

At native resolution

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(emulator.wl_native, emulator.flux_native, label='PHOENIX model')
plt.plot(emulator.wl_native, cloned_spectrum.detach(), label='Clone')
plt.legend()

In [None]:
residual = emulator.flux_native - cloned_spectrum.detach()

In [None]:
stddev = torch.std(residual)

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(emulator.wl_native, residual*100.0, label='Residual')
plt.axhline(+stddev*100, color='k', linestyle='dashed')
plt.axhline(-stddev*100, color='k', linestyle='dashed')
plt.ylim(-10, 10)
plt.ylabel('Residual (%)')

Hmmm, those residuals seem large compared to the reported residuals after training... did something go wrong when we loaded them in?  Is there some hysteresis when loading a model?  A rounding error?

How big are the residuals when you smooth them to HPF resolution?

## Smooth to HPF resolution

In [None]:
from gollum.phoenix import PHOENIXSpectrum
import astropy.units as u

from muler.hpf import HPFSpectrumList

In [None]:
original_native = PHOENIXSpectrum(spectral_axis=emulator.wl_native*u.Angstrom, 
                                flux=emulator.flux_native*u.dimensionless_unscaled)

In [None]:
clone_native = PHOENIXSpectrum(spectral_axis=emulator.wl_native*u.Angstrom, 
                                flux=cloned_spectrum*u.dimensionless_unscaled)

In [None]:
echelle_orders = HPFSpectrumList.read('../../muler_example_data/HPF/01_A0V_standards/Goldilocks_20210517T054403_v1.0_0060.spectra.fits')

In [None]:
hpf_spectrum = echelle_orders.sky_subtract(method='vector')\
                                .deblaze()\
                                .normalize()\
                                .remove_nans()\
                                .trim_edges((6,2042))\
                                .stitch()

In [None]:
def simulate_observation(spectrum):
    """Simulate an observation with HPF"""
    return spectrum.rotationally_broaden(13.5)\
                    .rv_shift(-16.2)\
                    .instrumental_broaden(resolving_power=55_000)\
                    .resample(hpf_spectrum)

In [None]:
original_sim = simulate_observation(original_native)
clone_sim = simulate_observation(clone_native)

In [None]:
ax = original_sim.plot(ylo=0, yhi=2)
clone_sim.plot(ax=ax)

In [None]:
ax = original_sim.plot(ylo=0.5, yhi=1)
clone_sim.plot(ax=ax)
ax.set_xlim(10820, 10960)

Yuck!  The cloning is not adequate at this zoom level.  Did we not train long enough?

In [None]:
residual_spec = (original_sim - clone_sim)*100

In [None]:
stddev = residual_spec.flux.std().value

In [None]:
stddev

## $\pm 0.9\%$ residuals after smoothing and resampling


In [None]:
ax = residual_spec.plot(ylo=-10, yhi=10)
ax.axhline(+stddev, color='k', linestyle='dashed')
ax.axhline(-stddev, color='k', linestyle='dashed')
ax.set_ylabel('Residual (%)')

We still want better than 1%!  That level of residual is comparable to the SNR of a real spectrum.

## Clustering of parameters

In [None]:
amps, widths = model_params['amplitudes'].detach().cpu().numpy(), model_params['gamma_widths'].detach().cpu().numpy()

In [None]:
amps, widths = np.exp(amps), np.exp(widths)

In [None]:
plt.plot(widths,amps, 'o', alpha=0.1)
plt.yscale('log'), plt.xscale('log')
plt.xlabel('$\gamma_L \; (\AA) $'); plt.ylabel('Amplitude');

Hmm, I'd expect to see continuum overfitting in the bottom right corner:  Wide lines with low amplitude.

In [None]:
from scipy.signal import find_peaks

In [None]:
biggest_residuals = find_peaks(np.abs(residual), height=0.1)

In [None]:
indices, meta_info = biggest_residuals

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(emulator.wl_native, np.abs(residual*100.0), label='Residual')
plt.ylim(0, 100)

plt.plot(emulator.wl_native[indices], 100*meta_info['peak_heights'], 'ro')
plt.ylabel('Residual (%)')