Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-40841: Bring in upstream plotting and hologram definitions #46

Merged
merged 10 commits into from
Sep 26, 2023
4 changes: 2 additions & 2 deletions spectractor/extractor/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,9 @@ def run_ffm_minimisation(w, method="newton", niter=2):
if parameters.LSST_SAVEFIGPATH:
fig.savefig(os.path.join(parameters.LSST_SAVEFIGPATH, 'fwhm_2.pdf'))

my_logger.info("\n\tStart regularization parameter only.")
# Optimize the regularisation parameter only if it was not done before
if w.amplitude_priors_method == "spectrum" and w.reg == parameters.PSF_FIT_REG_PARAM: # pragma: no cover
my_logger.info("\n\tStart regularization parameter estimation...")
w_reg = RegFitWorkspace(w, opt_reg=parameters.PSF_FIT_REG_PARAM, verbose=True)
w_reg.run_regularisation()
w.opt_reg = w_reg.opt_reg
Expand Down Expand Up @@ -1135,7 +1135,7 @@ def SpectractorRun(image, output_directory, guess=None):
# Save the spectrum
my_logger.info('\n\t ======================= SAVE SPECTRUM =============================')
spectrum.save_spectrum(output_filename, overwrite=True)
spectrum.lines.print_detected_lines(amplitude_units=spectrum.units)
spectrum.lines.table = spectrum.lines.build_detected_line_table(amplitude_units=spectrum.units)

# Plot the spectrum
if parameters.VERBOSE and parameters.DISPLAY:
Expand Down
57 changes: 25 additions & 32 deletions spectractor/extractor/spectroscopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,10 @@ def sort_lines(self):
sorted_lines = sorted(sorted_lines, key=lambda x: x.wavelength)
return sorted_lines

def plot_atomic_lines(self, ax, color_atomic='g', color_atmospheric='b', fontsize=12, force=False):
def plot_atomic_lines(self, ax, color_atomic='g', color_atmospheric='b', fontsize=12, force=False,
calibration_only=False):
"""Over plot the atomic lines as vertical lines, only if they are fitted or with high
signal to noise ratio, unless force keyword is set to True.
signal-to-noise ratio, unless force keyword is set to True.

Parameters
----------
Expand All @@ -264,7 +265,9 @@ def plot_atomic_lines(self, ax, color_atomic='g', color_atmospheric='b', fontsiz
fontsize: int
Font size of the spectral line labels (default: 12).
force: bool
Force the plot of vertical lines if set to True (default: False).
Force the plot of vertical lines if set to True even if they are not detected (default: False).
calibration_only: bool
Plot only the lines used for calibration if True (default: False).

Examples
--------
Expand Down Expand Up @@ -306,6 +309,8 @@ def plot_atomic_lines(self, ax, color_atomic='g', color_atmospheric='b', fontsiz
for line in self.lines:
if (not line.fitted or not line.high_snr) and not force:
continue
if not line.use_for_calibration and calibration_only:
continue
color = color_atomic
if line.atmospheric:
color = color_atmospheric
Expand All @@ -316,15 +321,15 @@ def plot_atomic_lines(self, ax, color_atomic='g', color_atmospheric='b', fontsiz
xycoords='axes fraction', color=color, fontsize=fontsize)
return ax

def plot_detected_lines(self, ax=None, print_table=False):
def plot_detected_lines(self, ax=None, calibration_only=False):
"""Overplot the fitted lines on a spectrum.

Parameters
----------
ax: Axes
The Axes instance if needed (default: None).
print_table: bool, optional
If True, print a summary table (default: False).
calibration_only: bool
Plot only the lines used for calibration if True (default: False).

Examples
--------
Expand Down Expand Up @@ -391,6 +396,8 @@ def plot_detected_lines(self, ax=None, print_table=False):
"""
lambdas = np.zeros(1)
for line in self.lines:
if not line.use_for_calibration and calibration_only:
continue
if line.fitted is True:
# look for lines in subset fit
bgd_npar = line.fit_bgd_npar
Expand All @@ -403,22 +410,16 @@ def plot_detected_lines(self, ax=None, print_table=False):
bgd = np.polynomial.legendre.legval(x_norm, line.fit_popt[0:bgd_npar])
# bgd = np.polyval(line.fit_popt[0:bgd_npar], lambdas)
ax.plot(lambdas, bgd, lw=2, color='b', linestyle='--')
if print_table:
self.table = self.print_detected_lines(print_table=True)

def print_detected_lines(self, output_file_name="", overwrite=False, print_table=False, amplitude_units=""):
"""Print the detected line on screen as an Astropy table, and write it in a file.
def build_detected_line_table(self, amplitude_units="", calibration_only=False):
"""Build the detected line on screen as an Astropy table.

Parameters
----------
output_file_name: str, optional
Output file name. If not empty, save the table in a file (default: '').
overwrite: bool, optional
If True, overwrite the existing file if it exists (default: False).
print_table: bool, optional
If True, print a summary table (default: False).
amplitude_units: str, optional
Units of the line amplitude (default: "").
calibration_only: bool
Include only the lines used for calibration if True (default: False).

Returns
-------
Expand Down Expand Up @@ -450,20 +451,15 @@ def print_detected_lines(self, output_file_name="", overwrite=False, print_table

Print the result
>>> spec.lines = lines
>>> t = lines.print_detected_lines(output_file_name="test_detected_lines.csv")

.. doctest::
:hide:

>>> assert len(t) > 0
>>> assert os.path.isfile('test_detected_lines.csv')
>>> os.remove('test_detected_lines.csv')
>>> t = lines.build_detected_line_table()
"""
lambdas = np.zeros(1)
rows = []
j = 0

for line in self.lines:
if not line.use_for_calibration and calibration_only:
continue
if line.fitted is True:
# look for lines in subset fit
bgd_npar = line.fit_bgd_npar
Expand Down Expand Up @@ -493,21 +489,18 @@ def print_detected_lines(self, output_file_name="", overwrite=False, print_table
for col in t.colnames[-2:]:
t[col].unit = 'nm'
t[t.colnames[-3]].unit = 'reduced'
if output_file_name != "":
t.write(output_file_name, overwrite=overwrite)
if print_table:
print(t)
t.convert_bytestring_to_unicode()
return t


# Line catalog

# Hydrogen lines
HALPHA = Line(656.3, atmospheric=False, label='$H\\alpha$', label_pos=[-0.04, 0.02], use_for_calibration=True)
HALPHA = Line(656.3, atmospheric=False, label='$H\\alpha$', label_pos=[-0.02, 0.02], use_for_calibration=True)
HBETA = Line(486.3, atmospheric=False, label='$H\\beta$', label_pos=[0.007, 0.02], use_for_calibration=True)
HGAMMA = Line(434.0, atmospheric=False, label='$H\\gamma$', label_pos=[0.007, 0.02], use_for_calibration=False)
HDELTA = Line(410.2, atmospheric=False, label='$H\\delta$', label_pos=[0.007, 0.02], use_for_calibration=False)
HEPSILON = Line(397.0, atmospheric=False, label='$H\\epsilon$', label_pos=[-0.04, 0.02], use_for_calibration=False)
HEPSILON = Line(397.0, atmospheric=False, label='$H\\epsilon$', label_pos=[-0.02, 0.02], use_for_calibration=False)
HYDROGEN_LINES = [HALPHA, HBETA, HGAMMA, HDELTA, HEPSILON]

# Stellar lines (Fraunhofer lines) https://en.wikipedia.org/wiki/Fraunhofer_lines
Expand All @@ -526,8 +519,8 @@ def print_detected_lines(self, output_file_name="", overwrite=False, print_table
# O2 = Line(762.2, atmospheric=True, label=r'$O_2$', # 762.2 is a weighted average of the O2 line simulated by Libradtran
# label_pos=[0.007, 0.02],
# use_for_calibration=True) # http://onlinelibrary.wiley.com/doi/10.1029/98JD02799/pdf
O2_1 = Line(760.3, atmospheric=True, label='',
label_pos=[0.007, 0.02], use_for_calibration=True) # libradtran paper fig.3
O2_1 = Line(760.3, atmospheric=True, label='$O_2$',
label_pos=[-0.02, 0.02], use_for_calibration=True) # libradtran paper fig.3
O2_2 = Line(763.1, atmospheric=True, label='$O_2$',
label_pos=[0.007, 0.02], use_for_calibration=True) # libradtran paper fig.3
O2B = Line(687.472, atmospheric=True, label=r'$O_2(B)$', # 687.472 is a weighted average of the O2B line simulated by Libradtran
Expand Down
127 changes: 119 additions & 8 deletions spectractor/extractor/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
import string
import astropy
import warnings
import itertools
warnings.filterwarnings('ignore', category=astropy.io.fits.card.VerifyWarning, append=True)

from spectractor import parameters
from spectractor.config import set_logger, load_config, update_derived_parameters
from spectractor.extractor.dispersers import Hologram
from spectractor.extractor.targets import load_target
from spectractor.tools import (ensure_dir, load_fits, plot_image_simple,
from spectractor.tools import (ensure_dir, load_fits, plot_image_simple, plot_table_in_axis,
find_nearest, plot_spectrum_simple, fit_poly1d_legendre, gauss,
rescale_x_to_legendre, fit_multigauss_and_bgd, multigauss_and_bgd, multigauss_and_bgd_jacobian)
from spectractor.extractor.psf import load_PSF
Expand Down Expand Up @@ -382,7 +383,7 @@ def load_filter(self):
t.reset_lambda_range(transmission_threshold=1e-4)
return t

def plot_spectrum(self, ax=None, xlim=None, live_fit=False, label='', force_lines=False):
def plot_spectrum(self, ax=None, xlim=None, live_fit=False, label='', force_lines=False, calibration_only=False):
"""Plot spectrum with emission and absorption lines.

Parameters
Expand All @@ -398,15 +399,20 @@ def plot_spectrum(self, ax=None, xlim=None, live_fit=False, label='', force_line
(default: False).
force_lines: bool
Force the over plot of vertical lines for atomic lines if set to True (default: False).
calibration_only: bool
Plot only the lines used for calibration if True (default: False).

Examples
--------
>>> s = Spectrum(file_name='tests/data/reduc_20170530_134_spectrum.fits')
>>> s.plot_spectrum(xlim=[500,900], live_fit=False, force_lines=True)
"""
if ax is None:
doplot = True
plt.figure(figsize=[12, 6])
ax = plt.gca()
else:
doplot = False
if label == '':
label = f'Order {self.order:d} spectrum\n' \
r'$D_{\mathrm{CCD}}=' \
Expand All @@ -426,18 +432,20 @@ def plot_spectrum(self, ax=None, xlim=None, live_fit=False, label='', force_line
plot_indices = np.logical_and(self.target.wavelengths[k] > np.min(self.lambdas),
self.target.wavelengths[k] < np.max(self.lambdas))
s = self.target.spectra[k] / np.max(self.target.spectra[k][plot_indices]) * np.max(self.data)
ax.plot(self.target.wavelengths[k], s, lw=2, label='Tabulated spectra #%d' % k)
ax.plot(self.target.wavelengths[k], s, lw=2, label=f'Tabulated spectra #{k}')
if self.lambdas is not None:
self.lines.plot_detected_lines(ax, print_table=parameters.VERBOSE)
self.lines.plot_detected_lines(ax)
if self.lines is not None and len(self.lines.table) > 0:
self.my_logger.info(f"\n{self.lines.table}")
if self.lambdas is not None and self.lines is not None:
self.lines.plot_atomic_lines(ax, fontsize=12, force=force_lines)
self.lines.plot_atomic_lines(ax, fontsize=12, force=force_lines, calibration_only=calibration_only)
ax.legend(loc='best')
if self.filters is not None:
ax.get_legend().set_title(self.filters)
plt.gcf().tight_layout()
if parameters.LSST_SAVEFIGPATH: # pragma: no cover
plt.gcf().savefig(os.path.join(parameters.LSST_SAVEFIGPATH, f'{self.target.label}_spectrum.pdf'))
if parameters.DISPLAY: # pragma: no cover
if parameters.DISPLAY and doplot: # pragma: no cover
if live_fit:
plt.draw()
plt.pause(1e-8)
Expand Down Expand Up @@ -503,6 +511,107 @@ def plot_spectrogram(self, ax=None, scale="lin", title="", units="Image units",
if parameters.PdfPages: # pragma: no cover
parameters.PdfPages.savefig()

def plot_spectrum_summary(self, xlim=None, figsize=(12, 12), save_as=''):
"""Plot spectrum with emission and absorption lines.

Parameters
----------
xlim: list, optional
List of minimum and maximum abscisses (default: None).
figsize: tuple
Figure size (default: (12, 12)).
save_as : str, optional
Path to save the figure to, if specified.

Examples
--------
>>> s = Spectrum(file_name='tests/data/reduc_20170530_134_spectrum.fits')
>>> s.plot_spectrum_summary()
"""
if not np.any([line.fitted for line in self.lines.lines]):
fwhm_func = interp1d(self.chromatic_psf.table['lambdas'],
self.chromatic_psf.table['fwhm'],
fill_value=(parameters.CALIB_PEAK_WIDTH, parameters.CALIB_PEAK_WIDTH),
bounds_error=False)
detect_lines(self.lines, self.lambdas, self.data, self.err, fwhm_func=fwhm_func,
calibration_lines_only=True)

def generate_axes(fig):
tableShrink = 3
tableGap = 1
gridspec = fig.add_gridspec(nrows=23, ncols=20)
axes = {}
axes['A'] = fig.add_subplot(gridspec[0:3, 0:19])
axes['C'] = fig.add_subplot(gridspec[3:6, 0:19], sharex=axes['A'])
axes['B'] = fig.add_subplot(gridspec[6:14, 0:19])
axes['CA'] = fig.add_subplot(gridspec[0:3, 19:20])
axes['CC'] = fig.add_subplot(gridspec[3:6, 19:20])
axes['D'] = fig.add_subplot(gridspec[14:16, 0:19], sharex=axes['B'])
axes['E'] = fig.add_subplot(gridspec[16+tableGap:23, tableShrink:19-tableShrink])
return axes

fig = plt.figure(figsize=figsize)
axes = generate_axes(fig)
plt.suptitle(f"{self.target.label} {self.date_obs}", y=0.91, fontsize=16)
mainPlot = axes['B']
spectrogramPlot = axes['A']
spectrogramPlotCb = axes['CA']
residualsPlot = axes['C']
residualsPlotCb = axes['CC']
widthPlot = axes['D']
tablePlot = axes['E']

label = f'Order {self.order:d} spectrum\n' \
r'$D_{\mathrm{CCD}}=' \
rf'{self.disperser.D:.2f}\,$mm'
plot_spectrum_simple(mainPlot, self.lambdas, self.data, data_err=self.err, xlim=xlim, label=label,
title='', units=self.units, lw=1, linestyle="-")
if len(self.target.spectra) > 0:
plot_indices = np.logical_and(self.target.wavelengths[0] > np.min(self.lambdas),
self.target.wavelengths[0] < np.max(self.lambdas))
s = self.target.spectra[0] / np.max(self.target.spectra[0][plot_indices]) * np.max(self.data)
mainPlot.plot(self.target.wavelengths[0], s, lw=2, label='Normalized\nCALSPEC spectrum')
self.lines.plot_atomic_lines(mainPlot, fontsize=12, force=False, calibration_only=True)
self.lines.plot_detected_lines(mainPlot, calibration_only=True)

table = self.lines.build_detected_line_table(calibration_only=True)
plot_table_in_axis(tablePlot, table)

mainPlot.legend()

widthPlot.plot(self.lambdas, np.array(self.chromatic_psf.table['fwhm']), "r-", lw=2)
widthPlot.set_ylabel("FWHM [pix]")
widthPlot.set_xlabel(r'$\lambda$ [nm]')
widthPlot.grid()

spectrogram = np.copy(self.spectrogram)
res = self.spectrogram_residuals.reshape((-1, self.spectrogram_Nx))
std = np.std(res)
if spectrogram.shape[0] != res.shape[0]:
margin = (spectrogram.shape[0] - res.shape[0]) // 2
spectrogram = spectrogram[margin:-margin]
plot_image_simple(spectrogramPlot, data=spectrogram, title='Data',
aspect='auto', cax=spectrogramPlotCb, units='ADU/s', cmap='viridis')
spectrogramPlot.set_title('Data', fontsize=10, loc='center', color='white', y=0.8)
spectrogramPlot.grid(False)
plot_image_simple(residualsPlot, data=res, vmin=-5 * std, vmax=5 * std, title='(Data-Model)/Err',
aspect='auto', cax=residualsPlotCb, units=r'$\sigma$', cmap='bwr')
residualsPlot.set_title('(Data-Model)/Err', fontsize=10, loc='center', color='black', y=0.8)
residualsPlot.grid(False)

# hide the tick labels in the plots which share an x axis
for label in itertools.chain(mainPlot.get_xticklabels(), residualsPlot.get_xticklabels(), spectrogramPlot.get_xticklabels()):
label.set_visible(False)

# align y labels
for ax in [spectrogramPlot, residualsPlot, mainPlot, widthPlot]:
ax.yaxis.set_label_coords(-0.05, 0.5)

fig.subplots_adjust(hspace=0)
if save_as:
plt.savefig(save_as)
plt.show()

def save_spectrum(self, output_file_name, overwrite=False):
"""Save the spectrum into a fits file (data, error and wavelengths).

Expand Down Expand Up @@ -593,7 +702,7 @@ def save_spectrum(self, output_file_name, overwrite=False):
elif extname == "LINES":
u.set_enabled_aliases({'flam': u.erg / u.s / u.cm**2 / u.nm,
'reduced': u.dimensionless_unscaled})
tab = self.lines.print_detected_lines(amplitude_units=self.units.replace("erg/s/cm$^2$/nm", "flam"), print_table=False)
tab = self.lines.build_detected_line_table(amplitude_units=self.units.replace("erg/s/cm$^2$/nm", "flam"))
hdus[extname] = fits.table_to_hdu(tab)
elif extname == "CONFIG":
# HIERARCH and CONTINUE not compatible together in FITS headers
Expand Down Expand Up @@ -1788,7 +1897,9 @@ def detect_lines(lines, lambdas, spec, spec_err=None, cov_matrix=None, fwhm_func
lambda_shifts.append(peak_pos - line.wavelength)
snrs.append(snr)
if ax is not None:
lines.plot_detected_lines(ax, print_table=parameters.DEBUG)
lines.plot_detected_lines(ax)
lines.table = lines.build_detected_line_table()
lines.my_logger.debug(f"\n{lines.table}")
if len(lambda_shifts) > 0:
global_chisq /= len(lambda_shifts)
shift = np.average(np.abs(lambda_shifts) ** 2, weights=np.array(snrs) ** 2)
Expand Down