In [ ]:
import matplotlib.pyplot as plt
import DMK_go_coude as Fns
import numpy as np
import os, readcol
import scipy.optimize as optim
import scipy.interpolate as interp
import pickle

from astropy.io import fits 
from mpfit import mpfit
from scipy import signal

%matplotlib inline

In [ ]:
dir = os.getenv("HOME") + '/Research/YMG/coude_data/20140321/'
rdir = dir + 'reduction/'
codedir = os.getenv("HOME") + '/codes/coudereduction/'
#codedir = os.getenv("HOME") + '/Research/Codes/coudereduction/'

In [ ]:
os.chdir(dir)

In [ ]:
DarkCurVal = 0.0

InfoFile = 'headstrip.csv'
FileInfo = readcol.readcol( InfoFile, fsep = ',', asRecArray = True )
DarkCube = FileInfo.ExpTime * DarkCurVal

BiasInds = np.where( FileInfo.Type == 'zero' )[0]
FlatInds = np.where( FileInfo.Type == 'flat' )[0]
ArcInds  = np.where( (FileInfo.Type == 'comp') & ( (FileInfo.Object == 'Thar') | (FileInfo.Object == 'THAR') | (FileInfo.Object == 'A') ) )[0]
ObjInds  = np.where( (FileInfo.Type == 'object') & (FileInfo.Object != 'SolPort') & (FileInfo.Object != 'solar port') & (FileInfo.Object != 'solar_ort') )[0]

CalsDone = True
SuperBias, FlatField = Fns.Basic_Cals( FileInfo.File[BiasInds], FileInfo.File[FlatInds], CalsDone, rdir, plots = False )

ShowBPM = False
BPM = Fns.Make_BPM( SuperBias, FlatField, 99.9, ShowBPM )

RdNoise  = FileInfo.rdn[ArcInds] / FileInfo.gain[ArcInds]
DarkCur  = DarkCube[ArcInds] / FileInfo.gain[ArcInds]
ArcCube, ArcSNR = Fns.Make_Cube( FileInfo.File[ArcInds], RdNoise, DarkCur, Bias = SuperBias )

RdNoise  = FileInfo.rdn[ObjInds] / FileInfo.gain[ObjInds]
DarkCur  = DarkCube[ObjInds] / FileInfo.gain[ObjInds]
ObjCube, ObjSNR = Fns.Make_Cube( FileInfo.File[ObjInds], RdNoise, DarkCur, Bias = SuperBias, Flat = FlatField, BPM = BPM )

OrderStart = -32
TraceDone = True
MedCut = 95.0
MedTrace, FitTrace = Fns.Get_Trace( FlatField, ObjCube, OrderStart, MedCut, rdir, TraceDone, plots = False )

spec       = pickle.load(open(rdir+'extracted_spec_oldway.pkl','rb'))
sig_spec   = pickle.load(open(rdir+'extracted_sigspec_oldway.pkl','rb'))
wspec      = pickle.load(open(rdir+'extracted_wspec_oldway.pkl','rb'))
sig_wspec  = pickle.load(open(rdir+'extracted_sigwspec_oldway.pkl','rb'))

wspec      = wspec[:,::-1,:]
sig_wspec  = sig_wspec[:,::-1,:]
spec       = spec[:,::-1,:]
sig_spec   = sig_spec[:,::-1,:]

In [ ]:
#fullspec, fullsig_spec = Fns.extractor( ObjCube, ObjSNR, FitTrace, quick = False, nosub = False, arc = False )

In [ ]:
def Gaussian( x, A, mean, sigma, const ):
    gauss = A * np.exp( - ( x - mean ) ** 2.0 / ( 2.0 * sigma ** 2 ) ) + const
    return gauss
    
def TwoGaussian( x, A1, mean1, sigma1, const, A2, mean2, sigma2 ):
    return Gaussian( x, A1, mean1, sigma1, const ) + Gaussian( x, A2, mean2, sigma2, const )

In [ ]:
def Get_WavSol( Cube, RoughSol, OrderToFit, plots = True ):
    
    orderdif = Cube.shape[1] - RoughSol.shape[0]
    
    arcspec   = Cube[0, OrderToFit, :]
    startwsol = RoughSol[OrderToFit - orderdif]
    
    arcspec           = arcspec - np.min( arcspec )
    belowmed          = np.where( arcspec < np.median( arcspec ) )
    arcspec[belowmed] = np.median( arcspec )
    logarcspec        = np.log10( arcspec )
    logarcspec        = logarcspec - np.min( logarcspec )
        
    THAR        = fits.open( codedir + 'thar_photron.fits' )[0]
    THARhead    = THAR.header
    THARspec    = THAR.data
    THARwav     = np.arange( len( THARspec ) ) * THARhead['CDELT1'] + THARhead['CRVAL1']
    THARlines   = readcol.readcol( codedir + 'ThAr_list.txt', asRecArray = True )
    logTHARspec = np.log10(THARspec)

    wavsols, wavkeep = Fit_WavSol( startwsol, arcspec, THARlines.wav, [50], minsep = 1.5, plots = plots )

    if plots == True:
        plt.clf()
        plt.plot( wavsols[-1], logarcspec, 'k-' )
        plt.plot( THARwav, logTHARspec, 'r-' )
        plt.xlim( wavsols[-1,0], wavsols[-1,-1] )
        for peak in wavkeep:
            plt.axvline( x = peak, color = 'b', ls = ':' )
        plt.show()

        Plot_Wavsol_Windows( wavsols, wavkeep, logarcspec, THARwav, logTHARspec )
    
    return wavsols, wavkeep

def Plot_Wavsol_Windows( wavsols, wavkeep, spec, tharwav, tharspec ):
    
    plotwav  = wavsols[-1]
    startwav = np.min( plotwav )
    
    while startwav <= np.max( plotwav ):
        plt.clf()
        plt.plot( plotwav, spec, 'k-' )
        plt.plot( tharwav, tharspec, 'r-' )
        plt.xlim( startwav, startwav + 10 )
        for peak in wavkeep:
            plt.axvline( x = peak, color = 'b', ls = ':' )
        plt.show()
        
        startwav += 10
        
    return None

In [ ]:
def Find_Peaks_Wavelet( wav, spec, peaksnr = 5, pwidth = 10, minsep = 1 ):
    
    # Find peaks using the cwt routine from scipy.signal
    peaks = signal.find_peaks_cwt( spec, np.arange( 1, 2 ), min_snr = peaksnr, noise_perc = 20 )
    
    # Offset from start/end of spectrum by some number of pixels
    peaks = peaks[ (peaks > pwidth) & (peaks < len(spec) - pwidth) ]
    
    pixcent = np.array([])
    wavcent = np.array([])
        
    for peak in peaks:
        
        xi   = wav[peak - pwidth:peak + pwidth]
        yi   = spec[peak - pwidth:peak + pwidth]
        inds = np.arange( len(xi), dtype = float )
        
        pguess   = [ yi[9], np.median( inds ), 0.9, np.median( spec ) ]
        lowerbds = [ 0.1*pguess[0], pguess[1] - 2.0, 0.3, 0.0  ]
        upperbds = [ np.inf, pguess[1] + 2.0, 1.5, np.inf ]
        
        try:
            params, pcov = optim.curve_fit( Gaussian, inds, yi, p0 = pguess, bounds = (lowerbds,upperbds) )
            
            pixval  = peak - pwidth + params[1]
            pixcent = np.append( pixcent, pixval )
            
            ceiling = np.ceil( pixval ).astype(int)
            floor   = np.floor( pixval ).astype(int)
            slope   = ( wav[ceiling] - wav[floor] ) / ( ceiling - floor )
            wavval  = wav[floor] + slope * ( pixval - floor )
            wavcent = np.append( wavcent, wavval )
            
        except RuntimeError:
            pixval  = 'nan'
            
    vals = spec[pixcent.astype(int)]
    oks  = np.ones( len(pixcent), int )
    
    for i in range( len(wavcent) ):
        dist  = np.absolute( wavcent - wavcent[i] )
        close = np.where( dist <= minsep )[0]
        small = np.where( vals[close] < np.max( vals[close] ) )[0]
        if len(small) != 0: oks[close[small]] = -1
            
    keep    = np.where( oks == 1 )
    pixcent = pixcent[keep]
    wavcent = wavcent[keep]
            
    return pixcent, wavcent

def Fit_WavSol( wav, spec, linecatalog, snrarr, pwidth = 10, minsep = 1, plots = True ):
    
    wavsols    = np.zeros( ( len(snrarr) + 1, len(spec) ) )
    wavsols[0] = wav
    
    for i in range( len(snrarr) ):
        
        pixcent, wavcent = Find_Peaks_Wavelet( wavsols[i], spec, peaksnr = snrarr[i], minsep = minsep )
        
        THARthisord = [ line for line in linecatalog if line > wavsols[i,0] and line < wavsols[i,-1] ]
        
        pixkeep   = np.array([])
        wavkeep   = np.array([])
        lineskeep = np.array([])

        pixrej   = np.array([])
        wavrej   = np.array([])
        linesrej = np.array([])

        for j in range( len(wavcent) ):
            dists    = np.absolute( THARthisord - wavcent[j] )
            mindist  = np.min( dists )
            mindisti = np.argmin( dists )
            
            if mindist <= 1.5:
                pixkeep   = np.append( pixkeep, pixcent[j] )
                wavkeep   = np.append( wavkeep, wavcent[j] )
                lineskeep = np.append( lineskeep, THARthisord[mindisti] )
        
        dofit = True
        
        while dofit:
            wavparams  = np.polyfit( pixkeep, lineskeep, 4 )
            ptsfromfit = np.polyval( wavparams, pixkeep )
            
            wavresids  = ptsfromfit - lineskeep
            wavabsdev  = np.absolute( wavresids )
            
            velresids  = wavresids / lineskeep * 3e5
            velabsdev  = np.absolute( velresids )
            
            abovekms   = len( np.where( velresids >= 1.0 )[0] )
            
            toreject   = wavabsdev >= 3.0 * np.median( wavabsdev )
            tokeep     = np.logical_not( toreject )
            numrej     = len( np.where( toreject == True )[0] )
                        
            if abovekms > 0:                
                if numrej > 0:
                    pixrej    = pixkeep[toreject]
                    wavrej    = wavkeep[toreject]
                    linesrej  = lineskeep[toreject]

                    pixkeep   = pixkeep[tokeep]
                    wavkeep   = wavkeep[tokeep]
                    lineskeep = lineskeep[tokeep]

                    if plots:
                        plt.clf()
                        fig, (wavax, velax) = plt.subplots( 2, 1, sharex = 'all' )
                        wavax.axhline( y = 0.0, color = 'k', ls = ':')
                        velax.axhline( y = 0.0, color = 'k', ls = ':')
                        wavax.plot( linesrej, wavresids[toreject], 'kx' )
                        wavax.plot( lineskeep, wavresids[tokeep], 'ko', mfc = 'none' )
                        velax.plot( linesrej, velresids[toreject], 'kx' )
                        velax.plot( lineskeep, velresids[tokeep], 'ko', mfc = 'none' )
                        for x in [ -3, 3 ]:
                            wavax.axhline( y = x * np.median( wavabsdev ), color = 'r', ls = '--' )
                            velax.axhline( y = x * np.median( velabsdev ), color = 'r', ls = '--' )
                        wavax.set_ylabel('Resids ($\AA$)')
                        velax.set_ylabel('Resids (km/s)' )
                        fig.subplots_adjust(hspace=0)
                        plt.show()
                    
                elif numrej == 0:
                    print 'There is something seriously wrong.\n'
                    print 'There are points > 1 km/s, but none are found to be rejected. FIX'
                    break
                
            else:
                if plots:
                    plt.clf()
                    fig, (wavax, velax) = plt.subplots( 2, 1, sharex = 'all' )
                    wavax.axhline( y = 0.0, color = 'k', ls = ':')
                    velax.axhline( y = 0.0, color = 'k', ls = ':')
                    wavax.plot( lineskeep, wavresids, 'ko', mfc = 'none' )
                    velax.plot( lineskeep, velresids, 'ko', mfc = 'none' )
                    for x in [ -3, 3 ]:
                        wavax.axhline( y = x * np.median( wavabsdev ), color = 'r', ls = '--' )
                        velax.axhline( y = x * np.median( velabsdev ), color = 'r', ls = '--' )
                    wavax.set_ylabel('Resids ($\AA$)')
                    velax.set_ylabel('Resids (km/s)' )
                    fig.subplots_adjust(hspace=0)
                    plt.show()
                    print 'No points are worse than 1 km/s, fit good.'

                dofit = False
        
        wavsol       = np.polyval( wavparams, np.arange( len(spec) ) )
        wavsols[i+1] = wavsol
        
    return wavsols, wavkeep

In [ ]:
#testord = wspec.shape[1] - 15 - 1
testord = 58
oldwsol = pickle.load( open( codedir + 'global_wsol_arizz.pkl', 'rb' ) )[0]
oldord  = testord - ( wspec.shape[1] - oldwsol.shape[0] )

print testord, oldord

In [ ]:
wavsols, wavkeep = Get_WavSol( wspec, oldwsol, testord )

In [ ]:
print testord
print len( wavkeep )
print np.min( wavsols[-1] )
print np.max( wavsols[-1] )

In [ ]:
print oldwsol[47]