# Importing and Defining

#### GRATING = 600ZD

In [None]:
import os
import math
import statistics
import numpy as np 
from astropy.io import fits 
# from smooth_kevin import smoother
import py_specrebin
import matplotlib.pyplot as plt 
from matplotlib import rc
# import pandas as pd
path_name = '.'
optimized_data_path = path_name + '/Optimized Data'

In [None]:
new_wave_600 = np.arange(4000, 11000, .65) 
new_wave_1200 = np.arange(6000, 11000, .33) 

In [None]:
def get_original_data(file_names,mask_name):
    
    tot_flux = []
    tot_wave = []
    tot_ivar = []
    
    for j in range(len(file_names)):
        #read in star data
        h_star = fits.open(path_name + '/' + 'data/{0}'.format(mask_name) + '/' + file_names[j], ignore_missing_end = True)
        
        data_star1 = h_star[1].data
        star_flux1 = data_star1['SKYSPEC'][0]
        star_wave1 = data_star1['LAMBDA'][0]
        star_ivar1 = data_star1['IVAR'][0]
        
        data_star2 = h_star[2].data
        star_flux2 = data_star2['SKYSPEC'][0]
        star_wave2 = data_star2['LAMBDA'][0]
        star_ivar2 = data_star2['IVAR'][0]
        
        
        #combine the blue and red side into one list
        star_flux = np.array(list(star_flux1) + list(star_flux2))
        star_wave = np.array(list(star_wave1) + list(star_wave2))
        star_ivar = np.array(list(star_ivar1) + list(star_ivar2))
        
        if (sum(star_flux) and sum(star_ivar) and sum(star_wave)) == 0:
            file_name_split = file_names[j].split(".")
            serendip_file_name = "{0}.{1}.{2}.serendip1.{3}.{4}".format(file_name_split[0],file_name_split[1],
                                                                   file_name_split[2],file_name_split[4],file_name_split[5])
            path_to_serendip = fits.open(path_name + '/' + "data/{0}/{1}".format(mask_name,serendip_file_name))
            
            star_flux1_serendip = path_to_serendip[1].data["SKYSPEC"][0]
            star_flux2_serendip = path_to_serendip[2].data["SKYSPEC"][0]
            star_flux_serendip = np.concatenate((star_flux1_serendip,star_flux2_serendip))
            
            star_ivar1_serendip = path_to_serendip[1].data["IVAR"][0]
            star_ivar2_serendip = path_to_serendip[2].data["IVAR"][0]
            star_ivar_serendip = np.concatenate((star_ivar1_serendip,star_ivar2_serendip))
            
            star_wave1_serendip = path_to_serendip[1].data["LAMBDA"][0]
            star_wave2_serendip = path_to_serendip[2].data["LAMBDA"][0]
            star_wave_serendip = np.concatenate((star_wave1_serendip,star_wave2_serendip))
            
            tot_flux.append(star_flux_serendip)
            tot_wave.append(star_wave_serendip)
            tot_ivar.append(star_ivar_serendip)
            
            h_star.close()
        
        else:
            #add to above lists
            tot_flux.append(star_flux)
            tot_wave.append(star_wave)
            tot_ivar.append(star_ivar)

            h_star.close()
        
    return tot_flux, tot_wave, tot_ivar 

In [None]:
def rebin(fluxes, waves, ivar, grating):
    
    rbflux = []
    rbivar = []
    
    if grating == 600:
        new_wave = new_wave_600
    elif grating == 1200:
        new_wave = new_wave_1200
    
    for i in range(len(waves)):
        new_flux,new_ivar = py_specrebin.rebinspec(waves[i],fluxes[i],new_wave,ivar=ivar[i])
#         new_flux_err = 1/np.sqrt(new_ivar)

        rbflux.append(new_flux)
        rbivar.append(new_ivar)
        
    return rbflux, new_wave, rbivar

In [None]:
def find_median(rebinned_flux_array):
    
    median_vals = []
    
    print(len(rebinned_flux_array))
    
    for i in range(len(rebinned_flux_array[0])):

        comp = []
        
        for array in rebinned_flux_array:
            
            if np.isfinite(array[i]) == True:
                comp.append(array[i])
                
        median_vals.append(np.median(comp))
        
    return median_vals

In [None]:
def get_exclusions():
    filepath = 'ISM_EM_LINES.txt'
    fp = open(filepath)
    all_data = []
    for line in (fp):
        mask_name = line.split(':')[0].split('_')[0]
        slit_number = line.split(':')[1].strip().split(" ")[0]
        if len(slit_number) == 2:
            slit_number = '0' + slit_number
        elif len(slit_number) == 1:
            slit_number = '00' + slit_number
        else:
            pass
        object_id = line.split(':')[1].strip().split()[1]
        data = {}
        data['mask_name'] = mask_name
        data['slit_number'] = slit_number
        data['object_id'] = object_id
        all_data.append(data)
    return all_data     

In [None]:
def get_files_to_include(folder):
    import os
    list_of_files_to_include = []
    list_of_files_to_exclude = []
    serendip_files = []
    all_file_names_in_folder = os.listdir('data/{}'.format(folder))
    y = len(all_file_names_in_folder)
    print("The number of files in the folder is {0}".format(y))
    all_data = get_exclusions()
    len_all_data = len(all_data)
    for n in range(y):
        parts_of_file_name = all_file_names_in_folder[n].split(".")
        if parts_of_file_name[0] == 'spec1d': # avoids hidden DS_Store files on my mac
            object_id = parts_of_file_name[3]
            slit_number = parts_of_file_name[2]
            mask_name = parts_of_file_name[1]
            should_include = True
            should_exclude = True
            for k in range(len_all_data):
                if ((object_id == all_data[k]['object_id']) and (slit_number == all_data[k]['slit_number']) and (mask_name == all_data[k]['mask_name'])):
                    should_include = False
                    should_exclude = True
                if 'serendip' in object_id:
                    should_include = False
                    should_exclude = False
            if should_include == True:
                list_of_files_to_include.append(all_file_names_in_folder[n])       
            elif should_exclude == True:
                list_of_files_to_exclude.append(all_file_names_in_folder[n])
            elif should_include == False & should_exclude == False:
                serendip_files.append(all_file_names_in_folder[n])
    
    print('The number of files left after exclusions is {0}'.format(len(list_of_files_to_include)))
    
    return sorted(list_of_files_to_include), sorted(list_of_files_to_exclude), sorted(serendip_files)


## Function to Save The Rebinned Data

In [None]:
#Sarthak's function as modified by Liv Gaunt
def exportToFits(rbflux, rbwave, rbivar, mask_name, file_names, incl_or_excl):

    for i in range(len(rbflux)):
            
        hdu1 = fits.PrimaryHDU() #primary HDU (empty)
        hdu1.header['INCLUDE'] = (incl_or_excl, 'Include in median calc if T') #this sets the tag for inclusion
            
        c1 = fits.Column(name='RBFLUX', array=rbflux[i], format='E')
        c2 = fits.Column(name='RBWAVE', array=rbwave, format='E') #no [i] on rbwave since it's just one array
        c3 = fits.Column(name='RBIVAR', array=rbivar[i], format='E')
        hdu2 = fits.BinTableHDU.from_columns([c1, c2, c3]) #first extensional HDU (w data)
            
        hdul = fits.HDUList([hdu1, hdu2]) #combine both HDUs into file and write it below
            
        #this part puts the files to include in one folder, and those to exclude in another
        if incl_or_excl == True:
            hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Included'.format(mask_name) + '/' + 'rebinned_{0}'.format(file_names[i]))
            
        elif incl_or_excl == False:
            hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Excluded'.format(mask_name) + '/' + 'rebinned_{0}'.format(file_names[i]))
                
        else:
            hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Serendip'.format(mask_name) + '/' + 'rebinned_{0}'.format(file_names[i]))


## Function to Read FITS File and Get Back Rebin Data

In [None]:
def get_fits_rebinned_data (mask_name, file_names, incl_or_excl):
    if incl_or_excl == True:
        
        #all of these are libraries
        
        rbflux = []
        rbwave = []
        rbivar = []
        
        for slit in file_names: 
            rb_fits_include = fits.open(optimized_data_path + '/{0}'.format(mask_name) + "/{0}_Rebinned/{0}_Included/rebinned_{1}".format(mask_name,slit))
            rbflux.append(rb_fits_include[1].data["RBFLUX"])
            rbwave.append(rb_fits_include[1].data["RBWAVE"])
            rbivar.append(rb_fits_include[1].data["RBIVAR"])
            
    elif incl_or_excl == False: 
        
        rbflux = []
        rbwave = []
        rbivar = []
        
        for slit in file_names: 
            rb_fits_include = fits.open(optimized_data_path + '/{0}'.format(mask_name) + "/{0}_Rebinned/{0}_Excluded/rebinned_{1}".format(mask_name,slit))
            rbflux.append(rb_fits_include[1].data["RBFLUX"])
            rbwave.append(rb_fits_include[1].data["RBWAVE"])
            rbivar.append(rb_fits_include[1].data["RBIVAR"]) 
        
    else: 
        
        rbflux = []
        rbwave = []
        rbivar = []
        
        for slit in file_names: 
            rb_fits_include = fits.open(optimized_data_path + '/{0}'.format(mask_name) + "/{0}_Rebinned/{0}_Serendip/rebinned_{1}".format(mask_name,slit))
            rbflux.append(rb_fits_include[1].data["RBFLUX"])
            rbwave.append(rb_fits_include[1].data["RBWAVE"])
            rbivar.append(rb_fits_include[1].data["RBIVAR"]) 
        
    return rbflux, rbwave, rbivar

## Function to Save The Median

In [None]:
def exportToFitsMedian(median,mask_name):
    
    hdu1 = fits.PrimaryHDU()
        
    c1 = fits.Column(name='MEDIAN',array=median,format="E")
    hdu2 = fits.BinTableHDU.from_columns([c1])
        
    hdul = fits.HDUList([hdu1,hdu2])
        
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Median/Median_of_{0}.fits.gz'.format(mask_name))

## Function To Read and Get Back Median

In [None]:
def get_med_from_fits(mask_name):
    median_read = fits.open(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Median/Median_of_{0}.fits.gz'.format(mask_name))
    median_fits = median_read[1].data["MEDIAN"] #contain the median 
    return median_fits

## Median Airglow Subtraction


In [None]:
def median_subtraction(slit_index,rebinned_flux):
    
    new_flux = []
    
    spectrum = rebinned_flux[slit_index]
   
    for i in range(len(spectrum)):
        if np.isfinite(spectrum[i]) == True:
            new_flux.append(spectrum[i] - median[i])
        else:
            new_flux.append(spectrum[i])
            
    return new_flux

In [None]:
def get_slit_nums(files):
    
    slit_nums = []
    
    if len(files) > 1:
    
        for i in range(len(files)):
            parts_of_file_name = files[i].split(".")
            slit_num = parts_of_file_name[2]
            slit_nums.append(int(slit_num))
            
    return slit_nums

In [None]:
def find_slit_index(slit_nums,slit_num): 
    #print('The index of slit number {} is: '.format(slit_num), slit_nums.index(slit_num))
    return slit_nums.index(slit_num)

In [None]:
def plotting(mask_name, slit_nums, rebinned_flux, median, incl_or_excl):
    
    if incl_or_excl == True:
        for slit in slit_nums:
            slit_index = find_slit_index(slit_nums,slit)
            new_flux = median_subtraction(slit_index,rebinned_flux)
            fig,axs = plt.subplots(1)
            fig.patch.set_alpha(1)
            plt.ylim(-10000,10000) #could try getting a smaller y limit and getting rid of legend
            plt.xlim(4000,11000)
            #plt.plot(wave_all[slit_index], flux_all[slit_index], color = 'r', label = 'original')
            plt.plot(new_wave_600, median, scalex=False, scaley=False, color = 'r', label = 'median')
            plt.plot(new_wave_600, rebinned_flux[slit_index] + 4000, scalex=False, scaley=False, color = 'b', label = 'rebinned')
            plt.plot(new_wave_600, np.array(new_flux) + 8000, scalex=False, scaley=False, color = 'g', label = 'median subtracted')
            plt.title('{} Slit {}'.format(mask_name, slit))
            plt.xlabel('Wavelength')
            plt.ylabel('Flux')
            plt.legend()
            fig.savefig('{0}_Spectra/{0}_Included/{0} Slit {1}'.format(mask_name, slit)) #folder name would need to change for each mask
    else:
        for slit in slit_nums:
            slit_index = find_slit_index(slit_nums,slit)
            new_flux = median_subtraction(slit_index,rebinned_flux)
            fig,axs = plt.subplots(1)
            fig.patch.set_alpha(1)
            plt.ylim(-10000,10000) #could try getting a smaller y limit and getting rid of legend
            plt.xlim(4000,11000)
            #plt.plot(wave_all[slit_index], flux_all[slit_index], color = 'r', label = 'original')
            plt.plot(new_wave_600, median, scalex=False, scaley=False, color = 'r', label = 'median')
            plt.plot(new_wave_600, rebinned_flux[slit_index] + 4000, scalex=False, scaley=False, color = 'b', label = 'rebinned')
            plt.plot(new_wave_600, np.array(new_flux) + 8000, scalex=False, scaley=False, color = 'g', label = 'median subtracted')
            plt.title('{} Slit {}'.format(mask_name, slit))
            plt.xlabel('Wavelength')
            plt.ylabel('Flux')
            plt.legend()
            fig.savefig('{0}_Spectra/{0}_Excluded/{0} Slit {1}'.format(mask_name, slit)) #folder name would need to change for each mask

## Looking At Specific Slit

In [None]:
def one_slit(slit_number,rebinned_flux,median,incl_or_excl,multiplier,
             min_ylim,max_ylim,min_xlim,max_xlim): #combination of median_subtraction + plotting function 
#slit_number, rebinned_flux, and incl_excl will need to be changed depending on whether we want to look at incl or excl 
    if incl_or_excl == True: #median subtraction 
        new_flux = [] #sky subtracted spectra 
        slit_index = find_slit_index(slit_nums,slit_number)
        spectrum = rebinned_flux[slit_index]
        multiplier = multiplier 

        for i in range(len(spectrum)):
            if np.isfinite(spectrum[i]) == True:
                new_flux.append((spectrum[i]*multiplier) - median[i]) 
            else:
                new_flux.append(spectrum[i])

        #plotting 
        fig,axs = plt.subplots(1)
        plt.ylim(-10000,10000) #could try getting a smaller y limit and getting rid of legend
        plt.xlim(4000,11000)
        plt.plot(new_wave_600, median, c="r", scalex=False, scaley=False, label = "median")
        plt.plot(new_wave_600, rebinned_flux[slit_index] + 5000, c="b", scalex=False, scaley=False, label = "rebinned")
        #plt.plot(new_wave_600, rebinned_flux["slit_{}".format(slit_number)] - 0.9*median + 8000, scalex=False, scaley=False,
                #label = "subtracted") #error
        plt.plot(new_wave_600, np.array(new_flux), c="g", scalex=False, scaley=False,label = "subtracted")
        plt.xlabel("Wavelength (A)")
        plt.ylabel("Flux (Electron/Hour)")
        plt.title("Slit #{}".format(slit_number))
        
    else:
        new_flux = [] #sky subtracted spectra w/ scaling
        new_flux_no_scaling = [] #sky subtracted spectra w/o scaling
        rbflux_with_scaling = []
        slit_index = find_slit_index(slit_nums_exclude,slit_number) #changed slit_nums to slit_nums_exclude
        spectrum = rebinned_flux[slit_index] #rbflux w/ no scaling
        multiplier = multiplier 

        #multiplying the rbflux by scaling factor 
        for i in range(len(spectrum)):
            if np.isfinite(spectrum[i]) == True:
                new_flux.append((spectrum[i] * multiplier) - median[i]) 
                rbflux_with_scaling.append((spectrum[i] * multiplier))
            else:
                new_flux.append(spectrum[i])
                rbflux_with_scaling.append(spectrum[i])

        #no scaling factor
        for i in range(len(spectrum)):
            if np.isfinite(spectrum[i]) == True:
                new_flux_no_scaling.append((spectrum[i]) - median[i]) 
            else:
                new_flux_no_scaling.append(spectrum[i])
        
        #plotting 
        fig,axs = plt.subplots(1)
        plt.ylim(min_ylim,max_ylim) #could try getting a smaller y limit and getting rid of legend
        plt.xlim(min_xlim,max_xlim)
        plt.plot(new_wave_600, median, c="r", scalex=False, scaley=False, label = "median")
        plt.plot(new_wave_600, rbflux_with_scaling, c="b", scalex=False, scaley=False, label = "rebinned w/ scaling")
        plt.plot(new_wave_600, np.array(new_flux_no_scaling), c="purple", scalex=False, scaley=False, label = "Subtracted w/o scaling")
        plt.plot(new_wave_600, np.array(new_flux), c="g", scalex=False, scaley=False,label = "subtracted")
        plt.xlabel("Wavelength (A)")
        plt.ylabel("Flux (Electron/Hour)")
        plt.title("Slit #{}".format(slit_number) + "/Multiplier: {}".format(multiplier) 
                  + "/From {0} to {1}".format(min_xlim,max_xlim))
        

## Define The Mask

In [None]:
mask_name = "M33D2A" #change to fit the appropriate mask 

## Getting Files We Want to Include and Exclude

In [None]:
#filtering files
list_of_files_to_include, list_of_files_to_exclude, list_of_serendip_files = get_files_to_include(mask_name)

#sorted
#file_names = all slits used to create the median (airglow)
#file_names_exclude = all slits that contain ISM emission lines 
#file_names_serendip = all serendip files
#file_names_all = all slits excluding "serendip"

file_names = list_of_files_to_include
file_names_exclude = list_of_files_to_exclude
file_names_serendip = list_of_serendip_files
file_names_all = list_of_files_to_include + list_of_files_to_exclude

## Extracting The Wavelength, Flux, and Inverse Variance

Make sure to comment out the codes in this section after you have rebinned and saved your data!!!

Then make sure to uncomment them whenever you're working with a new mask and want to rebin!!!

In [None]:
#getting data
#try getting and rebinning all files
flux, wave, ivar = get_original_data(file_names, mask_name) 

In [None]:
#rebinning the original data
rbflux, rbwave, rbivar = rebin(flux, wave, ivar, 600) # this takes about 4 minutes to run

In [None]:
#getting all excluded data
flux_exclude, wave_exclude, ivar_exclude = get_original_data(file_names_exclude, mask_name)

In [None]:
#rebinning the excluded data
rbflux_exclude, rbwave_exclude, rbivar_exclude = rebin(flux_exclude, wave_exclude, ivar_exclude, 600)

In [None]:
#getting all serendip data 
#NOTE: we will never use it but is good to just process it
flux_serendip, wave_serendip, ivar_serendip = get_original_data(list_of_serendip_files, mask_name)

In [None]:
#rebinning the serendip data
rbflux_serendip, rbwave_serendip, rbivar_serendip = rebin(flux_serendip, wave_serendip, ivar_serendip, 600)

## Saving The Rebinned Data

In [None]:
paths = [#make three folders to store the rebinned data, the median, and the spectra
        "{0}/{1}/{1}_Rebinned".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Spectra".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Median".format(optimized_data_path,mask_name),

        #make sub-folders for rebinned data
        "{0}/{1}/{1}_Rebinned/{1}_Excluded".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Included".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Serendip".format(optimized_data_path,mask_name),

        #make sub-folders for the spectra
        "{0}/{1}/{1}_Spectra/{1}_Excluded".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Spectra/{1}_Included".format(optimized_data_path,mask_name),

        #make directory to stores the scaled flux and shifted wavelength and polynomial coefficients
        "{0}/{1}/{1}_Rebinned/{1}_Scale_Values".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Shift_Values".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Polynomial_Coefficients".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Polynomial_Coefficients/{1}_Shifting_Polynomial_Coefficients".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Polynomial_Coefficients/{1}_Scaling_Polynomial_Coefficients".format(optimized_data_path,mask_name),
    
        #make a directory to store new median
        "{0}/{1}/{1}_Rebinned/{1}_New_Median".format(optimized_data_path,mask_name),

        #make sub-folders for scaling and shifting factors and rebinned flux w/ shifted wavelength
        "{0}/{1}/{1}_Rebinned/{1}_Shifting_Factor".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Scaling_Factor".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Scaling_and_Shifting_Factor".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Rebinned/{1}_Rebinned_Flux_Shifted_Wave".format(optimized_data_path,mask_name),

        #make directories to stores final trimmed spectra
        "{0}/{1}/{1}_Trimmed_Spectra".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Trimmed_Spectra/Excluded".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Trimmed_Spectra/Included".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Trimmed_Spectra/Optimized_Spectrum_Flux".format(optimized_data_path,mask_name),

        #make directories to stores polynomial fits and factor vs RMS plots
        "{0}/{1}/{1}_Polynomial_Graph".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Polynomial_Graph/Scaling_vs_RMS".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Polynomial_Graph/Shifting_vs_RMS".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Polynomial_Graph/Scaling_Fitting".format(optimized_data_path,mask_name),
        "{0}/{1}/{1}_Polynomial_Graph/Shifting_Fitting".format(optimized_data_path,mask_name)]

for path in paths:
    try: 
        os.makedirs(path)
    except OSError:
        if not os.path.isdir(path):
            raise

In [None]:
exportToFits(rbflux, rbwave, rbivar, mask_name, file_names, True) 
exportToFits(rbflux_exclude, rbwave_exclude, rbivar_exclude, mask_name, file_names_exclude, False)
exportToFits(rbflux_serendip, rbwave_serendip, rbivar_serendip, mask_name, file_names_serendip, None)

## Reading Back The Rebinned Data

In [None]:
rbflux_fits,rbwave_fits,rbivar_fits = get_fits_rebinned_data(mask_name,file_names,True)
rbflux_fits_exclude,rbwave_fits_exclude,rbivar_fits_exclude = get_fits_rebinned_data(mask_name,file_names_exclude,False)
rbflux_fits_serendip,rbwave_fits_serendip,rbivar_fits_serendip = get_fits_rebinned_data(mask_name,file_names_serendip,None)

## Finding The Median 

In [None]:
#taking the median
median = find_median(rbflux_fits) #median length is 10770 (M33D2A)

## Saving Median As FITS

In [None]:
exportToFitsMedian(median,mask_name)

## Getting Back Median From FITS

In [None]:
median_fits = get_med_from_fits(mask_name)

## Slits to Include and Exclude

In [None]:
slit_nums = get_slit_nums(file_names)
slit_nums_exclude = get_slit_nums(file_names_exclude)

all_slit_nums = get_slit_nums(file_names_all)

print("Slit # to INCLUDE in median calculation: {0}".format(slit_nums))
print("Slit # to EXCLUDE: {0}".format(slit_nums_exclude))

## Plotting ALL Slits In A Mask and Saving It In A Folder

In [None]:
#plot all slits and save it as a png in a folder 
#Need to change the mask name for each mask 
# plotting(mask_name, slit_nums, rbflux_fits, median_fits, True)
# plotting(mask_name, slit_nums_exclude, rbflux_fits_exclude, median_fits, False)

# Optimization

## Optimization Settings

In [None]:
threshold_median = 150 #define the threshold for median boolean
threshold_sky_sub = 50 #define the threshold for sky subtraction boolean 

multipliers = np.arange(0.7,1.3,0.01) #array of multipliers we want to test
shift_test_values = np.arange(-0.2, 0.21, 0.01) #range of shift values we want to test

## Boolean Arrays & Moving Median

In [None]:
from scipy.ndimage import median_filter

def moving_median(a, window=325):
    
    '''
    Returns the moving median values of the array,
    using a window of a given size, centered at
    each point.
    
    Version - 4.0
    
    Parameters
    ----------
    a : ndarray
        One dimensional flux array.
    window : int, optional
        The size of each segment for taking the median.
        
    Returns
    ----------
    median_arr : One dimensional array of moving median.
    
    '''
        
    all_indices = np.arange(len(a))
    finite_bool = np.isfinite(a)
    nan_indices = all_indices[np.invert(finite_bool)]
    nan_indices_set = set(nan_indices)
    n = len(finite_bool)

    if (nan_indices_set=={0,n} or nan_indices_set=={0} or nan_indices_set=={n}):
        
        finite_indices = all_indices[finite_bool]
        nearest_finite_indices = np.searchsorted(finite_indices, nan_indices)
        nearest_finite_indices = nearest_finite_indices - (nearest_finite_indices==len(finite_indices))
        a[nan_indices] = a[finite_indices[nearest_finite_indices]][:]
        median_arr = median_filter(a, window, mode='nearest')

    elif (len(nan_indices_set)==0):
        
        median_arr = np.nan*np.ones(len(a))

    else:
        
        if True not in finite_bool:
            median_arr = np.nan*np.ones(len(a))
            
        else:
            finite_indices = all_indices[finite_bool]
            nearest_finite_indices = np.searchsorted(finite_indices, nan_indices)
            gap_indices = ((nearest_finite_indices>0) & (nearest_finite_indices<len(finite_indices)))
            middle_nan_indices = nan_indices[gap_indices]
            right_nearest_indices = finite_indices[nearest_finite_indices[gap_indices]]
            left_nearest_indices = finite_indices[nearest_finite_indices[gap_indices] - 1]
            right_distances = abs(right_nearest_indices - middle_nan_indices)
            left_distances = abs(left_nearest_indices - middle_nan_indices)
            right_is_near_bool = (left_distances > right_distances)
            left_is_near_bool = (left_distances <= right_distances)
            a[middle_nan_indices[right_is_near_bool]] = a[right_nearest_indices[right_is_near_bool]][:]
            a[middle_nan_indices[left_is_near_bool]] = a[left_nearest_indices[left_is_near_bool]][:]
            a[nan_indices[nearest_finite_indices==0]] = a[finite_indices[0]]
            a[nan_indices[nearest_finite_indices==len(finite_indices)]] = a[finite_indices[-1]]
            median_arr = median_filter(a, window, mode='nearest')
    
    return (median_arr)

In [None]:
use_moving_median = True
window = 325

In [None]:
def median_threshold(median, threshold):
    
    median_boolean = []
    
    for value in median:
        
        if np.isfinite(value) == True: #to filter out nan 
            
            if value > threshold:
                median_boolean.append(True)
                
            elif value <= threshold:
                median_boolean.append(False)
            
        else:
            median_boolean.append(False)
            
    median_boolean_array = np.array(median_boolean)
    
    return median_boolean_array

In [None]:
def create_wave_bool(wavelength, min_wave, max_wave):
    
    wavelength_boolean = []
    
    for value in wavelength: 
        
        if (value > min_wave) and (value < max_wave): 
            
            wavelength_boolean.append(True)
            
        else:
            
            wavelength_boolean.append(False)
            
    wavelength_boolean_array = np.array(wavelength_boolean)
            
    return wavelength_boolean_array

In [None]:
def sky_sub_bool(slit_index, rebinned_flux_list, median, threshold, use_moving_median=use_moving_median):
#     new_flux = [] #sky subtracted spectra 
    #slit_index = find_slit_index(slit_nums_exclude,slit_number) #changed slit_nums to slit_nums_exclude
    slit_index = slit_index
    spectrum = rebinned_flux_list[slit_index]

    # rbflux - median is stored as new_flux 
    new_flux = spectrum - median
    
    if (use_moving_median==True):
        new_flux = np.asarray(new_flux) - moving_median(np.asarray(new_flux), window=window)
    else:
        pass
    
    rbflux_boolean = new_flux <= threshold
            
    return np.array(rbflux_boolean)

## Polynomial Fits and Weighted Wavelength

In [None]:
def polynomial_first_order(x,poly_const):
    return (x * poly_const[0]) + (poly_const[1])

In [None]:
def polynomial_second_order(x,poly_const): #function that represent the third-order polynomial 
    return (x**2 * poly_const[0]) + (x * poly_const[1]) + (poly_const[2])

In [None]:
def polynomial_third_order(x,poly_const): #function that represent the third-order polynomial 
    return (x**3 * poly_const[0]) + (x**2 * poly_const[1]) + (x * poly_const[2]) + (poly_const[3])

In [None]:
def polynomial_fourth_order(x,poly_const): #function that represent the fourth-order polynomial 
    return (x**4 * poly_const[0]) + (x**3 * poly_const[1]) + (x**2 * poly_const[2]) + (x * poly_const[3]) + (poly_const[4])

In [None]:
def weighted_wave(median,threshold_median,threshold_sky_sub,wavelength,min_wave,max_wave,rbflux,slit_index):
    
    #boolean array using median and threshold
    median_boolean_array = median_threshold(median, threshold_median)
    
    #boolean array using wavelength
    wavelength_boolean_array = create_wave_bool(wavelength, min_wave, max_wave)
    
    #boolean array using rbflux - median 
    sky_sub_boolean_array = sky_sub_bool(slit_index,rbflux,median,threshold_sky_sub)
    
    #multiply the two boolean arrays
    multiply_boolean = median_boolean_array * wavelength_boolean_array * sky_sub_boolean_array
    
    median_values_for_weight = [] #determine the median we want to use as our weight
    median_values_index = []
    
    for n in range(len(multiply_boolean)):
        
        if multiply_boolean[n] == True:
            
            median_values_for_weight.append(median[n])
            median_values_index.append(n)
        
        else:
            
            pass
    
    wavelength_values_for_weight = [] #determine the wavelength we want to use in 
    
    for index in median_values_index:
        
        wavelength_values_for_weight.append(wavelength[index])
    
    #calculating the weighted wavelength
    if len(median_values_for_weight) == 0:
        print("All boolean values are False. Weighted wavelength cannot be calculated! (Wavelength {0} to {1})".format(min_wave,max_wave))
    
    elif len(median_values_for_weight) == len(wavelength_values_for_weight): #make sure they have same length
        
        sigma_med_wave = []
        
        for n in range(len(median_values_for_weight)):
            
            sigma_med_wave.append(median_values_for_weight[n] * wavelength_values_for_weight[n])
        
        weighted_wavelength = sum(np.array(sigma_med_wave)) / sum(np.array(median_values_for_weight))
        
        return weighted_wavelength

In [None]:
def automating_weighted_wave_multiple_slits(index_of_slit, fluxes, median):
        
    #wavelength_array =  np.arange(4000,11200,200)
    wavelength_array = np.arange(4000,11500,500)
    #wavelength_array = np.arange(4000,11350,350)
    
    weighted_wavelength_list = []

    for index in range(len(wavelength_array)):

        if (index + 1) == len(wavelength_array): 
            break

        else:
        
            weighted_wavelength = weighted_wave(median, threshold_median,threshold_sky_sub, rbwave_fits[0],wavelength_array[index],
                                          wavelength_array[index+1],fluxes,index_of_slit)
            
            weighted_wavelength_list.append(weighted_wavelength)

    #weighted_wavelength_list_final = []
    
    #for wavelength in weighted_wavelength_list: #filter out any None values 
        
        #if wavelength != None:
            
            #weighted_wavelength_list_final.append(wavelength)
        
    return weighted_wavelength_list
            

## Class & Decorator Function To Generate Log Files

**Caution:** Execute the following cell only once per run. Do not modify the ```std_out``` or ```std_err``` variables. If they are modified by accident, please restart the kernel and run the notebook from the beginning.

In [None]:
# Saving the original streams for stdout and stderr. To be used for logging output later

import sys
std_out = sys.stdout; std_err = sys.stderr

In [None]:
# For duplicating the output stream to log files during the optimization process.

class multifile(object):
    def __init__(self, files):
        self._files = files
    def __getattr__(self, attr, *args):
        return self._wrap(attr, *args)
    def _wrap(self, attr, *args):
        def g(*a, **kw):
            for f in self._files:
                res = getattr(f, attr, *args)(*a, **kw)
            return res
        return g

In [None]:
# Decorating function to generate log files during shifting, and scaling.

def write_to_log(filename):
    def inner_decorator_1(func):
        def inner_decorator_2(*args, **kwargs):
            
            try:
                log_file = open(filename, 'w')
                sys.stdout = multifile([ std_out, log_file ])
                sys.stderr = multifile([ std_err, log_file ])

                # Function runs here.
                x = func(*args, **kwargs)

            except Exception as e:
                raise e
            finally:
                log_file.close()
                sys.stdout = std_out
                sys.stderr = std_err
                
            return x
        return inner_decorator_2
    return inner_decorator_1

# Shifting

In [None]:
# def shifting_wavelength(waves, shifting_value):
    
# #     waves_shifted = []
    
# #     for i in range(len(waves)):
# #         waves_shifted.append(waves[i] + shifting_value)
    
#     return waves_shifted

In [None]:
# def looping_shifting_wavelength(original_wavelength,test_values,index_of_slit):
    
#     wave_shifted_dict = {}
    
#     for value in test_values: 
#         wave_shifted = shifting_wavelength(original_wavelength[index_of_slit],value)
#         wave_shifted_dict["Shifted_{}".format(round(value,2))] = wave_shifted
    
#     return wave_shifted_dict

In [None]:
def rebin_wave_shifted(flux, wave, ivar, test_values, index_of_slit):
    
    rbflux_shifted_dict = {}
    
    for value in test_values:
        
        rbflux_shifted,rbwave_shifted,rbivar_shifted = rebin([flux[index_of_slit]], 
                                                             [wave[index_of_slit] + value], 
                                                             [ivar[index_of_slit]], 600)
        
        rbflux_shifted_dict["Shifted_{}".format(round(value,2))] = np.asarray(rbflux_shifted).ravel()
    
    return rbflux_shifted_dict

In [None]:
def shifted_flux_subtraction_threshold(rbflux_shifted_dict, median,
                                       use_moving_median=use_moving_median):
    
    keys = list(rbflux_shifted_dict.keys())
    sub_dict = {}
    sub_boolean_dict = {}
    
    if use_moving_median==True:
        flux_without_shifting = rbflux_shifted_dict['Shifted_0.0']
        median_baseline = moving_median(flux_without_shifting-median,  window=window)
        for i in keys:
            sub_flux = rbflux_shifted_dict[i] - median
            sub_flux = sub_flux - median_baseline
            sub_dict[i] = sub_flux
            sub_boolean_dict[i] = np.asarray(sub_flux < threshold_sky_sub)
    else:
        for i in keys:
            sub_flux = rbflux_shifted_dict[i] - median
            sub_dict[i] = sub_flux
            sub_boolean_dict[i] = np.asarray(sub_flux < threshold_sky_sub)
    
    return (sub_dict, sub_boolean_dict)

In [None]:
def rbflux_shifted_minus_median(sub_dict, sub_boolean_dict,
                                median_and_wavelength_boolean_array, shift_test_values):
    
    subtraction_dict_shifted = {}
    
    for test_value in shift_test_values:
        
        key = "Shifted_{}".format(round(test_value,2))
        
        multiply_boolean = sub_boolean_dict[key] * median_and_wavelength_boolean_array
        subtraction_full_array = sub_dict[key]
        subtraction_slice = subtraction_full_array[multiply_boolean]
                      
        subtraction_dict_shifted[key] = subtraction_slice 
    
    return subtraction_dict_shifted

In [None]:
def rms_calculation_shift(rms_dict_sorted_shift, shift_test_values):
    
    rms_dict_shift = {}

    for test_value in shift_test_values:

        values_for_rms_cal = rms_dict_sorted_shift["Shifted_{}".format(round(test_value,2))]
        
        values_for_rms_cal = np.array(values_for_rms_cal)
        rms = np.nanmean(values_for_rms_cal*values_for_rms_cal)**0.5
#         rms = statistics.stdev(values_for_rms_cal)

        rms_dict_shift["Shifted_{}".format(round(test_value,2))] = rms
        
    return rms_dict_shift
    
    
        #print("Everything is False. There's no True boolean. Therefore, RMS cannot be calculated.")

In [None]:
def plotting_the_shift_rms(slit_number,shift_test_values, rms_dict_shift, min_wave, max_wave):
    
    value_list = []
    path = optimized_data_path + '/{0}'.format(mask_name) + "/{0}_Polynomial_Graph/Shifting_vs_RMS/Slit_{1}".format(mask_name,slit_number)
    try: 
        os.makedirs(path)
    except OSError:
        if not os.path.isdir(path):
            raise
    for test_value in shift_test_values: 

        value_list.append(rms_dict_shift["Shifted_{0}".format(round(test_value,2))])

    fig = plt.figure(figsize=(8,6))
    fig.patch.set_alpha(1)
    plt.plot(shift_test_values,value_list)
    plt.xlabel("Shifting")
    plt.ylabel("RMS")
    plt.title("Mask {0}: Slit #{1}\nRMS vs Shifting ({2} A to {3} A)".format(mask_name,slit_number,min_wave,max_wave))
    fig.savefig(path+'/{0}_Slit_{1}_{2}_to_{3}.png'.format(mask_name,slit_number,min_wave,max_wave))
        
    min_val = min(value_list)

    shifting_value = round(shift_test_values[value_list.index(min_val)],2)

    print("Shifting w/ minimum RMS (Slit #{0}): {1}".format(slit_number,shifting_value) + " ({0} A to {1} A)".format(min_wave,max_wave))

    return shifting_value
        
    
        #print("Because we have no RMS there is no plot.")
        
    

## Automating All The Functions Used In Shifting Process (For A Single Slit)

In [None]:
def finding_shifting(slit_number, flux, wave, ivar,
                     median_boolean_array, wavelength, slit_index):
    
    #define all values we want to test
    # shift_test_values = shift_test_values
    
    #shift original wavelength by test values
    #store everything as a dictionary. Format: "Shifted_(test_value):[shifted wavelength]"
#     wave_shifted_dict = looping_shifting_wavelength(wave,shift_test_values,slit_index) 
    
    #rebin using shifted wavelength 
    #store rebinned flux as a dictionary. Format: " Shifted_(test_value):[rebinned flux]"
    rbflux_shifted_dict = rebin_wave_shifted(flux, wave, ivar, shift_test_values, slit_index)
    
    sub_dict, sub_boolean_dict = shifted_flux_subtraction_threshold(rbflux_shifted_dict, median)
    
    #500 A segments 
    wavelength_array = np.arange(4000,11500,500)
    
    shifting_value_dict = {}
    
    #for loop to find the multiply_boolean and median-subtracted spectrum
    for n in range(len(wavelength_array)):

        if (n + 1) == len(wavelength_array): 
            break

        else:
            #boolean array using wavelength
            wavelength_boolean_array = create_wave_bool(wavelength, wavelength_array[n], wavelength_array[n+1])

            #boolean array using rbflux - median 
#             sky_sub_boolean_array = sky_sub_bool(slit_index,rbflux_list,median,threshold_sky_sub)
            
            median_and_wavelength_boolean_array = median_boolean_array * wavelength_boolean_array

#             #multiply the three boolean arrays
#             multiply_boolean = median_boolean_array * wavelength_boolean_array * sky_sub_boolean_array
            
            #use the rebinned flux, the optimal scaling factor, and median to calculate the median-subtracted spectrum
            subtraction_dict_shifted = rbflux_shifted_minus_median(sub_dict, sub_boolean_dict,
                                                                    median_and_wavelength_boolean_array, shift_test_values)
            
            if len(subtraction_dict_shifted["Shifted_0.2"]) == 0 or len(subtraction_dict_shifted["Shifted_0.2"]) == 1:
                shifting_value_dict["{0}_to_{1}".format(wavelength_array[n],wavelength_array[n+1])] = None
                print("Boolean are all False. No values can be use to calculate the RMS. From {0} to {1}".format(wavelength_array[n],wavelength_array[n+1]))
                
            else: 
            
                #rms values
                rms_dict_shift = rms_calculation_shift(subtraction_dict_shifted,shift_test_values)

                #rms values that will be used 
                shifting_value = plotting_the_shift_rms(slit_number,shift_test_values,rms_dict_shift,wavelength_array[n],wavelength_array[n+1])

                shifting_value_dict["{0}_to_{1}".format(wavelength_array[n],wavelength_array[n+1])] = shifting_value
            
    return shifting_value_dict, rbflux_shifted_dict

## Saving Flux Rebinned With Shifted Wavelength

In [None]:
#saving rbflux_shifted_dict as a nyp files, take a bit of space, need to reduce it.
#get only the rbflux that we will need 
def sorting_needed_shift_factor(shifting_value_dict,rbflux_shifted_dict):
    
    original_shift_factor_list = list(shifting_value_dict.values())
    
    #remove all duplicate and all None 
    remove_duplicate_and_none = list(set(original_shift_factor_list))
    
    #selects only values that are optimal shift factor
    sorted_rbflux_shifted_dict = {}
    
    for value in remove_duplicate_and_none:
        if value == None:
            remove_duplicate_and_none.remove(value)
            
    for value in remove_duplicate_and_none:
        sorted_rbflux_shifted_dict["Shifted_{0}".format(value)] = rbflux_shifted_dict["Shifted_{0}".format(value)]
            
    return(sorted_rbflux_shifted_dict)


In [None]:
#note sure how to save dictionary as FITS file, therefore, use npy file
def saving_rbflux_shifted(mask_name,slit_number_used,sorted_rbflux_shifted_dict):
    
    #saving the dictionary as a npy file
    np.save("{2}/{0}/{0}_Rebinned/{0}_Rebinned_Flux_Shifted_Wave/rbflux_shifted_dict_{0}_{1}.npy".format(mask_name,slit_number_used,optimized_data_path),sorted_rbflux_shifted_dict)

## Polynomial Fit for Shifting

In [None]:
def wavelength_shifting_function(shifting_value_dict, weighted_wavelength): #used to make a plot of optimal shifting factor vs wavelength
    
    shifting_values = [] #do not plot any shifting value that has None
    
    wavelength_values = [] #contains all the wavelength we will plots 
    
    for index in range(len(shifting_value_dict.values())): #filtering out the None values
        
        if list(shifting_value_dict.values())[index] != None:
        
            shifting_values.append(list(shifting_value_dict.values())[index])
            
            wavelength_values.append(weighted_wavelength[index])

        else:
            pass
    
    all_wavelength_values = np.arange(4000,11000,0.65) #used to plot every single values between 4000 and 11000 using our polynomial

    #finding the polynomial constant for second order
    poly_const_second_deg = np.polyfit(wavelength_values,shifting_values,2)
    print("Second order polynomial: y = {0}x^2 + {1}x + {2}".format(poly_const_second_deg[0],
                                                                    poly_const_second_deg[1],poly_const_second_deg[2]))
    
    #finding the polynomial constant for third order
    poly_const_third_deg = np.polyfit(wavelength_values,shifting_values,3)
    print("Third order polynomial: y = {0}(x^3) + {1}(x^2) + {2}(x) + {3}".format(poly_const_third_deg[0],
                                                                                  poly_const_third_deg[1],poly_const_third_deg[2],poly_const_third_deg[3]))
    
    #finding the polynomial constant for fourth order
    poly_const_fourth_deg = np.polyfit(wavelength_values,shifting_values,4)
    print("Fourth order polynomial: y = {0}(x^4) + {1}(x^3) + {2}(x^2) + {3}(x) + {4}".format(poly_const_fourth_deg[0],
                                                                                  poly_const_fourth_deg[1],poly_const_fourth_deg[2],poly_const_fourth_deg[3]
                                                                                             ,poly_const_fourth_deg[4]))
    
    
    #calculating the third and fourth order polynomial as a function of wavelength
    poly_second_order = []
    
    poly_third_order = []
    
    poly_fourth_order = []
    
    for value in all_wavelength_values:
        poly_second_order.append(polynomial_second_order(value,poly_const_second_deg))
        
        poly_third_order.append(polynomial_third_order(value,poly_const_third_deg))
        
        poly_fourth_order.append(polynomial_fourth_order(value,poly_const_fourth_deg))
                                                                                            
    #plotting the optimal shifting factors, third-order polynomial, fourth-order polynomial as a function of wavelength
    fig,axs = plt.subplots(1)
    axs.set_xlim(3900,11100)
    axs.set_ylim(-0.25,0.25)
    axs.set_title("Optimal Shifting Factor vs Wavelength")
    axs.set_xlabel("Wavelength (A)")
    axs.set_ylabel("Optimal Shifting Factor")
    axs.plot(all_wavelength_values,poly_second_order,scalex=False,scaley=False,label="Second-Order",c="green")
    axs.plot(all_wavelength_values,poly_third_order,scalex=False,scaley=False,label="Third-Order",c="blue")
    axs.plot(all_wavelength_values,poly_fourth_order,scalex=False,scaley=False,label="Fourth-Order",c="orange")
    axs.scatter(wavelength_values, shifting_values,label="Optimal Shifting Factor",c="black")
    axs.legend()
    
    return shifting_values,wavelength_values,poly_const_second_deg

## Improving The Polymial Fits (Shifting)

In [None]:
def remove_outliers_sigma_clip_shifting(shifting_values_list,weighted_wavelength_list,poly_const_second_deg):
    
    #optimal shifting value - optimal shifting value based on line of best fit
    deviation = [] 
    
    #determine the vertical difference (deviation) between dots and line of best fit
    for n in range(len(weighted_wavelength_list)):
        deviation.append(np.abs(polynomial_second_order(weighted_wavelength_list[n],poly_const_second_deg) - shifting_values_list[n]))
    
    print("First Deviation: {}".format(deviation))
    print("First Deviation RMS: {}".format(statistics.stdev(deviation)))
    
    #BEGIN FIRST ITERATION
    
    #first iteration 
    shifting_values_1 = []
    wavelength_values_1 = []
    
    #to store outlier
    outlier_deviation = []
    
    #remove outliers from data
    for n in range(len(deviation)):
        if deviation[n]/statistics.stdev(deviation) < 3.5:
            shifting_values_1.append(shifting_values_list[n])#keep all non-outlier
            wavelength_values_1.append(weighted_wavelength_list[n])
        #add all outliers to a separate list
        else: 
            outlier_deviation.append(deviation[n])
    
    #if there's no outlier, return inputs 
    if len(outlier_deviation) == 0: 
        return shifting_values_1, wavelength_values_1
        exit()
    
    #if there are outliers, remove the largest one and keep the remaining outliers
    elif len(outlier_deviation) > 1: 
        outlier_deviation.remove(max(outlier_deviation))
        for value in outlier_deviation:
            shifting_values_1.append(shifting_values_list[deviation.index(value)])
            wavelength_values_1.append(weighted_wavelength_list[deviation.index(value)])
     
    #if there is only one values left after the outliers is removed, return shifting values and wavelength
    if len(shifting_values_1) == 0 or len(shifting_values_1) == 1:
        return shifting_values_1,wavelength_values_1
        exit()
    
    #PERFORM A NEW FITTING
    
    new_poly_coeff = np.polyfit(wavelength_values_1,shifting_values_1,2)
    
    
    #CALCULATE NEW DEVIATION AND REMOVE OUTLIERS
    
    new_deviation = []
    
    for n in range(len(wavelength_values_1)):
        new_deviation.append(np.abs(polynomial_second_order(wavelength_values_1[n],new_poly_coeff) - shifting_values_1[n]))
    
    print("Second Deviation: {}".format(new_deviation))
    print("Second Deviation RMS: {}".format(statistics.stdev(new_deviation)))
                             
    #final set of data with outliers removed
    no_outlier_shifting_values_list = []
    no_outlier_wavelength_values_list = []
    
    #to stores outliers
    new_outlier_deviation = []
    
    #remove outliers
    for n in range(len(new_deviation)):
        if new_deviation[n]/statistics.stdev(new_deviation) < 3.5:
            no_outlier_shifting_values_list.append(shifting_values_1[n])#keep all non-outlier
            no_outlier_wavelength_values_list.append(wavelength_values_1[n])
        else:
            new_outlier_deviation.append(new_deviation[n])
    
    #if there's no outlier or just one outlier, return final list of scale factor and wavelength 
    if len(new_outlier_deviation) == 0 or len(new_outlier_deviation) == 1:
        return no_outlier_shifting_values_list, no_outlier_wavelength_values_list
        exit()
    
    #if there are outliers, remove the largest one and keep the remaining outliers
    elif len(new_outlier_deviation) > 1: 
        new_outlier_deviation.remove(max(new_outlier_deviation))
        for value in new_outlier_deviation:
            no_outlier_shifting_values_list.append(shifting_values_1[new_deviation.index(value)])
            no_outlier_wavelength_values_list.append(wavelength_values_1[new_deviation.index(value)])
        return no_outlier_shifting_values_list,no_outlier_wavelength_values_list
    

In [None]:
def new_polynomial_fits_shifting(mask_name,slit_num,shifting_values_list,wavelength_values_list):
    
    no_outlier_shifting_values_list = shifting_values_list #scaling values with outliers removed 
    no_outlier_wavelength_values_list = wavelength_values_list #wavelength values with outliers removed
    
    poly_const_second_deg = np.polyfit(no_outlier_wavelength_values_list,no_outlier_shifting_values_list,2) #find the polynomial constants
    print("New Second Order Polynomial: y = {0}x^2 + {1}x + {2}".format(poly_const_second_deg[0],
                                                                    poly_const_second_deg[1],poly_const_second_deg[2]))
    
    all_wavelength_values = np.arange(4000,11000,0.65) #used to plot every single values between 4000 and 11000 using our polynomial
    
    #calculate the new second order polynomial fits
    
    poly_second_order = [] 
    
    for value in all_wavelength_values:
        poly_second_order.append(polynomial_second_order(value,poly_const_second_deg))
        
    #straighten the ends
        
    #Optimal Scaling Factor of far left wavelength
    left_end_val = polynomial_second_order(min(no_outlier_wavelength_values_list),poly_const_second_deg) 
    
    #Optimal Scaling Factor of far right wavelength
    right_end_val = polynomial_second_order(max(no_outlier_wavelength_values_list),poly_const_second_deg) 
    
    print("Optimal Shift Factor of All Wavelength Before {0} Angstroms: {1}".format(min(no_outlier_wavelength_values_list),left_end_val)) #print them out 
    
    print("Optimal Shift Factor of All Wavelength After {0} Angstroms: {1} ".format(max(no_outlier_wavelength_values_list),right_end_val))
    
    #straighten the LEFT side of the polynomial fit
    for n in range(len(all_wavelength_values)):
        if all_wavelength_values[n] < min(no_outlier_wavelength_values_list):
            poly_second_order[n] = left_end_val
        
    #straighten the RIGHT side of the polynomial fit
    for n in range(len(all_wavelength_values)):
        if all_wavelength_values[n] > max(no_outlier_wavelength_values_list):
            poly_second_order[n] = right_end_val
        
    #plot
    fig,axs = plt.subplots(1)
    fig.set_size_inches(8,6)
    fig.patch.set_alpha(1)
    axs.set_xlim(3900,11100)
    axs.set_ylim(-0.25,0.25)
    axs.set_title("Mask {0}: Slit #{1}\nOptimal Shift Factor vs Wavelength".format(mask_name,slit_num))
    axs.set_xlabel("Wavelength (A)")
    axs.set_ylabel("Optimal Shift Factor")
    axs.plot(all_wavelength_values,poly_second_order,scalex=False,scaley=False,label="Second-Order",c="green")
    axs.scatter(wavelength_values_list,shifting_values_list,label="Optimal Shift Factor",c="black")
    fig.savefig(optimized_data_path + '/{0}'.format(mask_name) + "/{0}_Polynomial_Graph/Shifting_Fitting/Poly_Graph_{0}_slit_{1}.png".format(mask_name,slit_num))
    
    return no_outlier_wavelength_values_list,left_end_val,right_end_val,poly_const_second_deg

## Shift Original Wavelength Using Polynomial Fits

In [None]:
def use_new_shifting_factor(wave,no_outlier_wavelength_values_list,left_end_val,right_end_val,poly_const_second_deg):
     
    wave = wave #original wavelength from FITS files
    
    no_outlier_wavelength_values_list = no_outlier_wavelength_values_list #weighted wavelength with outliers removed 
    
    original_wavelength_shift = []
    
    for n in range(len(wave)): 
        if wave[n] < min(no_outlier_wavelength_values_list):
            original_wavelength_shift.append(wave[n] + left_end_val)
        
        elif wave[n] > max(no_outlier_wavelength_values_list): 
            original_wavelength_shift.append(wave[n] + right_end_val)
                                      
        else:
            shift_factor = polynomial_second_order(wave[n],poly_const_second_deg)
            original_wavelength_shift.append(wave[n] + shift_factor)
            
    return original_wavelength_shift

## Saving Shifted Wavelength as FITS Files

In [None]:
def save_new_polynomial_fits_values_shift(no_outlier_scaling_values_list, no_outlier_wavelength_values_list, original_wavelength_shift,mask_name,slit_number_used):
    
    hdu1 = fits.PrimaryHDU() #primary HDU (empty)
    
    c1 = fits.Column(name='SHIFT_VALUES', array=no_outlier_scaling_values_list, format='E')
    c2 = fits.Column(name='WAVELENGTH_VALUES', array=no_outlier_wavelength_values_list, format='E')
    c3 = fits.Column(name='ORIGINAL_WAVELENGTH_SHIFTED', array=original_wavelength_shift, format='E')

    hdu2 = fits.BinTableHDU.from_columns([c1, c2, c3]) #first extensional HDU (w data)
            
    hdul = fits.HDUList([hdu1, hdu2]) #combine both HDUs into file and write it below
        
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Shift_Values/Shift_Values_Polynomial_Fits_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))


## Saving Shift Factor

In [None]:
def saving_shifting_factor(mask_name,slit_number_used,shifting_value_dict):
    
    hdu1 = fits.PrimaryHDU()
    
    c1 = fits.Column(name="SHIFTING_FACTOR", array=list(shifting_value_dict.values()), format='E')

    hdu2 = fits.BinTableHDU.from_columns([c1])
    
    hdul = fits.HDUList([hdu1,hdu2])
    
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Shifting_Factor/Shifting_Factor_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))

## Saving Shifting Polynomial Coefficients

In [None]:
def save_shift_polynomial_coeff(mask_name,slit_number_used,poly_const_second_deg_shift,left_end_val,right_end_val,no_outlier_wavelength_values_list):
    
    hdu1 = fits.PrimaryHDU() #primary HDU (empty)
        
    c2 = fits.Column(name='POLYNOMIAL_COEFFICIENTS_SHIFTING', array=poly_const_second_deg_shift, format='E')
    c3 = fits.Column(name='END_VALUES', array=[left_end_val,right_end_val], format='E')
    c4 = fits.Column(name='NO_OUTLIER_WEIGHTED_WAVE',array=no_outlier_wavelength_values_list,format='E')

    hdu2 = fits.BinTableHDU.from_columns([c2,c3,c4]) #first extensional HDU (w data)
            
    hdul = fits.HDUList([hdu1, hdu2]) #combine both HDUs into file and write it below
        
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Polynomial_Coefficients/{0}_Shifting_Polynomial_Coefficients/Shift_Poly_Coeff_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))
    

## Automate Shifting of Included Slits

In [None]:
@write_to_log('{0}/{1}/{1}_Shifting_Log_Included.log'.format(optimized_data_path,mask_name))
def shifting_included_slits(slit_numbers_to_automate):
    
    #boolean array using median and threshold
    median = median_fits
    median_boolean_array = median_threshold(median,threshold_median)
    
    for slit_number in slit_numbers_to_automate:
        
        slit_number_used = slit_number
        print('\nShifting slit #{0} (Included) from mask {1}.'.format(slit_number_used,mask_name))

        index_of_slit = slit_nums.index(slit_number_used) #index of slit 
        
        #Determine the weighted wavelength 
        weighted_wavelength_list = automating_weighted_wave_multiple_slits(index_of_slit, rbflux_fits, median) #there is some None values. Let's filter it out! 
        print("List of All Weighted Wavelength: {0}".format(weighted_wavelength_list))

        
        #Determine the Optimal Shift Factor 
        print('Finding the Optimal Shift Factor...')
        shifting_value_dict,rbflux_shifted_dict = finding_shifting(slit_number_used, flux, wave, ivar,
                                                                   median_boolean_array, rbwave_fits[0], index_of_slit)
        
        #Saving the Shift Factor
        saving_shifting_factor(mask_name,slit_number_used,shifting_value_dict)
        
        #Sort through rbflux_shifted_dict and remove all key-value we don't need (reduce size of saved file)
        sorted_rbflux_shifted_dict = sorting_needed_shift_factor(shifting_value_dict,rbflux_shifted_dict)
        
        #Saving the rebinned flux calculated using the shifted wavelength
        saving_rbflux_shifted(mask_name,slit_number_used,sorted_rbflux_shifted_dict)

        try:
            #Original Polynomial Fits for Shifting
            shifting_values_list,wavelength_values_shift,poly_const_second_deg_shifting = wavelength_shifting_function(shifting_value_dict,weighted_wavelength_list)

            #Improving Original Polynomial Fits for Shifting
            #Remove Outliers 
            no_outlier_shifting_values_list, no_outlier_wavelength_values_list_shift = remove_outliers_sigma_clip_shifting(shifting_values_list,wavelength_values_shift,poly_const_second_deg_shifting)

            #New Polynomial Fits for Shifting
            no_outlier_wavelength_values_list_shift,left_end_val_shift,right_end_val_shift,poly_const_second_deg_shift = new_polynomial_fits_shifting(mask_name,slit_number_used,no_outlier_shifting_values_list,no_outlier_wavelength_values_list_shift)

            #Use New Poly Fits to Shift Original Flux
            original_wavelength_shift = use_new_shifting_factor(wave[index_of_slit],
                               no_outlier_wavelength_values_list_shift,left_end_val_shift,
                               right_end_val_shift,poly_const_second_deg_shift)
        
        except TypeError:
            poly_const_second_deg_shift = [0,0,0] #no shifting
            no_outlier_shifting_values_list = list(shifting_value_dict.values())
            no_outlier_wavelength_values_list_shift = weighted_wavelength_list
            original_wavelength_shift = wave[index_of_slit]
            
        #Saving Shifted Original Wavelength as Fits Files 
        save_new_polynomial_fits_values_shift(no_outlier_shifting_values_list,no_outlier_wavelength_values_list_shift,original_wavelength_shift,mask_name,slit_number_used)
        
        #Saving Shifting Polynomial Coefficients
        save_shift_polynomial_coeff(mask_name,slit_number_used,poly_const_second_deg_shift,left_end_val_shift,right_end_val_shift,no_outlier_wavelength_values_list_shift)

## Automate Shifting of Excluded Slits

In [None]:
@write_to_log('{0}/{1}/{1}_Shifting_Log_Excluded.log'.format(optimized_data_path,mask_name))
def shifting_excluded_slits(slit_numbers_to_automate):
    
    #boolean array using median and threshold
    median = median_fits
    median_boolean_array = median_threshold(median,threshold_median)
    
    for slit_number in slit_numbers_to_automate:
        
        slit_number_used = slit_number
        print('\nShifting slit #{0} (Excluded) from mask {1}.'.format(slit_number_used,mask_name))

        index_of_slit = slit_nums_exclude.index(slit_number_used) #index of slit 
        
        #Determine the weighted wavelength 
        weighted_wavelength_list = automating_weighted_wave_multiple_slits(index_of_slit, rbflux_fits_exclude, median) #there is some None values. Let's filter it out! 
        print("List of All Weighted Wavelength: {0}".format(weighted_wavelength_list))

        
        #Determine the Optimal Shift Factor 
        print('Finding the Optimal Shift Factor...')
        shifting_value_dict,rbflux_shifted_dict = finding_shifting(slit_number_used, flux_exclude, wave_exclude,ivar_exclude,
                                                                   median_boolean_array, rbwave_fits_exclude[0], index_of_slit)
        
        #Saving the Shift Factor
        saving_shifting_factor(mask_name,slit_number_used,shifting_value_dict)
        
        #Sort through rbflux_shifted_dict and remove all key-value we don't need (reduce size of saved file)
        sorted_rbflux_shifted_dict = sorting_needed_shift_factor(shifting_value_dict,rbflux_shifted_dict)
        
        #Saving the rebinned flux calculated using the shifted wavelength
        saving_rbflux_shifted(mask_name,slit_number_used,sorted_rbflux_shifted_dict)
        
        try:
            #Original Polynomial Fits for Shifting
            shifting_values_list,wavelength_values_shift,poly_const_second_deg_shifting = wavelength_shifting_function(shifting_value_dict,weighted_wavelength_list)

            #Improving Original Polynomial Fits for Shifting
            #Remove Outliers 
            no_outlier_shifting_values_list, no_outlier_wavelength_values_list_shift = remove_outliers_sigma_clip_shifting(shifting_values_list,wavelength_values_shift,poly_const_second_deg_shifting)

            #New Polynomial Fits for Shifting
            no_outlier_wavelength_values_list_shift,left_end_val_shift,right_end_val_shift,poly_const_second_deg_shift = new_polynomial_fits_shifting(mask_name,slit_number_used,no_outlier_shifting_values_list,no_outlier_wavelength_values_list_shift)

            #Use New Poly Fits to Shift Original Flux
            original_wavelength_shift = use_new_shifting_factor(wave_exclude[index_of_slit],
                               no_outlier_wavelength_values_list_shift,left_end_val_shift,
                               right_end_val_shift,poly_const_second_deg_shift)
        
        except TypeError:
            poly_const_second_deg_shift = [0,0,0] #no shifting
            no_outlier_shifting_values_list = list(shifting_value_dict.values())
            no_outlier_wavelength_values_list_shift = weighted_wavelength_list
            original_wavelength_shift = wave_exclude[index_of_slit]
            
        #Saving Shifted Original Wavelength as Fits Files 
        save_new_polynomial_fits_values_shift(no_outlier_shifting_values_list,no_outlier_wavelength_values_list_shift,original_wavelength_shift,mask_name,slit_number_used)
        
        #Saving Shifting Polynomial Coefficients
        save_shift_polynomial_coeff(mask_name,slit_number_used,poly_const_second_deg_shift,left_end_val_shift,right_end_val_shift,no_outlier_wavelength_values_list_shift)

## Shifting - Execution

In [None]:
slit_numbers_you_want_to_shift_included_slits = slit_nums #put in the slit numbers you want to run 

# The text output is written to a log file.
# The log file is overwritten (not appended) every time this cell is executed.
# If you want to keep a record of particular run, make sure you save a copy of the file before running this cell again.

shifting_included_slits(slit_numbers_you_want_to_shift_included_slits)


In [None]:
slit_numbers_you_want_to_shift_excluded_slits = slit_nums_exclude #put the slit number of all slits you want to iterate

# The text output is written to a log file.
# The log file is overwritten (not appended) every time this cell is executed.
# If you want to keep a record of particular run, make sure you save a copy of the file before running this cell again.
    
shifting_excluded_slits(slit_numbers_you_want_to_shift_excluded_slits)
    

# New Median

In [None]:
# def read_shifted_wavelength(slit_nums,mask_name):
    
#     shifted_original_wavelength_fits = []
    
#     for slit_num in slit_nums:
        
#         shifted_wavelength = fits.open(("{2}/{0}/{0}_Rebinned/{0}_Shift_Values/Shift_Values_Polynomial_Fits_{0}_{1}.fits.gz".format(mask_name,slit_num,optimized_data_path)))
#         shifted_original_wavelength_fits.append(shifted_wavelength[1].data["ORIGINAL_WAVELENGTH_SHIFTED"])
    
#     return shifted_original_wavelength_fits

## Geting Back Shift Factor

In [None]:
def read_shift_factor(mask_name,slit_number_used):
    
    shift_factor_hdu = fits.open(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Shifting_Factor/Shifting_Factor_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))
    shift_factor = shift_factor_hdu[1].data["SHIFTING_FACTOR"]
            
    return shift_factor

## Getting Back Flux Rebinned With Shifted Wavelength

In [None]:
def read_rbflux_shifted_wave(mask_name,slit_number_used):
    read_dict = np.load("{2}/{0}/{0}_Rebinned/{0}_Rebinned_Flux_Shifted_Wave/rbflux_shifted_dict_{0}_{1}.npy".format(mask_name,slit_number_used,optimized_data_path),allow_pickle=True).item()
    return read_dict

In [None]:
def get_shifted_flux_array(index_of_slit,rbflux_shifted_dict,rebinned_flux,shifting_factors):
    
    wavelength_array = np.arange(4000,11500,500)
    
    new_flux_joined = []

    # Rounding shift factors so that they can work with rbflux_shifted_dict
    shifting_values_str = [str(round(x,2)) if (x!=0)  else '0.0' for x in shifting_factors]
    #shifting_values_str = [str(round(x,2)) if ((x!=0) and (np.isfinite(x)==True)) else '0.0' for x in shifting_factors]
    #print(shifting_values_str)
    
    for index in range(len(wavelength_array)):
        
        if (index + 1) == len(wavelength_array): 
            break
        
        else: 
            
            spectrum = rebinned_flux[index_of_slit] #original rebinned flux 

            #For segments that DOES NOT have Opt Shifting Factor
            if np.isfinite(shifting_factors[index]) == False: 
                fx = np.array(spectrum)
                l = list(fx[np.where((new_wave_600>=wavelength_array[index]) & (new_wave_600<wavelength_array[index+1]))])
                new_flux_joined = new_flux_joined + l
                    
                            
            #For segments that do have Opt Shifting Factor
            else: 
                spectrum_shifted = rbflux_shifted_dict["Shifted_{0}".format(shifting_values_str[index])]
                fx = np.array(spectrum_shifted)
                l = list(fx[np.where((new_wave_600>=wavelength_array[index]) & (new_wave_600<wavelength_array[index+1]))])
                new_flux_joined = new_flux_joined + l
    
    return np.array(new_flux_joined)

In [None]:
def looping_shifted_flux(slit_numbers_to_automate, incl_or_excl):
    if (incl_or_excl == 'Included'):
        fluxes = rbflux_fits
        slits = slit_nums
    else:
        fluxes = rbflux_fits_exclude
        slits = slit_nums_exclude
    
    shifted_fluxes = []
    for slit_number_used in slit_numbers_to_automate:
#         print(slit_number_used)
        
        index_of_slit = slits.index(slit_number_used)
        
        # Read in saved dictionary flux rebinned with shifted wavelength
        rbflux_shifted_dict = read_rbflux_shifted_wave(mask_name,slit_number_used)
        
        # Read in saved shift factors
        shift_factors = read_shift_factor(mask_name,slit_number_used)

        # Generate flux array rebinned with shifted wavelength
        rbflux_shifted = get_shifted_flux_array(index_of_slit,rbflux_shifted_dict,fluxes,shift_factors)
        shifted_fluxes.append(rbflux_shifted)
    
    return (np.array(shifted_fluxes))

In [None]:
shifted_fluxes_included = looping_shifted_flux(slit_nums, 'Included')

In [None]:
shifted_fluxes_excluded = looping_shifted_flux(slit_nums_exclude, 'Excluded')

## Making The New Median

In [None]:
new_median = find_median(shifted_fluxes_included)

## Saving The New Median

In [None]:
def exportToFitsNewMedian(median,mask_name):
    
    hdu1 = fits.PrimaryHDU()
        
    c1 = fits.Column(name='NEW MEDIAN',array=median,format="E")
    hdu2 = fits.BinTableHDU.from_columns([c1])
        
    hdul = fits.HDUList([hdu1,hdu2])
        
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_New_Median/New_Median_of_{0}.fits.gz'.format(mask_name))

In [None]:
exportToFitsNewMedian(new_median,mask_name)

## Getting Back New Median

In [None]:
def get_new_med_from_fits(mask_name):
    median_read = fits.open(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_New_Median/New_Median_of_{0}.fits.gz'.format(mask_name))
    median_fits = median_read[1].data["NEW MEDIAN"] #contain the median 
    return median_fits

In [None]:
new_median_fits = get_new_med_from_fits(mask_name)

# Scaling

In [None]:
def rbflux_minus_median(rbflux, median, multipliers,
                       use_moving_median=use_moving_median):
    
    subtraction_dict = {}
    
    if use_moving_median==True:
        median_baseline = moving_median(rbflux-median,  window=window)
        rbflux = rbflux - median_baseline
        for multiplier in multipliers: 

            subtraction_list = (multiplier * rbflux) - median
            subtraction_list = subtraction_list - moving_median(subtraction_list,  window=window)
            subtraction_dict["Multiplier_{}".format(round(multiplier,2))] = subtraction_list
            
    else:
        for multiplier in multipliers: 

            subtraction_list = (multiplier * rbflux) - median
            subtraction_dict["Multiplier_{}".format(round(multiplier,2))] = subtraction_list
    
    return subtraction_dict

In [None]:
def sorting_rms(multiply_boolean, subtraction_dict, multipliers): #both inputs are same length
    
    rms_dict_sorted = {}
    
    for multiplier in multipliers:
                
        subtraction = subtraction_dict["Multiplier_{}".format(round(multiplier,2))]
            
        rms_dict_sorted["Multiplier_{}".format(round(multiplier,2))] = subtraction[np.where(multiply_boolean == True)]
                    
    return rms_dict_sorted

In [None]:
def rms_calculation(rms_dict_sorted, multipliers):
    
    try:
        rms_dict = {}

        for multiplier in multipliers:

            values_for_rms_cal = rms_dict_sorted["Multiplier_{}".format(round(multiplier,2))]
            
            values_for_rms_cal = np.array(values_for_rms_cal)
            rms = np.nanmean(values_for_rms_cal*values_for_rms_cal)**0.5
#             rms = statistics.stdev(values_for_rms_cal)
            
            if not (np.isfinite(rms)):
                raise ValueError

            rms_dict["Multiplier_{}".format(round(multiplier,2))] = rms
        
        return rms_dict
    
    except:
        print("Everything is False. There's no True boolean. Therefore, RMS cannot be calculated.")
        

In [None]:
def plotting_the_rms(slit_number,multipliers, rms_dict, min_wave, max_wave):

    path = optimized_data_path + '/{0}'.format(mask_name) + "/{0}_Polynomial_Graph/Scaling_vs_RMS/Slit_{1}".format(mask_name,slit_number)
    try: 
        os.makedirs(path)
    except OSError:
        if not os.path.isdir(path):
            raise
    try:
        value_list = []

        for multiplier in multipliers: 

            value_list.append(rms_dict["Multiplier_{0}".format(round(multiplier,2))])

        fig = plt.figure(figsize=(8,6))
        fig.patch.set_alpha(1)
        plt.plot(multipliers,value_list)
        plt.xlabel("Scaling")
        plt.ylabel("RMS")
        plt.title("Mask {0}: Slit #{1}\nRMS vs Scaling ({2} A to {3} A)".format(mask_name,slit_number,min_wave,max_wave))
        fig.savefig(path+'/{0}_Slit_{1}_{2}_to_{3}.png'.format(mask_name,slit_number,min_wave,max_wave))
        
        min_val = min(value_list)

        scale = round(multipliers[value_list.index(min_val)],2)

        print("Scaling w/ minimum RMS (Slit #{0}): {1}".format(slit_number,scale) + " ({0} A to {1} A)".format(min_wave,max_wave))
        
        return scale
        
    except:
        print("Because we have no RMS there is no plot.")

## Automating All The Functions Used In Scaling Process (For A Single Slit)

In [None]:
def finding_scaling(slit_number, median_and_subtraction_boolean_array,
                    wavelength, min_wave, max_wave,
                    multipliers, subtraction_dict): 
    
    #all the functions combined together here for convenience!
    
    #boolean array using wavelength
    wavelength_boolean_array = create_wave_bool(wavelength, min_wave, max_wave)
    
    #multiply the two boolean arrays
    multiply_boolean = median_and_subtraction_boolean_array * wavelength_boolean_array
    
#     subtraction_dict = subtraction_dict
    
    #sort through all the subtractions and keep only those that are True
    rms_dict_sorted = sorting_rms(multiply_boolean, subtraction_dict, multipliers)
    
    #calculate the RMS associated with each scaling factor 
    rms_dict = rms_calculation(rms_dict_sorted, multipliers)
    
    #plot RMS vs scaling factor
    scaling_value = plotting_the_rms(slit_number,multipliers,rms_dict,min_wave,max_wave)

    return scaling_value

In [None]:
def looping_finding_scale(slit_number, subtraction_dict, median_and_subtraction_boolean_array):

    #wavelength_array =  np.arange(4000,11200,200)
    wavelength_array = np.arange(4000,11500,500)
    #wavelength_array = np.arange(4000, 11350, 350)
    
    scaling_value_dict = {}

    for index in range(len(wavelength_array)):

        if (index + 1) == len(wavelength_array): 
            break

        else:
        
            scaling_value = finding_scaling(slit_number, median_and_subtraction_boolean_array,
                                            rbwave_fits[0], wavelength_array[index], wavelength_array[index+1],
                                            multipliers, subtraction_dict)
            
            scaling_value_dict["{0}_to_{1}".format(wavelength_array[index],wavelength_array[index+1])] = scaling_value
            
    return scaling_value_dict


## Polynomial Fit for Scaling

In [None]:
def wavelength_scaling_function(scaling_value_dict, weighted_wavelength): #used to make a plot of optimal scale factor vs wavelength
    
    scaling_values = [] #do not plot any scaling value that has None
    
    wavelength_values = [] #contains all the wavelength we will plots 
    
    for index in range(len(scaling_value_dict.values())): #filtering out the None values
        
        if list(scaling_value_dict.values())[index] != None:
        
            scaling_values.append(list(scaling_value_dict.values())[index])
            
            wavelength_values.append(weighted_wavelength[index])

        else:
            pass
    
    all_wavelength_values = np.arange(4000,11000,0.65) #used to plot every single values between 4000 and 11000 using our polynomial

    #finding the polynomial constant for second order
    poly_const_second_deg = np.polyfit(wavelength_values,scaling_values,2)
    print("Second order polynomial: y = {0}x^2 + {1}x + {2}".format(poly_const_second_deg[0],
                                                                    poly_const_second_deg[1],poly_const_second_deg[2]))
    
    
    #finding the polynomial constant for third order
    poly_const_third_deg = np.polyfit(wavelength_values,scaling_values,3)
    print("Third order polynomial: y = {0}(x^3) + {1}(x^2) + {2}(x) + {3}".format(poly_const_third_deg[0],
                                                                                  poly_const_third_deg[1],poly_const_third_deg[2],poly_const_third_deg[3]))
    
    #finding the polynomial constant for fourth order
    poly_const_fourth_deg = np.polyfit(wavelength_values,scaling_values,4)
    print("Fourth order polynomial: y = {0}(x^4) + {1}(x^3) + {2}(x^2) + {3}(x) + {4}".format(poly_const_fourth_deg[0],
                                                                                  poly_const_fourth_deg[1],poly_const_fourth_deg[2],poly_const_fourth_deg[3]
                                                                                             ,poly_const_fourth_deg[4]))
    
    
    #calculating the second, third,and fourth order polynomial as a function of wavelength
    poly_second_order = [] #green line
    
    poly_third_order = [] #blue line
    
    poly_fourth_order = [] #orange line
    
    for value in all_wavelength_values:
        poly_second_order.append(polynomial_second_order(value,poly_const_second_deg))
        
        poly_third_order.append(polynomial_third_order(value,poly_const_third_deg))
        
        poly_fourth_order.append(polynomial_fourth_order(value,poly_const_fourth_deg))
                                                                                            
    
    #plotting the scaling factors, third-order polynomial, fourth-order polynomial as a function of wavelength
    fig,axs = plt.subplots(1)
    axs.set_xlim(3900,11100)
    axs.set_ylim(-0.5,2.5)
    axs.set_title("Optimal Scale Factor vs Wavelength")
    axs.set_xlabel("Wavelength (A)")
    axs.set_ylabel("Optimal Scale Factor")
    axs.plot(all_wavelength_values,poly_second_order,scalex=False,scaley=False,label="Second-Order",c="green")
    axs.plot(all_wavelength_values,poly_third_order,scalex=False,scaley=False,label="Third-Order")
    axs.plot(all_wavelength_values,poly_fourth_order,scalex=False,scaley=False,label="Fourth-Order")
    axs.scatter(wavelength_values,scaling_values,label="Optimal Scale Factor",c="black")
    axs.legend()
    
    return scaling_values,wavelength_values,poly_const_second_deg


## Improving The Polynomial Fits (Scaling)

In [None]:
def remove_outliers_sigma_clip_scaling(scaling_values_list,weighted_wavelength_list,poly_const_second_deg):
    
    #optimal shifting value - optimal shifting value based on line of best fit
    deviation = [] 
    
    #determine the vertical difference (deviation) between dots and line of best fit
    for n in range(len(weighted_wavelength_list)):
        deviation.append(np.abs(polynomial_second_order(weighted_wavelength_list[n],poly_const_second_deg) - scaling_values_list[n]))
        
    print("First Deviation: {}".format(deviation))
    print("First Deviation RMS: {}".format(statistics.stdev(deviation)))
    
    
    #BEGIN FIRST ITERATION OF SIGMA-CLIPPING
    
    #variables to store scaling factor and wavelength from first iteration 
    scaling_values_1 = [] #scaling factor that is not an outlier
    wavelength_values_1 = [] #wavelength that is not an outlier
    
    #variables to store outliers 
    outlier_deviation = []
    
    #remove outliers from data
    for n in range(len(deviation)):
        if deviation[n]/statistics.stdev(deviation) < 3.0:
            scaling_values_1.append(scaling_values_list[n])#keep all non-outlier
            wavelength_values_1.append(weighted_wavelength_list[n])
        #add all outliers to a separate list
        else: 
            outlier_deviation.append(deviation[n])
    
    #if there's no outlier, return inputs 
    if len(outlier_deviation) == 0: 
        return scaling_values_1, wavelength_values_1
        exit()
    
    #if there are outliers, remove the largest one and keep the remaining outliers
    elif len(outlier_deviation) > 1: 
        outlier_deviation.remove(max(outlier_deviation))
        for value in outlier_deviation:
            scaling_values_1.append(scaling_values_list[deviation.index(value)])
            wavelength_values_1.append(weighted_wavelength_list[deviation.index(value)])
     
    #if there is only one values left after the outliers is removed, return scaling values and wavelength
    if len(scaling_values_1) == 0 or len(scaling_values_1) == 1:
        return scaling_values_1,wavelength_values_1
        exit()
    
    
    #PERFORM A NEW FITTING
    
    #calculate new polynomial coefficients
    new_poly_coeff = np.polyfit(wavelength_values_1,scaling_values_1,2)
    
    #CALCULATE NEW DEVIATION AND REMOVE OUTLIERS
    
    new_deviation = []
    
    for n in range(len(wavelength_values_1)):
        new_deviation.append(np.abs(polynomial_second_order(wavelength_values_1[n],new_poly_coeff) - scaling_values_1[n]))
    
    print("Second Deviation: {}".format(new_deviation))
    print("Second Deviation RMS: {}".format(statistics.stdev(new_deviation)))
                             
    #final set of data with outliers removed
    no_outlier_scaling_values_list = []
    no_outlier_wavelength_values_list = []
    
    #to stores outlier deviation for second iteration
    new_outlier_deviation = []
    
    #remove outliers
    for n in range(len(new_deviation)):
        if new_deviation[n]/statistics.stdev(new_deviation) < 3.0:
            no_outlier_scaling_values_list.append(scaling_values_1[n])#keep all non-outlier
            no_outlier_wavelength_values_list.append(wavelength_values_1[n])
        else:
            new_outlier_deviation.append(new_deviation[n])

    #if there's no outlier, return inputs 
    if len(new_outlier_deviation) == 0 or len(new_outlier_deviation) == 1: 
        return no_outlier_scaling_values_list,no_outlier_wavelength_values_list
        exit()
        
    #if there are outliers, remove the largest one and keep the remaining outliers
    elif len(new_outlier_deviation) > 1: 
        new_outlier_deviation.remove(max(new_outlier_deviation))
        for value in new_outlier_deviation:
            no_outlier_scaling_values_list.append(scaling_values_1[new_deviation.index(value)])
            no_outlier_wavelength_values_list.append(wavelength_values_1[new_deviation.index(value)])
        return no_outlier_scaling_values_list,no_outlier_wavelength_values_list

In [None]:
def new_polynomial_fits(mask_name,slit_num,scaling_values_list,wavelength_values_list):
    
    no_outlier_scaling_values_list = scaling_values_list #scaling values with outliers removed 
    no_outlier_wavelength_values_list = wavelength_values_list #wavelength values with outliers removed
    
    poly_const_second_deg = np.polyfit(no_outlier_wavelength_values_list,no_outlier_scaling_values_list,2) #find the polynomial constants
    print("New Second Order Polynomial: y = {0}x^2 + {1}x + {2}".format(poly_const_second_deg[0],
                                                                    poly_const_second_deg[1],poly_const_second_deg[2]))
    
    all_wavelength_values = np.arange(4000,11000,0.65) #used to plot every single values between 4000 and 11000 using our polynomial
    
    #calculate the new second order polynomial fits
    
    poly_second_order = [] 
    
    for value in all_wavelength_values:
        poly_second_order.append(polynomial_second_order(value,poly_const_second_deg))
        
        
        
    #straighten the ends
        
    #Optimal Scaling Factor of far left wavelength
    left_end_val = polynomial_second_order(min(no_outlier_wavelength_values_list),poly_const_second_deg) 
    
    #Optimal Scaling Factor of far right wavelength
    right_end_val = polynomial_second_order(max(no_outlier_wavelength_values_list),poly_const_second_deg) 
    
    print("Optimal Scale Factor of All Wavelength Before {0} Angstroms: {1}".format(min(no_outlier_wavelength_values_list),left_end_val)) #print them out 
    
    print("Optimal Scale Factor of All Wavelength After {0} Angstroms: {1} ".format(max(no_outlier_wavelength_values_list),right_end_val))
    
    #straighten the LEFT side of the polynomial fit
    for n in range(len(all_wavelength_values)):
        if all_wavelength_values[n] < min(no_outlier_wavelength_values_list):
            poly_second_order[n] = left_end_val
        
    #straighten the RIGHT side of the polynomial fit
    for n in range(len(all_wavelength_values)):
        if all_wavelength_values[n] > max(no_outlier_wavelength_values_list):
            poly_second_order[n] = right_end_val
    
    
    #plot
    fig,axs = plt.subplots(1)
    fig.set_size_inches(8,6)
    fig.patch.set_alpha(1)
    axs.set_xlim(3900,11100)
    axs.set_ylim(-0.5,2.5)
    axs.set_title("Mask {0}: Slit #{1}\nOptimal Scale Factor vs Wavelength".format(mask_name,slit_num))
    axs.set_xlabel("Wavelength (A)")
    axs.set_ylabel("Optimal Scale Factor")
    axs.plot(all_wavelength_values,poly_second_order,scalex=False,scaley=False,label="Second-Order",c="green")
    axs.scatter(wavelength_values_list,scaling_values_list,label="Optimal Scale Factor",c="black")
    fig.savefig(optimized_data_path + '/{0}'.format(mask_name) + "/{0}_Polynomial_Graph/Scaling_Fitting/Poly_Graph_{0}_slit_{1}.png".format(mask_name,slit_num))
    
    return no_outlier_wavelength_values_list,left_end_val,right_end_val,poly_const_second_deg

## Scale Original Flux Using Polynomial Fits

In [None]:
def use_new_scaling_factor(flux_include,wave_include,no_outlier_wavelength_values_list,left_end_val,right_end_val,poly_const_second_deg):
    
    flux = flux_include #original flux from FITS files 
    wave = wave_include #original wavelength from FITS files
    
    no_outlier_wavelength_values_list = no_outlier_wavelength_values_list #weighted wavelength with outliers removed 
    
    original_flux_scale = []
    
    for n in range(len(wave_include)): 
        if wave_include[n] < min(no_outlier_wavelength_values_list):
            original_flux_scale.append(flux_include[n] * left_end_val)
        
        elif wave_include[n] > max(no_outlier_wavelength_values_list): 
            original_flux_scale.append(flux_include[n] * right_end_val)
                                      
        else:
            scale_factor = polynomial_second_order(wave_include[n],poly_const_second_deg)
            original_flux_scale.append(flux_include[n] * scale_factor)
            
    return original_flux_scale

## Saving Scaled Flux as FITS Files

In [None]:
def save_new_polynomial_fits_values_scale(no_outlier_scaling_values_list, no_outlier_wavelength_values_list, original_flux_scale, mask_name,slit_number_used):
            
    hdu1 = fits.PrimaryHDU() #primary HDU (empty)
        
    c1 = fits.Column(name='SCALE_VALUES', array=no_outlier_scaling_values_list, format='E')
    c2 = fits.Column(name='WAVELENGTH_VALUES', array=no_outlier_wavelength_values_list, format='E')
    c3 = fits.Column(name='ORIGINAL_FLUX_SCALED', array=original_flux_scale, format='E')

    hdu2 = fits.BinTableHDU.from_columns([c1, c2, c3]) #first extensional HDU (w data)
            
    hdul = fits.HDUList([hdu1, hdu2]) #combine both HDUs into file and write it below
        
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Scale_Values/Scale_Values_Polynomial_Fits_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))

## Saving Scale Factor

In [None]:
def saving_scaling_factor(mask_name,slit_number_used,scaling_value_dict):
    
    hdu1 = fits.PrimaryHDU()
    
    c1 = fits.Column(name="SCALING_FACTOR", array=list(scaling_value_dict.values()), format='E')

    hdu2 = fits.BinTableHDU.from_columns([c1])
    
    hdul = fits.HDUList([hdu1,hdu2])
    
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Scaling_Factor/Scaling_Factor_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))

## Saving Scaling Polynomial Coefficients

In [None]:
def save_scale_polynomial_coeff(mask_name,slit_number_used,poly_const_second_deg_scal,left_end_val,right_end_val,no_outlier_wavelength_values_list):
    
    hdu1 = fits.PrimaryHDU() #primary HDU (empty)
    
    c1 = fits.Column(name='POLYNOMIAL_COEFFICIENTS_SCALING', array=poly_const_second_deg_scal, format='E')
    c3 = fits.Column(name='END_VALUES', array=[left_end_val,right_end_val], format='E')
    c4 = fits.Column(name='NO_OUTLIER_WEIGHTED_WAVE',array=no_outlier_wavelength_values_list,format='E')

    hdu2 = fits.BinTableHDU.from_columns([c1,c3,c4]) #first extensional HDU (w data)
            
    hdul = fits.HDUList([hdu1, hdu2]) #combine both HDUs into file and write it below
        
    hdul.writeto(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Polynomial_Coefficients/{0}_Scaling_Polynomial_Coefficients/Scale_Poly_Coeff_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))
    

## Automate Scaling of Included Slits

In [None]:
@write_to_log('{0}/{1}/{1}_Scaling_Log_Included.log'.format(optimized_data_path,mask_name))
def scaling_included_slits(slit_numbers_to_automate):
    
    median = new_median_fits

    #boolean array using median and threshold
    median_boolean_array = median_threshold(median, threshold_median)
    
    for slit_number in slit_numbers_to_automate:

        slit_number_used = slit_number
        print('\nOptimizing slit #{0} (Included) from mask {1}.'.format(slit_number_used,mask_name))

        index_of_slit = slit_nums.index(slit_number_used) #index of slit
        
        #boolean array using rbflux - median 
        sky_sub_boolean_array = sky_sub_bool(index_of_slit, shifted_fluxes_included, median, threshold_sky_sub)
        
        median_and_subtraction_boolean_array = median_boolean_array * sky_sub_boolean_array
                
        #Determine the Optimal Scale Factor 
        print('Finding the Optimal Scale Factor...')
        subtraction_dict = rbflux_minus_median(shifted_fluxes_included[index_of_slit], median, multipliers)

        scaling_value_dict = looping_finding_scale(slit_number_used, subtraction_dict, median_and_subtraction_boolean_array)

        #Determine the weighted wavelength 
        weighted_wavelength_list = automating_weighted_wave_multiple_slits(index_of_slit, shifted_fluxes_included, median) #there is some None values. Let's filter it out! 
        print("List of All Weighted Wavelength: {0}".format(weighted_wavelength_list))
        
        try:
            #Original Polynomial Fits for Scaling
            scaling_values_list,wavelength_values_scal,poly_const_second_deg_scaling = wavelength_scaling_function(scaling_value_dict,weighted_wavelength_list)

            #Improving Original Polynomial Fits for Scaling
            #Remove Outliers Using Sigma Clipping
            no_outlier_scaling_values_list, no_outlier_wavelength_values_list_scal = remove_outliers_sigma_clip_scaling(scaling_values_list,wavelength_values_scal,poly_const_second_deg_scaling)

            #New Polynomial Fits for Scaling
            no_outlier_wavelength_values_list_scal,left_end_val_scal,right_end_val_scal,poly_const_second_deg_scal = new_polynomial_fits(mask_name,slit_number_used,no_outlier_scaling_values_list,no_outlier_wavelength_values_list_scal)
            print('left end value:',left_end_val_scal)
            
            #Use New Poly Fits to Scale Original Flux
            original_flux_scale = use_new_scaling_factor(shifted_fluxes_included[index_of_slit],wave[index_of_slit],
                               no_outlier_wavelength_values_list_scal,left_end_val_scal,
                               right_end_val_scal,poly_const_second_deg_scal)
        except TypeError: 
            poly_const_second_deg_scal = [0,0,1] #straight horizontal line, no scaling
            no_outlier_scaling_values_list = list(scaling_value_dict.values())
            no_outlier_wavelength_values_list_scal = weighted_wavelength_list
            original_flux_scale = shifted_fluxes_included[index_of_slit]
                    
        #Saving Scaled Flux as FITS Files
        save_new_polynomial_fits_values_scale(no_outlier_scaling_values_list, no_outlier_wavelength_values_list_scal, original_flux_scale,mask_name,slit_number_used)

        #Saving the Scale Factor
        saving_scaling_factor(mask_name,slit_number_used,scaling_value_dict)
        
        #Saving Scaling Polynomial Coefficients
        save_scale_polynomial_coeff(mask_name,slit_number_used,poly_const_second_deg_scal,left_end_val_scal,right_end_val_scal,no_outlier_wavelength_values_list_scal)

## Automate Scaling of Excluded Slits

In [None]:
@write_to_log('{0}/{1}/{1}_Scaling_Log_Excluded.log'.format(optimized_data_path,mask_name))
def scaling_excluded_slits(slit_numbers_to_automate):
    
    median = new_median_fits

    #boolean array using median and threshold
    median_boolean_array = median_threshold(median, threshold_median)
    
    for slit_number in slit_numbers_to_automate:

        slit_number_used = slit_number
        print('\nOptimizing slit #{0} (Excluded) from mask {1}.'.format(slit_number_used,mask_name))

        index_of_slit = slit_nums_exclude.index(slit_number_used) #index of slit

        #boolean array using rbflux - median 
        sky_sub_boolean_array = sky_sub_bool(index_of_slit, shifted_fluxes_excluded, median, threshold_sky_sub)
        
        median_and_subtraction_boolean_array = median_boolean_array * sky_sub_boolean_array
        
        #Determine the Optimal Scale Factor 
        print('Finding the Optimal Scale Factor...')
        subtraction_dict = rbflux_minus_median(shifted_fluxes_excluded[index_of_slit], median, multipliers)

        scaling_value_dict = looping_finding_scale(slit_number_used, subtraction_dict, median_and_subtraction_boolean_array)

        #Determine the weighted wavelength 
        weighted_wavelength_list = automating_weighted_wave_multiple_slits(index_of_slit, shifted_fluxes_excluded, median) #there is some None values. Let's filter it out! 
        print("List of All Weighted Wavelength: {0}".format(weighted_wavelength_list))
        
        try:
            #Original Polynomial Fits for Scaling
            scaling_values_list,wavelength_values_scal,poly_const_second_deg_scaling = wavelength_scaling_function(scaling_value_dict,weighted_wavelength_list)

            #Improving Original Polynomial Fits for Scaling
            #Remove Outliers Using Sigma Clipping
            no_outlier_scaling_values_list, no_outlier_wavelength_values_list_scal = remove_outliers_sigma_clip_scaling(scaling_values_list,wavelength_values_scal,poly_const_second_deg_scaling)

            #New Polynomial Fits for Scaling
            no_outlier_wavelength_values_list_scal,left_end_val_scal,right_end_val_scal,poly_const_second_deg_scal = new_polynomial_fits(mask_name,slit_number_used,no_outlier_scaling_values_list,no_outlier_wavelength_values_list_scal)

            #Use New Poly Fits to Scale Original Flux
            original_flux_scale = use_new_scaling_factor(shifted_fluxes_excluded[index_of_slit],wave_exclude[index_of_slit],
                               no_outlier_wavelength_values_list_scal,left_end_val_scal,
                               right_end_val_scal,poly_const_second_deg_scal)
        except TypeError: 
            poly_const_second_deg_scal = [0,0,1] #straight horizontal line, no scaling
            no_outlier_scaling_values_list = list(scaling_value_dict.values())
            no_outlier_wavelength_values_list_scal = weighted_wavelength_list
            original_flux_scale = shifted_fluxes_excluded[index_of_slit]
        
        #Saving Scaled Flux as FITS Files
        save_new_polynomial_fits_values_scale(no_outlier_scaling_values_list, no_outlier_wavelength_values_list_scal, original_flux_scale,mask_name,slit_number_used)

        #Saving the Scale Factor
        saving_scaling_factor(mask_name,slit_number_used,scaling_value_dict)
        
        #Saving Scaling Polynomial Coefficients
        save_scale_polynomial_coeff(mask_name,slit_number_used,poly_const_second_deg_scal,left_end_val_scal,right_end_val_scal,no_outlier_wavelength_values_list_scal)

## Scaling - Execution

In [None]:
slit_numbers_you_want_to_scale_included_slits = slit_nums #put in the slit numbers you want to run 

# The text output is written to a log file.
# The log file is overwritten (not appended) every time this cell is executed.
# If you want to keep a record of particular run, make sure you save a copy of the file before running this cell again.

scaling_included_slits(slit_numbers_you_want_to_scale_included_slits)


In [None]:
slit_numbers_you_want_to_scale_excluded_slits = slit_nums_exclude #put the slit number of all slits you want to iterate

# The text output is written to a log file.
# The log file is overwritten (not appended) every time this cell is executed.
# If you want to keep a record of particular run, make sure you save a copy of the file before running this cell again.

scaling_excluded_slits(slit_numbers_you_want_to_scale_excluded_slits)


# Reading Back Optimized Data

## Getting Back Shift & Scale Factors

In [None]:
def read_scale_and_shift_factor(mask_name,slit_number_used):
    
    scale_factor_hdu = fits.open(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Scaling_Factor/Scaling_Factor_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))
    scale_factor = scale_factor_hdu[1].data["SCALING_FACTOR"]
    
    shift_factor = read_shift_factor(mask_name,slit_number_used)
            
    return scale_factor,shift_factor 

## Getting Back All Polynomial Coefficients

In [None]:
def read_poly_coeff(mask_name,slit_number_used):
    
    #get the polynomial coefficients and stores them
    
    shift_poly_coeff = fits.open(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Polynomial_Coefficients/{0}_Shifting_Polynomial_Coefficients/Shift_Poly_Coeff_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))
    read_poly_coeff_shift = shift_poly_coeff[1].data["POLYNOMIAL_COEFFICIENTS_SHIFTING"]
    
    scale_poly_coeff = fits.open(optimized_data_path + '/{0}'.format(mask_name) + '/{0}_Rebinned/{0}_Polynomial_Coefficients/{0}_Scaling_Polynomial_Coefficients/Scale_Poly_Coeff_{0}_{1}.fits.gz'.format(mask_name,slit_number_used))
    read_poly_coeff_scale = scale_poly_coeff[1].data["POLYNOMIAL_COEFFICIENTS_SCALING"]
    end_values = scale_poly_coeff[1].data["END_VALUES"]
    no_outlier_weighted_wave = scale_poly_coeff[1].data["NO_OUTLIER_WEIGHTED_WAVE"]
    
    return read_poly_coeff_scale,read_poly_coeff_shift,end_values,no_outlier_weighted_wave

## Applying Optimal Shift & Scale Factors (For A Single Slit)

In [None]:
def optimizing_using_saved_data(index_of_slit,median,rbflux_shifted_dict,rebinned_flux,shifting_values,poly_coeff_scale,end_values,no_outlier_weighted_wave):
   
    new_flux_joined = get_shifted_flux_array(index_of_slit,rbflux_shifted_dict,rebinned_flux,shifting_values)
    
    all_scale_factors = [] #calculate all the scale factor using polynomial function
    
    for wavelength in new_wave_600:
        if wavelength < min(no_outlier_weighted_wave):
            all_scale_factors.append(end_values[0])
             
        elif wavelength > max(no_outlier_weighted_wave):
            all_scale_factors.append(end_values[1])
        
        else:
            scaling_factor = polynomial_second_order(wavelength,poly_coeff_scale)
            all_scale_factors.append(scaling_factor)
    
    new_rbflux_scaled = (np.asarray(new_flux_joined) * np.array(all_scale_factors)) - np.array(median)
    
    return new_rbflux_scaled

## Automation of Included Slits

In [None]:
def automation_func_include(mask_name,slit_number):
    
    #slit number
    slit_number_used = slit_number 
    
    median = new_median_fits
    
    #index of slit number using list of included slits 
    index_of_slit = slit_nums.index(slit_number_used)
    
    #get scale factor and shift factor from FITS files
    scale_factor_list,shift_factor_list = read_scale_and_shift_factor(mask_name,slit_number_used)
    
    #Display whole number, but is actually a float 
    #Round it to whole number 
    # rounded_scale_factor_list,rounded_shift_factor_list = rounding_the_factors(scale_factor_list,shift_factor_list)
    
    #Rebinned flux w/ shifted wavelength 
    rbflux_shifted_dict = read_rbflux_shifted_wave(mask_name,slit_number_used)
    
    #Read poly coeff
    poly_coeff_scale,poly_coeff_shift,end_values,no_outlier_weighted_wave = read_poly_coeff(mask_name,slit_number)
    
    #subtraction using new median
    new_flux = optimizing_using_saved_data(index_of_slit,median,rbflux_shifted_dict,rbflux_fits,shift_factor_list,poly_coeff_scale,end_values,no_outlier_weighted_wave)    
    
    return new_flux


In [None]:
def combining_all_include_new_flux(mask_name,slit_nums):
    new_flux_list = []

    for slit_number in slit_nums:

        new_flux = automation_func_include(mask_name,slit_number)

        new_flux_list.append(new_flux)
        
    return new_flux_list

In [None]:
new_flux_list_include_slits = combining_all_include_new_flux(mask_name,slit_nums)

## Automation of Excluded Slits

In [None]:
def automation_func_exclude(mask_name,slit_number):
    
    #slit number
    slit_number_used = slit_number 
    
    median = new_median_fits
    
    #index of slit number using list of excluded slits 
    index_of_slit = slit_nums_exclude.index(slit_number_used)
    
    #get scale factor and shift factor from FITS files
    scale_factor_list,shift_factor_list = read_scale_and_shift_factor(mask_name,slit_number_used)
    
    #Display whole number, but is actually a float 
    #Round it to whole number 
    # rounded_scale_factor_list,rounded_shift_factor_list = rounding_the_factors(scale_factor_list,shift_factor_list)
    
    #Rebinned flux w/ shifted wavelength 
    rbflux_shifted_dict = read_rbflux_shifted_wave(mask_name,slit_number_used)
    
    #Read poly coeff
    poly_coeff_scale,poly_coeff_shift,end_values,no_outlier_weighted_wave = read_poly_coeff(mask_name,slit_number)
    
    #subtraction using new median
    new_flux = optimizing_using_saved_data(index_of_slit,median,rbflux_shifted_dict,rbflux_fits_exclude,shift_factor_list,poly_coeff_scale,end_values,no_outlier_weighted_wave)    
    
    return new_flux


In [None]:
def combining_all_exclude_new_flux(mask_name,slit_nums_exclude):
    new_flux_list = []

    for slit_number in slit_nums_exclude:

        new_flux = automation_func_exclude(mask_name,slit_number)

        new_flux_list.append(new_flux)
        
    return new_flux_list

In [None]:
new_flux_list_exclude_slits = combining_all_exclude_new_flux(mask_name,slit_nums_exclude)

# Compare Optimized Subtraction With Initial Subtraction

In [None]:
def compare_subtraction_plots(slits,excl_or_incl):

    rbwave = new_wave_600
    
    if excl_or_incl == 'Excluded':
        def opt_flux(mask_name, slit_number):
            return (automation_func_exclude(mask_name,slit_number))
        rbflux = rbflux_fits_exclude
        slit_list = slit_nums_exclude
    elif excl_or_incl == 'Included':
        def opt_flux(mask_name, slit_number):
            return (automation_func_include(mask_name,slit_number))
        rbflux = rbflux_fits
        slit_list = slit_nums

    path = './Compare Subtractions/{0}/{1}/'.format(mask_name,excl_or_incl)
    try: 
        os.makedirs(path)
    except OSError:
        if not os.path.isdir(path):
            raise
    
    notes_path = "./Raja's notes/{0}_notes.txt".format(mask_name)        
    try:
        with open(notes_path,'r') as f:
            lines = [x.split() for x in f.readlines()]
            notes = {int(x[0]) : ' '.join(x[3:]) for x in lines}
            print_notes = True
    except IOError:
        print_notes = False
        
    def plot_sub_figure(i):
        plt.ioff()
        
        fig = plt.figure(figsize=(16,6),dpi=600)
        fig.patch.set_alpha(1)
        
        y1 = rbflux[slit_list.index(i)] - median_fits
        y2 = opt_flux(mask_name,i)
        
        plt.subplot(1,2,1)
        plt.plot(rbwave,y1,'dodgerblue',ls='-',label='Intial subtraction')
        plt.ylabel("Flux (Electron/Hour)")
        plt.xlabel("Angstroms ($\AA$)")
        plt.title('Intial subtraction', fontsize=14)
        plt.legend()
        
        plt.subplot(1,2,2)
        plt.plot(rbwave,y2,'dodgerblue',ls='-',label='Optimized flux')
        plt.ylabel("Flux (Electron/Hour)")
        plt.xlabel("Angstroms ($\AA$)")
        plt.title('Optimized subtraction', fontsize=14)
        plt.legend()
        
        plt.suptitle('Test figure: comparing subtractions\n slit #{0} in mask {1} ({2})'.format(i,mask_name,excl_or_incl),fontsize=16)
        if (print_notes == True):
            txt = 'Notes: {0}'.format(notes[i])
            plt.figtext(0.95, 0.08, txt, alpha=0.5, wrap=True, ha='right', va='top', fontstyle='italic', fontsize=8)
            plt.tight_layout(rect=[0,0.1,0.9,1])
        else:
            plt.tight_layout(rect=[0,0,1,0.99])
        s = 'test_figure_{0}_slit_{1}.png'.format(mask_name,i)
        plt.savefig(path+s)
        print('Saved figure {0}'.format(s))
        plt.close('all')
        plt.ion()
    
    for i in slits:
        plot_sub_figure(i)

In [None]:
compare_subtraction_plots(slit_nums,"Included")

In [None]:
compare_subtraction_plots(slit_nums_exclude,"Excluded")

In [None]:
# %matplotlib notebook

In [None]:
# plt.figure(1)
# plt.plot(rbwave_fits[0], rbflux_fits_exclude[slit_nums_exclude.index(42)] - median_fits)
# lines = np.array([6562.82, 6548.10, 6583.60, 6716.47, 6730.85])
# c = 299792
# v = -97
# lines *= (1+(v/c))
# plt.vlines(lines, -999, +999, color='black')
# plt.xlim(6700, 6740)
# plt.ylim(120, 240)