In [1]:
import numpy as np
import matplotlib.pyplot as plt
import copy
from modules import get_bandpass_dict, Sed

In [2]:
# plotting style
%config InlineBackend.figure_format = 'retina'
plt.style.use('paper.mplstyle')
twocol = 7.1014
onecol = 3.35

In [3]:
bandpass_dict = get_bandpass_dict()

In [4]:
def classify(template_list):
    
    cwwsb4 = np.loadtxt('templates/cwwsb4.list',dtype=str)
    bands = ['NUV','u','g','r','i','z','y','J']
    
    ref_fluxes = []
    for template in cwwsb4:
        name = 'templates/'+template
        x,y = np.loadtxt(name,unpack=True)
        sed = Sed(x,y)

        fluxes = sed.fluxlist(bandpass_dict,bands)
        ref_fluxes.append(fluxes)
        
    for template in template_list:
        name = 'templates/'+template
        x,y = np.loadtxt(name,unpack=True)
        sed = Sed(x,y)

        fluxes = sed.fluxlist(bandpass_dict,bands)
        mse = []
        for rfluxes in ref_fluxes:
            rfluxes_ = rfluxes.copy()
            ratios = fluxes/rfluxes_
            med = np.median(ratios)
            norm_idx = np.argmin(ratios - med)
            #norm_idx = 4
            rfluxes_ *= fluxes[norm_idx]/rfluxes_[norm_idx]
            mse.append(np.mean((rfluxes_ - fluxes)**2))
        idx = np.argmin(mse)
        print(template[:-4],'---',cwwsb4[idx][:-4])

In [5]:
cwwsb4_trained = ['El_trained.sed', 'Sbc_trained.sed',
                  'Scd_trained.sed', 'Im_trained.sed',
                  'SB3_trained.sed', 'SB2_trained.sed',
                  '25Myr_trained.sed', '5Myr_trained.sed']
classify(cwwsb4_trained)

El_trained --- El_B2004a
Sbc_trained --- Sbc_B2004a
Scd_trained --- Scd_B2004a
Im_trained --- Im_B2004a
SB3_trained --- Scd_B2004a
SB2_trained --- SB2_B2004a
25Myr_trained --- ssp_25Myr_z008
5Myr_trained --- ssp_25Myr_z008


In [6]:
N8 = ['N8_'+str(i+1)+'.sed' for i in range(8)]
classify(N8)

N8_1 --- El_B2004a
N8_2 --- Scd_B2004a
N8_3 --- Scd_B2004a
N8_4 --- Scd_B2004a
N8_5 --- Im_B2004a
N8_6 --- SB2_B2004a
N8_7 --- SB2_B2004a
N8_8 --- ssp_25Myr_z008


In [7]:
N16 = ['N16_'+str(i+1)+'.sed' for i in range(16)]
classify(N16)

N16_1 --- El_B2004a
N16_2 --- Scd_B2004a
N16_3 --- El_B2004a
N16_4 --- Scd_B2004a
N16_5 --- Scd_B2004a
N16_6 --- Scd_B2004a
N16_7 --- Scd_B2004a
N16_8 --- SB3_B2004a
N16_9 --- Im_B2004a
N16_10 --- SB2_B2004a
N16_11 --- SB2_B2004a
N16_12 --- SB2_B2004a
N16_13 --- SB2_B2004a
N16_14 --- SB2_B2004a
N16_15 --- ssp_25Myr_z008
N16_16 --- ssp_25Myr_z008
