Skip to content

Commit

Permalink
Replace Ia-02cx label with Iax.
Browse files Browse the repository at this point in the history
Add citation to documentation.
Change confusion matrix metrics font
  • Loading branch information
daniel-muthukrishna committed Sep 19, 2019
1 parent 22938a4 commit c4a7287
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
17 changes: 12 additions & 5 deletions astrodash/model_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
import itertools
import numpy as np
Expand Down Expand Up @@ -29,11 +30,17 @@ def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix'
np.savetxt(os.path.join(fig_dir, 'confusion_matrix_%s.csv' % name), cm)
print(cm)

plt.rcParams['text.usetex'] = True
plt.rcParams['font.serif'] = ['Computer Modern Roman'] + plt.rcParams['font.serif']
font = {'family': 'normal',
'size': 16}
matplotlib.rc('font', **font)

fig = plt.figure(figsize=(15, 12))
plt.imshow(cm, interpolation='nearest', cmap=cmap, vmin=-1, vmax=1)
plt.title(title)
cb = plt.colorbar()
cb.ax.set_yticklabels(cb.ax.get_yticklabels(), fontsize=16)
cb.ax.tick_params(labelsize=23)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90, fontsize=fontsize_labels)
plt.yticks(tick_marks, classes, fontsize=fontsize_labels)
Expand All @@ -45,8 +52,8 @@ def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix'
color="white" if abs(cm[i, j]) > thresh else "black", fontsize=fontsize_matrix)

plt.tight_layout()
plt.ylabel('True label', fontsize=18)
plt.xlabel('Predicted label', fontsize=18)
plt.ylabel('True label', fontsize=26)
plt.xlabel('Predicted label', fontsize=26)
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, 'confusion_matrix_%s.pdf' % name))

Expand Down Expand Up @@ -91,14 +98,14 @@ def calc_model_metrics(modelFilename, testLabels, testImages, testTypeNames, typ
if confMatrixAggregateAges.shape[0] < len(classnames):
classnames = classnames[:-1]
plot_confusion_matrix(confMatrixAggregateAges, classes=classnames, normalize=True, title='', fig_dir=fig_dir,
name='aggregate_ages', fontsize_labels=15, fontsize_matrix=16)
name='aggregate_ages', fontsize_labels=23, fontsize_matrix=21)

# Aggregate age and subtypes conf matrix
aggregateSubtypesIndexes = np.array([0, 108, 180, 234, 306])
broadTypes = ['Ia', 'Ib', 'Ic', 'II']
confMatrixAggregateSubtypes = get_aggregated_conf_matrix(aggregateSubtypesIndexes, testLabels, predictedLabels)
plot_confusion_matrix(confMatrixAggregateSubtypes, classes=broadTypes, normalize=True, title='',
fig_dir=fig_dir, name='aggregate_subtypes', fontsize_labels=30, fontsize_matrix=30)
fig_dir=fig_dir, name='aggregate_subtypes', fontsize_labels=35, fontsize_matrix=35)
# plt.show()

np.set_printoptions(precision=2)
Expand Down
2 changes: 2 additions & 0 deletions astrodash/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def snid_template_spectra_all(self):

if ttype == 'Ia-99aa':
ttype = 'Ia-91T'
elif ttype == 'Ia-02cx':
ttype = 'Iax'

return wave, fluxes, numAges, ages, ttype, splineInfo

Expand Down
2 changes: 1 addition & 1 deletion astrodash/training_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def create_training_params_file(dataDirName):
parameters = {
'typeList': ['Ia-norm', 'Ia-91T', 'Ia-91bg', 'Ia-csm', 'Ia-02cx', 'Ia-pec',
'typeList': ['Ia-norm', 'Ia-91T', 'Ia-91bg', 'Ia-csm', 'Iax', 'Ia-pec',
'Ib-norm', 'Ibn', 'IIb', 'Ib-pec', 'Ic-norm', 'Ic-broad',
'Ic-pec', 'IIP', 'IIL', 'IIn', 'II-pec'],
'nTypes': 17,
Expand Down
6 changes: 6 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ License

The project is licensed under the MIT license.

Citation
--------

You can cite the following paper for this work: https://ui.adsabs.harvard.edu/abs/2019arXiv190302557M/abstract


Author
------
Daniel Muthukrishna
Expand Down
2 changes: 2 additions & 0 deletions templates/training_set/create_templist.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def snid_template_spectra_all(filename):

if ttype == 'Ia-99aa':
ttype = 'Ia-91T'
elif ttype == 'Ia-02cx':
ttype = 'Iax'

return wave, fluxes, numAges, ages, ttype, splineInfo

Expand Down

0 comments on commit c4a7287

Please sign in to comment.