# Matplotlib GUI for Linear Interpolation of Spectra

### Purpose:
This Notebook provides a graphical user interface for linearly interpolating spectra. This can be useful for filling in gaps in data, or removing telluric lines by hand. 

### How to Use:
In order to use the GUI, simply set the variable `dirfile` to be the absolute path to a directory that contains spectrum files. Run the Notebook, and a Matplotlib Figure Window will pop up showing the first spectrum. The Matplotlib window already has an option to zoom by selecting a rectangular region of the plot. The left and right arrows allow you to go back to the original view, and vice versa.  There are seven new buttons:

Left Bound -- Click this button to choose the left bound of an interpolation region. 

Right Bound -- Click this button to choose the right bound of an interpolation region.

Clear Bounds -- Click this button to erase all interpolation regions shown on the plot.

Interp -- Click this button to interpolate over an interpolation region.

Export -- Click this button to export the interpolated spectrum to an ascii file.

Save Fig -- Click this button to save the Matplotlib Figure shown on the screen.

Next -- Click this button to move on to the next spectrum. The current Figure will close.

In [506]:
%matplotlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from matplotlib.lines import Line2D
import os

Using matplotlib backend: TkAgg


In [510]:
# dirfile should be a directory that ONLY contains spectrum files.
# Each file should consist of ONLY two columns (wvl flux)
dirfile = '/home/data/mew488/Downloads/16coi_spec/'

In [508]:

# The Callable class handles action functions that are
# run when a button is clicked.
class Callable:
    def __init__(self,interpolator):
        self.interpolator = interpolator
        return
    
    def lvline(self,event):
        """
        Draws the left bound for an interpolation region.
        """
        ymin, ymax = self.interpolator.axis.get_ylim()
        xmin, xmax = self.interpolator.axis.get_xlim()
        xd = event.xdata
        yd = event.ydata
        if event.inaxes == self.interpolator.axis:
            if self.interpolator.lbound_on == False: return
            print 'lvl'
            
            self.interpolator.xlbound.append(event.xdata)
            
            line = Line2D([event.xdata,event.xdata],[ymin, ymax],color='r',alpha=0.2)
            self.interpolator.axis.add_line(line)
            self.interpolator.fig.canvas.draw()
            self.interpolator.lbound_on = False
        return
    
    def rvline(self,event):
        """Draws the right bound for an interpolation region."""
        xd = event.xdata
        yd = event.ydata
        if event.inaxes == self.interpolator.axis:
            if self.interpolator.rbound_on == False: return
            print 'rvl'
            
            self.interpolator.xrbound.append(event.xdata)
            
            ymin, ymax = self.interpolator.axis.get_ylim()
            line = Line2D([event.xdata,event.xdata],[ymin, ymax],color='r',alpha=0.2)
            self.interpolator.axis.add_line(line)
            self.interpolator.fig.canvas.draw()
            self.interpolator.rbound_on = False
        return
    
    def lbound_click(self,event):
        self.interpolator.lbound_on = True
        cid = self.interpolator.fig.canvas.mpl_connect('button_press_event',self.lvline)
        print 'lbc'
        return   
    def rbound_click(self,event):
        self.interpolator.rbound_on = True
        cid = self.interpolator.fig.canvas.mpl_connect('button_press_event',self.rvline)
        print 'rbc'
        return
    
    def clear_bounds_click(self, event):
        for line in self.interpolator.axis.lines[1:]:
            self.interpolator.axis.lines.remove(line)
        self.interpolator.fig.canvas.draw()
        self.interpolator.lbound_on = False
        self.interpolator.rbound_on = False
        self.interpolator.interp_on = True
        self.interpolator.xlbound = []
        self.interpolator.xrbound = []
        if hasattr(self.interpolator,'flux_interp'): del self.interpolator.flux_interp
    
    def interp_click(self,event):

        if self.interpolator.interp_on == False: return
        self.interpolator.interp_on = False
        xlbound_len = len(self.interpolator.xlbound)
        xrbound_len = len(self.interpolator.xrbound)
        assert xlbound_len == xrbound_len, 'bound lengths are different'
        total_msk = np.array(len(self.interpolator.specdata[:,0])*[True])
        for i in range(xlbound_len):
            xlb = self.interpolator.xlbound[i]
            xrb = self.interpolator.xrbound[i]
            msk = np.logical_or(self.interpolator.specdata[:,0] < xlb,\
                             self.interpolator.specdata[:,0] > xrb)
            total_msk = np.logical_and(total_msk, msk)
        fluxdata_gap = self.interpolator.specdata[:,1][total_msk]
        wvl_gap = self.interpolator.specdata[:,0][total_msk]
        flux_interp = np.interp(self.interpolator.specdata[:,0], wvl_gap, fluxdata_gap)
        self.interpolator.flux_interp = flux_interp
        self.interpolator.axis.plot(self.interpolator.specdata[:,0], \
                                    flux_interp, c='r', linestyle='--',alpha=0.5, label='Interpolated Data')
        self.interpolator.axis.legend()
        self.interpolator.fig.canvas.draw()

        return
    
    def export_click(self, event):
        if self.interpolator.export_on == False: return
        if not hasattr(self.interpolator, 'flux_interp'):
            print "You haven't interpolated yet!"
            return
        self.interpolator.export_on = False
        break_ind = self.interpolator.specpath.index('.')
        newfname = self.interpolator.specpath[:break_ind] + '_' + self.interpolator.export_suffix +\
                   self.interpolator.specpath[break_ind:]
        np.savetxt(newfname, np.column_stack((self.interpolator.specdata[:,0], self.interpolator.flux_interp)),\
                  delimiter=' ')
        print "Exported interpolated data to:"
        print newfname
        return
    
    def next_click(self, event):
        plt.close(self.interpolator.fig)
        self.interpolator.spec_counter = self.interpolator.spec_counter + 1
        if self.interpolator.spec_counter < self.interpolator.max_counter:
            self.interpolator.specpath = self.interpolator.allspecpaths[self.interpolator.spec_counter]
            self.interpolator.specdata = self.interpolator.allspecdata[self.interpolator.spec_counter]
        
            self.interpolator.__init__(self.interpolator.dirpath, self.interpolator.export_suffix, \
                                   self.interpolator.spec_counter, self.interpolator.allspecdata)
        
        return
    
    def save_click(self, event):
        ind = self.interpolator.specpath.index('.')
        self.interpolator.fig.savefig(self.interpolator.dirpath + '/' + self.interpolator.specpath[:ind] +\
                                     '_'+self.interpolator.export_suffix+'_fig')
        print "Saved figure to: "
        print self.interpolator.dirpath + '/' + self.interpolator.specpath[:ind] +\
                                     '_'+self.interpolator.export_suffix+'_fig'
        return

class Interpolator():

    def __init__(self, dirpath, export_suffix,counter=0,loaded_data=None):
        self.calls = Callable(self)
        
        self.xrbound = []
        self.xlbound = []
        
        if loaded_data is None:
            self.allspecpaths = os.listdir(dirpath)
            self.dirpath = dirpath
            self.allspecdata = []
            for specfile in self.allspecpaths:
                d = np.loadtxt(self.dirpath + '/'+specfile)
                self.allspecdata.append(d)
        else:
            self.dirpath = dirpath
            self.allspecdata = loaded_data
        self.max_counter = len(self.allspecdata)
        self.spec_counter = counter
        self.specpath = self.allspecpaths[self.spec_counter]
        self.specdata = self.allspecdata[self.spec_counter]
        self.export_suffix = export_suffix
        self.fig = plt.figure(figsize=(10,5))
        self.axis = self.fig.add_subplot(111)
        self.axis.plot(self.specdata[:,0], self.specdata[:,1],c='k',alpha=0.5,label='Original Data')
        self.axis.set_title(self.allspecpaths[self.spec_counter])
        self.axis.set_ylabel('Relative Flux')
        self.axis.set_xlabel('Wavelength ($\AA$)')
        self.axis.legend()
        
        
        self.lbound_on = False
        self.rbound_on = False
        self.interp_on = True
        self.export_on = True
        
        self.clbound_button_ax = self.fig.add_axes([.925,.35,.05,.075])
        self.clbound_button = Button(self.clbound_button_ax,'Clear \nBounds')
        self.clbound_button.on_clicked(self.calls.clear_bounds_click)
        
        self.lbound_button_ax = self.fig.add_axes([.925,.25,.05,.075])
        self.lbound_button = Button(self.lbound_button_ax,'Left \nBound')
        self.lbound_button.on_clicked(self.calls.lbound_click)
        
        self.rbound_button_ax = self.fig.add_axes([.925,.15,.05,.075])
        self.rbound_button = Button(self.rbound_button_ax,'Right \nBound')
        self.rbound_button.on_clicked(self.calls.rbound_click)
        
        self.interp_button_ax = self.fig.add_axes([.925,.05,.05,.075])
        self.interp_button = Button(self.interp_button_ax,'Interp')
        self.interp_button.on_clicked(self.calls.interp_click)

        self.export_button_ax = self.fig.add_axes([.925,.45,.05,.075])
        self.export_button = Button(self.export_button_ax,'Export')
        self.export_button.on_clicked(self.calls.export_click)
        
        self.next_button_ax = self.fig.add_axes([.025,.05,.05,.075])
        self.next_button = Button(self.next_button_ax,'Next')
        self.next_button.on_clicked(self.calls.next_click)     
        
        self.save_button_ax = self.fig.add_axes([.025,.15,.05,.075])
        self.save_button = Button(self.save_button_ax,'Save \nFig')
        self.save_button.on_clicked(self.calls.save_click)  
        
        return
    

    

In [511]:
inter = Interpolator(dirfile,'interp')