In [None]:
# import pandas as pd
import numpy as np
# import logging
# import matplotlib.pyplot as plt
import plotly.graph_objects as go
# from plotly.subplots import make_subplots
# from random import sample
# import tensorflow as tf
# from keras import models
# from scipy.signal import find_peaks
# from sklearn.decomposition import PCA
# import itertools

In [None]:
class LineFinder():

    def __init__(self,spectrum,wvl,name,**kwargs):

        self.spectrum = spectrum.copy()
        self.wvl = wvl.copy()
        self.name = name

        self.height = .001 if 'height' not in kwargs else kwargs['height']
        self.prominence = .003 if 'prominence' not in kwargs else kwargs['prominence']
        self.distance = 2 if 'distance' not in kwargs else kwargs['distance']
        self.width = 3 if 'width' not in kwargs else kwargs['width']
        self.threshold = 1e-6 if 'threshold' not in kwargs else kwargs['threshold']
        self.rel_height = 0.8 if 'rel_height' not in kwargs else kwargs['rel_height']
        self.wlen = 50 if 'wlen' not in kwargs else kwargs['wlen']

    def find_lines(self,**kwargs):

        self.peaks = find_peaks(
            self.spectrum,            
            height=self.height,
            prominence=self.prominence,
            distance=self.distance,
            width=self.width,
            threshold=self.threshold,
            rel_height=self.rel_height,
            wlen=self.wlen
        )

    def plot_found_lines(self,show_cond=True):
        try:
            getattr(self, 'peaks')
        except AttributeError:
            print('performing line finding')
            self.find_lines()

        self.plot = go.Figure()
        self.plot = self.plot.add_trace(
            go.Scatter(
                x=np.squeeze(self.wvl),
                y=np.squeeze(self.spectrum),
                showlegend=True,
                name=self.name
            )
        )

        self.plot.add_trace(
            go.Scatter(
                x=self.wvl.iloc[self.peaks[0]],
                y=self.spectrum[self.peaks[0]],
                mode='markers',
                name='Found peaks'
            )
        )

        self.plot = _update_layout(self.plot)        

        for ndx in range(
            len(self.peaks[1]['left_bases'])
        ): 

            rect_color = 'rgb(255,0,0)' if ndx % 2 == 0 else 'rgb(255,255,0)'

            self.plot.add_shape(
                x0=self.wvl.iloc[self.peaks[1]['left_bases'][ndx]],
                x1=self.wvl.iloc[self.peaks[1]['right_bases'][ndx]],
                y0=0,
                y1=self.peaks[1]['peak_heights'][ndx] * 1.2,
                opacity=.2,
                fillcolor=rect_color
            )

        if show_cond: self.plot.show()

    def add_spectrum_to_plot(
        self,
        spectrum,
        name,
        wvl=None,
        scale=True,
        show_cond=True
    ):
        try:
            getattr(self, 'plot')
        except AttributeError:
            print('creating base plot')
            self.plot_found_lines(False)

        if wvl is None: wvl = self.wvl

        spectrum = spectrum.copy()
        if scale: spectrum *= np.max(self.spectrum)

        self.plot.add_trace(
            go.Scatter(
                x=wvl,
                y=spectrum,
                mode='lines',
                name=name
            )
        )

        if show_cond: self.plot.show()

    def find_peaks_in_reference(
        self,
        spectrum,
        name,
        wvl=None,
        scale=True,
        show_cond=True
    ):
        """
        currently assumes that the wavelengts are the same as that of the 
        initialized spectrum
        """
        try:
            getattr(self, 'plot')
        except AttributeError:
            print('creating base plot')
            self.plot_found_lines(False)
        try:
            getattr(self, 'peaks')
        except AttributeError:
            print('performing line finding')
            self.find_lines()

        if wvl is None: wvl = self.wvl

        self.reference_peak_indices = []

        for ndx in range(
            len(self.peaks[1]['left_bases'])
        ):

            left_ndx = self.peaks[1]['left_bases'][ndx]
            right_ndx = self.peaks[1]['right_bases'][ndx]

            self.reference_peak_indices.append(
                spectrum[left_ndx:right_ndx].argmax() + left_ndx
            )

        spectrum = spectrum.copy()
        if scale: spectrum *= np.max(self.spectrum)

        self.plot.add_trace(
            go.Scatter(
                x=wvl,
                y=spectrum,
                mode='lines',
                name=name
            )
        )
            
        self.plot.add_trace(
            go.Scatter(
                x=wvl.iloc[self.reference_peak_indices],
                y=spectrum[self.reference_peak_indices],
                mode='markers',
                name='Spectrum  peaks'
            )
        )

        if show_cond: self.plot.show()

    def match_peaks_to_tables(
        self,
        line_tables,
        verbose=False
    ):

        self.potential_lines = {
            x:_get_potential_lines(
                x,
                line_tables=line_tables,
                verbose=verbose
            )
            for x
            in self.wvl.iloc[self.reference_peak_indices]
        }

    def plot_potential_lines(
        self,
        show_cond=True
    ):

        try:
            getattr(self, 'plot')
        except AttributeError:
            print('creating base plot')
            self.plot_found_lines(False)
        try:
            getattr(self, 'peaks')
        except AttributeError:
            print('performing line finding')
            self.find_lines()
        try:
            getattr(self, 'potential_lines')
        except AttributeError:            
            raise AttributeError(
                '''peaks must be matched to reference tables first 
                via match_peaks_to_tables()'''
            )

        self.plot_pl = go.Figure()

        for trace in self.plot.data:
            if trace['name'] != 'Found peaks':
                self.plot_pl.add_trace(
                    trace
                )

        self.plot_pl = _update_layout(self.plot_pl)

        max_y = max(
            map(
                max,
                map(
                    lambda x: x['y'],
                    self.plot_pl.data
                )
            )
        )
        min_y = min(
            map(
                min,
                map(
                    lambda x: x['y'],
                    self.plot_pl.data
                )
            )
        )

        HOVER_COUNT = 100
        y_hover_coords = np.linspace(min_y,max_y,HOVER_COUNT)

        for line in self.potential_lines:

            print_data = self.potential_lines[line].\
              sort_values(['intens','sp_num']).\
              groupby(['element','sp_num']).\
              head(1).\
              sort_values(
                  'intens',
                  ascending=False
              )

            print_header = 'line | NIST intensity | E<sub>k</sub> (cm<sup>-1</sup>) | A<sub>ki</sub> (s<sup>-1</sup>)'
            
            full_hover_text = '{:.2f} :: <br> {} <br> {}'.format(
                line,
                print_header,
                '-'*(len(print_header) // 2)
            )

            for ndx in print_data.index:
                print_element = print_data.loc[ndx,'element']
                print_ionization = 'I'*print_data.loc[ndx,'sp_num']
                print_linecenter = print_data.loc[ndx,'obs_wl_air(nm)']
                print_intens = print_data.loc[ndx,'intens']
                print_Ek = print_data.loc[ndx,'Ek(cm-1)']
                print_Aki = print_data.loc[ndx,'Aki(s^-1)']

                hover_text = '{} {} {:.2f} nm | {} | {} | {}'.format(
                    print_element,
                    print_ionization,
                    print_linecenter,
                    print_intens,
                    print_Ek,
                    print_Aki
                )

                full_hover_text = f'{full_hover_text} <br> {hover_text}'

            if len(full_hover_text) == 0:
                full_hover_text = 'unknown'

            self.plot_pl.add_vline(
                line,        
                opacity=.2
            )
            self.plot_pl.add_trace(
                go.Scatter(
                    x=[line] *HOVER_COUNT,
                    y=y_hover_coords,
                    name='',
                    opacity=.0,
                    showlegend=False,
                    hovertext=full_hover_text,
                    hoverinfo='text',
                    hoverlabel=dict(
                        bgcolor='rgba(250,250,250,.5)'
                    )
                )
            )
            
        if show_cond: self.plot_pl.show()