## QuasarNET Performance vs training set size
#### Plot to show performance of QuasarNET as a function of the size of the training set
xxx

In [1]:
import astropy
import copy
import glob
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

from quasarnet.io import read_truth
from qn_analysis import plot, strategy, utils, variables

In [2]:
figsize = (12,6.5)
fontsize = 18

plt.rc('font', size=fontsize)

In [3]:
## General parameters
pc_dr12 = 90
prop = pc_dr12/100.
datatype = 'coadd'

## QN parameters
n_detect = 1
c_th = 0.8
include_c_qn = True
common_specids = False

### Look at histories from different models

In [4]:
data = {}

for i in range(10):

    ## Load the QN using nchunks value.
    fi_qn = glob.glob(variables.OUTDIR+'/qn_models/main_setup/{}/prop_{}/*_{}/*hist*.fits'.format(datatype,prop,i))
    print(fi_qn)
    
    if len(fi_qn)>0:
        data['QN_{} {}'.format(p,i)] = astropy.io.fits.open(fi_qn[0])
    #data['QN_{}'.format(p)] = utils.load_qn_data(f_qn,n_detect=n_detect,c_th=c_th,include_c=include_c_qn)

['/global/cfs/projectdirs/desi/users/jfarr/QuasarNET_paper//qn_models/main_setup/coadd/prop_0.9/model_indtest_0_0/qn_train_coadd_indtest_0_0_hist.fits']


NameError: name 'p' is not defined

In [None]:
filename = '../plots/qn_loss_vs_nepochs_{}.pdf'.format(prop)

fig, axs = plt.subplots(1,1,figsize=figsize,squeeze=False)
for c in data.keys():
    ind = (c.split(' ')[-1])
    label = 'model {}'.format(ind)
    axs[0,0].plot(np.linspace(1,200,200),data[c][1].data['loss'],label=label)
axs[0,0].semilogy()
axs[0,0].set_ylabel('loss')
axs[0,0].set_xlabel('# epochs')
axs[0,0].set_xlim(0,200)
axs[0,0].set_ylim(3e-4,6e-1)
plt.legend(ncol=2)
plt.savefig(filename)
plt.show()

### Load the results from the different QN models

In [None]:
f_truth = variables.OUTDIR+'/data/truth/truth_dr12q.fits'
truth = read_truth([f_truth])

In [None]:
data = {}
dts = {}

for i in range(10):
    
    if prop>0.5:
        stype = 'indtest'
    else:
        stype = 'indtrain'
    ## Load the QN data.
    f_qn = variables.OUTDIR+'/outputs/qn_outputs/main_setup/{d}/prop_{p}/model_{s}_0_{i}/qnAll-train_{p}_{d}_0_{i}-test_{d}.fits'.format(d=datatype,p=prop,i=i,s=stype)
    data['QN_{} {}'.format(pc_dr12,i)] = utils.load_qn_data(f_qn,n_detect=n_detect,c_th=c_th,include_c=include_c_qn)

    #d = {'QN_5 {}'.format(i): data['QN_5 {}'.format(i)]}
    #dt = utils.reduce_data_to_table(d,truth,include_c_qn=include_c_qn,common_specids=False)
    #dts['QN_5 {}'.format(i)] = dt


### Find the set of common spectra, and reduce all data to that set
This finds which spectra* are common to all datasets, and removes any that are not common. It then matches the data from each classifier to each spectrum, and produces a single data table.

\* using spec_id = plate$\times$1000000000 + mjd$\times$10000 + fiber for BOSS, or spec_id = targetid for DESI (to be updated)

In [None]:
data_table = utils.reduce_data_to_table(data,truth,include_c_qn=include_c_qn,common_specids=common_specids)
data_table[:5]

In [None]:
# Include only objects in the data table that have the highest confidence level, and VI redshift not equal to -1.
w = (data_table['ZCONF_PERSON']==2) & (data_table['Z_VI']>-1)
data_table = data_table[w]
len(data_table)

### Compare the performance of the different model/data combos.
xxx

In [None]:
ls = {s: '-' for s in data.keys()}

training_set_sizes = {0.9: 563000, 0.8: 500000, 0.5: 313000, 0.2: 125000, 0.1: 63000, 0.05: 31000, 0.025: 116000, 0.01: 6000}

In [None]:
## Define general parameter values.
dv_max = 6000.
nspec_sdr12q = 627751

# confidence thresholds.
c_th_min = 0.0
c_th_max = 1.0
n_int = 101
c_th_values = np.linspace(c_th_min,c_th_max,n_int)

In [None]:
## Define strategies.
strategies = {}

for s in data.keys():
    print(s)
    
    #p = float(s.split('_')[-1])/100
    p = prop
    
    #print('making name')
    name = 'Model trained on\n{:.0%} DR12Q Superset\n'.format(p)
    name += r'($\sim${:,} spectra)'.format(training_set_sizes[p])

    #print('filtering')
    if type(data_table['ISQSO_{}'.format(s)])==astropy.table.column.MaskedColumn:
        filt = (~data_table['ISQSO_{}'.format(s)].data.mask)
    else:
        filt = np.ones(len(data_table)).astype(bool)
    temp_data_table = data_table[filt]
    
    #print('starting defs')
    # QN definitions.
    n_detect = 1
    strat = strategy.Strategy('qn',cf_kwargs={'qn_name':s})
    #print('making preds')
    preds = [strat.predict(temp_data_table,filter=None,c_kwargs={'c_th':c_th_value,'n_detect':n_detect}) for c_th_value in c_th_values]
    #preds = [strat.predict(dts[s],filter=filt,c_kwargs={'c_th':c_th_value,'n_detect':n_detect}) for c_th_value in c_th_values]
    #print('making dict entry')
    strategies[s] = {#'isqso': [pred.isqso for pred in preds],
                     #'z': [pred.z for pred in preds],
                     'predictions': preds,
                     'n': name, 
                     'ls': ls[s]}


In [None]:
strategies_to_plot = {'QN_{}'.format(pc_dr12): {'strategies': [s for s in data.keys() if not ('v0' in s)],
                                'ls': '-',
                                'n': name,
                               },
#                      'QN_v0_{}'.format(pc_dr12): {'strategies': [s for s in data.keys() if ('v0' in s)],
#                                'ls': '--',
#                                'n': name,
#                               }
                     }

In [None]:
filename = '../plots/qn_model_variation_individual_{}.pdf'.format(prop)

plot.plot_qn_model_compare(data_table,strategies,filename=filename,
                           dv_max=dv_max,nydec=2,figsize=(12,12),
                           ymin=0.98,ymax=1.,verbose=False,npanel=2,
                           norm_dvhist=True,c_th=c_th_values)
plt.show()

In [None]:
filename = '../plots/qn_model_variation_combined_{}.pdf'.format(prop)

plot.plot_qn_model_compare(data_table,strategies,filename=filename,
                           dv_max=dv_max,nydec=2,figsize=(12,12),
                           ymin=0.98,ymax=1.,verbose=True,npanel=2,
                           norm_dvhist=True,c_th=c_th_values,show_std=True,
                           strategies_to_plot=strategies_to_plot)
plt.show()