# Plots for Siggi Paper

### Simple examples of information gain

In [None]:
# First import code
import sys
sys.path.append('..')

In [None]:
from siggi import siggi, filters, spectra, calcIG, plotting
from siggi import Sed
from siggi.lsst_utils import Bandpass, BandpassDict
import matplotlib.pyplot as plt
import numpy as np
%load_ext autoreload
%autoreload 2
%matplotlib inline

### Set up red and blue spectra

In [None]:
f = filters()
s = spectra()
red_spec = s.get_red_spectrum()
blue_spec = s.get_blue_spectrum()

In [None]:
def flat_prior_2(z):
    return 0.5

In [None]:
sig_example = siggi([red_spec, blue_spec], [0.5, 0.5], flat_prior_2,
                    z_min=0.0, z_max=0.0, z_steps=1)

In [None]:
d_lambda = 25.
x = np.arange(375., 1026., d_lambda)
point_list = []
for val_1 in x:
    for val_2 in x:
        point_list.append([val_1, val_2])
y = np.arange(387.5, 1026., d_lambda)
for val_1 in y:
    point_list.append([val_1, val_1])

In [None]:
%%time
test_rand_state = np.random.RandomState(42)
num_filters = 2
set_ratio = 0.5
res = sig_example.optimize_filters(num_filters=num_filters,
                                   filt_min=300., filt_max=1100.,
                                   sed_mags=22.0,
                                   set_ratio=set_ratio,
                                   system_wavelen_max=1200.,
                                   n_opt_points=15,
                                   optimizer_verbosity=5,
                                   procs=4, acq_func_kwargs_dict={'kappa':1.8},
                                   frozen_filt_dict = None,
                                   starting_points = [[mid-60., mid+60., mid_2-60., mid_2+60.] for mid, mid_2 in point_list],
                                   rand_state=test_rand_state)

##### Providing a random seed introduces reproducibility as long as the same number of processors are used in parallelization.

In [None]:
%%time
test_rand_state = np.random.RandomState(42)
num_filters = 2
set_ratio = 0.5
res = sig_example.optimize_filters(num_filters=num_filters,
                                   filt_min=300., filt_max=1100.,
                                   sed_mags=22.0,
                                   set_ratio=set_ratio,
                                   system_wavelen_max=1200.,
                                   n_opt_points=15,
                                   optimizer_verbosity=5,
                                   procs=4, acq_func_kwargs_dict={'kappa':1.8},
                                   frozen_filt_dict = None,
                                   starting_points = [[mid-60., mid+60., mid_2-60., mid_2+60.] for mid, mid_2 in point_list],
                                   rand_state=test_rand_state)

### Make plots

In [None]:
min_idx = np.argmin(res.yi)

best_pt = res.Xi[min_idx]

In [None]:
best_val = np.min(res.yi)

In [None]:
print(best_pt, best_val)

In [None]:
# Uncomment if red and blue filters are on wrong sides
# best_pt = [best_pt[2], best_pt[3], best_pt[0], best_pt[1]]

In [None]:
sig_plot = plotting([red_spec, blue_spec], best_pt,
                    frozen_filt_dict=None, set_ratio=0.5,
                    sed_mags=22.0)

In [None]:
fig = plt.figure(figsize=(14, 18))
fig.add_subplot(2,1,1)
sig_plot.plot_filters(fig=fig)
ax = plt.gca()
ax.lines[0].set_color('r')
ax.lines[1].set_color('b')
plt.xlabel('Wavelength (nm)', size=20)
plt.ylabel('Transmission Fraction', size=20)
plt.legend(('Red Spectrum', 'Blue Spectrum', 'Filter 1', 'Filter 2'), loc=1)
fig.add_subplot(2,1,2)
plt.rcParams.update({'font.size': 16})
sig_plot.plot_ig_space(res.Xi[:-3], np.abs(res.yi[:-3]), [0,1])
cbar = plt.colorbar()
cbar.set_label('Information Gain (bits)')
plt.scatter(best_pt[0]+50., best_pt[2]+50., c='r', s=64)
plt.xlabel('Filter 1 Center Wavelength (nm)', size=20)
plt.ylabel('Filter 2 Center Wavelength (nm)', size=20)
#plt.savefig('Example_1.pdf')

In [None]:
fig = plt.figure(figsize=(20,16))
fig = sig_plot.plot_color_color(['filter_0', 'filter_1', 'filter_0', 'filter_1'],
                                np.linspace(0.00, 0.0), fig=fig)

In [None]:
sig_plot.filter_dict.magListForSed(blue_spec), sig_plot.filter_dict.magListForSed(red_spec)

### Make sigmoid filter plot

In [None]:
sig_spec = s.get_sigmoid_spectrum()

In [None]:
sig_spec.wavelen[np.where(sig_spec.flambda < 0.01)] = 0.01

In [None]:
def prior_z(z, z0=0.5):
    return (z**2.)*np.exp(-(z/z0)**1.5)/(np.sum((np.arange(0, 2.51, .05)**2.)*np.exp(-(np.arange(0, 2.51, .05)/z0)**1.5)))

In [None]:
plt.plot(np.arange(0.00, 2.51, 0.05), prior_z(np.arange(0.00, 2.51, 0.05)))
plt.xlabel('Redshift')
plt.ylabel('Prior Probability')

In [None]:
plt.plot(sig_spec.wavelen, sig_spec.flambda)
plt.xlim(200, 500)

In [None]:
x = np.arange(400., 1001., 25.)
point_list = []
for val_1 in x:
    for val_2 in x:
        point_list.append([val_1, val_2])
y = np.arange(412.5, 1001., 25.)
for val_1 in y:
    point_list.append([val_1, val_1])

In [None]:
ref_filter = Bandpass()
ref_filter.imsimBandpass()

In [None]:
sig_example = siggi([sig_spec], [1.0], prior_z,# calib_filter=ref_filter,
                    z_min=0.05, z_max=2.5, z_steps=50)

In [None]:
%%time
test_rand_state = np.random.RandomState(2325)
num_filters = 2
set_ratio = 0.5
res_2 = sig_example.optimize_filters(num_filters=num_filters,
                                     filt_min=300., filt_max=1100.,
                                     sed_mags=22.0,
                                     set_ratio=set_ratio,
                                     system_wavelen_max=1200.,
                                     n_opt_points=15,
                                     optimizer_verbosity=5,
                                     procs=16, acq_func_kwargs_dict={'kappa':1.8},
                                     frozen_filt_dict = None,
                                     starting_points = [[mid-50., mid+50., mid_2-50., mid_2+50.] for mid, mid_2 in point_list],
                                     rand_state=test_rand_state)

In [None]:
min_idx = np.argmin(res_2.yi)
best_pt = res_2.Xi[min_idx]
print(best_pt)

In [None]:
# Uncomment if redder filter is first
best_pt = [best_pt[2], best_pt[3], best_pt[0], best_pt[1]]

In [None]:
sig_spec.redshiftSED(0.6)
sig_plot = plotting([sig_spec], best_pt,
                    frozen_filt_dict=None, set_ratio=0.5,
                    sed_mags=22.0)

In [None]:
fig = plt.figure(figsize=(14, 18))
fig.add_subplot(2,1,1)
sig_plot.plot_filters(fig=fig)
ax = plt.gca()
plt.xlabel('Wavelength (nm)', size=20)
plt.ylabel('Transmission Fraction', size=20)
plt.legend(('Sigmoid Spectrum (z=0.6)', 'Filter 1', 'Filter 2'), loc=2, fontsize=15)
fig.add_subplot(2,1,2)
plt.rcParams.update({'font.size': 16})
sig_plot.plot_ig_space(res_2.Xi, np.abs(res_2.yi), [0,1])
cbar = plt.colorbar()
cbar.set_label('Information Gain (bits)')
plt.scatter(best_pt[0]+50., best_pt[2]+50., c='r', s=64)
plt.xlabel('Filter 1 Center Wavelength (nm)', size=20)
plt.ylabel('Filter 2 Center Wavelength (nm)', size=20)
plt.savefig('Example_2.pdf')

In [None]:
from copy import deepcopy
shift_seds = []

sig_spec = s.get_sigmoid_spectrum()
sig_spec.flambda[np.where(sig_spec.flambda < 0.01)] = 0.01

#best_pt = [350., 450., 400., 500.]
#best_pt = [800., 900., 900., 1000.]
#best_pt = [600.0, 700.0, 825.0, 925.0]

sig_plot = plotting([sig_spec], best_pt,
                    frozen_filt_dict=None, set_ratio=0.5,
                    sed_mags=22.0)

for sed_obj in [sig_spec]:
    for z_val in np.linspace(0.05, 2.5, 50):
        sed_copy = deepcopy(sed_obj)
        sed_copy.redshiftSED(z_val)
        shift_seds.append(sed_copy)

calc_ig = calcIG(sig_plot.filter_dict, shift_seds,
                         np.ones(len(shift_seds)),
                         sky_mag=19.0, sed_mags=22.0)
col_x, err_x = calc_ig.calc_colors()


In [None]:
fig = plt.figure(figsize=(12,12))
plt.plot(np.linspace(0.05, 2.5, 50), col_x)
plt.errorbar(np.linspace(0.05, 2.5, 50), col_x, yerr=err_x**.5, ls='', marker='o', color='k')
plt.xlabel('Redshift')
plt.ylabel(r'Color ($F_2$ - $F_{1}$)')

In [None]:
fig = plt.figure(figsize=(20,16))
fig = sig_plot.plot_color_color(['filter_0', 'filter_1', 'filter_0', 'filter_1'],
                                np.linspace(0.05, 2.5, 50), fig=fig)

In [None]:
blue_spec = s.get_blue_spectrum()

In [None]:
x = np.arange(400., 1001., 25.)
point_list = []
for val_1 in x:
    for val_2 in x:
        point_list.append([val_1, val_2])
y = np.arange(412.5, 1001., 25.)
for val_1 in y:
    point_list.append([val_1, val_1])

In [None]:
sig_example = siggi([blue_spec], [1.0], prior_z,
                    z_min=0.05, z_max=2.5, z_steps=50)

In [None]:
%%time
test_rand_state = np.random.RandomState(864)
num_filters = 2
set_ratio = 0.5
res_3 = sig_example.optimize_filters(num_filters=num_filters,
                                     filt_min=300., filt_max=1100.,
                                     sed_mags=22.0,
                                     set_ratio=set_ratio,
                                     system_wavelen_max=1200.,
                                     n_opt_points=15,
                                     optimizer_verbosity=5,
                                     procs=16, acq_func_kwargs_dict={'kappa':1.8},
                                     frozen_filt_dict = None,
                                     starting_points = [[mid-50., mid+50., mid_2-50., mid_2+50.] for mid, mid_2 in point_list],
                                     rand_state=test_rand_state)

In [None]:
min_idx = np.argmin(res_3.yi)
best_pt = res_3.Xi[min_idx]
print(best_pt)

In [None]:
# Uncomment if blue and red filters are reversed
# best_pt = [best_pt[2], best_pt[3], best_pt[0], best_pt[1]]

In [None]:
blue_spec.redshiftSED(0.6)
sig_plot = plotting([blue_spec], best_pt,
                    frozen_filt_dict=None, set_ratio=0.5,
                    sed_mags=22.0)

In [None]:
fig = plt.figure(figsize=(14, 18))
fig.add_subplot(2,1,1)
sig_plot.plot_filters(fig=fig)
ax = plt.gca()
plt.xlabel('Wavelength (nm)', size=20)
plt.ylabel('Transmission Fraction', size=20)
plt.legend(('Test Spectrum', 'Filter 1', 'Filter 2'), loc=1)
fig.add_subplot(2,1,2)
plt.rcParams.update({'font.size': 16})
sig_plot.plot_ig_space(res_3.Xi, np.abs(res_3.yi), [0,1])
cbar = plt.colorbar()
cbar.set_label('Information Gain (bits)')
plt.scatter(best_pt[0]+50., best_pt[2]+50., c='r', s=64)
plt.xlabel('Filter 1 Center Wavelength (nm)', size=20)
plt.ylabel('Filter 2 Center Wavelength (nm)', size=20)
plt.savefig('Example_3.pdf')

In [None]:
fig = plt.figure(figsize=(20,16))
fig = sig_plot.plot_color_color(['filter_0', 'filter_1', 'filter_0', 'filter_1'],
                                np.linspace(0.05, 2.5, 50), fig=fig)