Skip to content

Commit

Permalink
Fixed bug when redshifts argument to Classify() is empty.
Browse files Browse the repository at this point in the history
Raised error when trying to use the (Agnostic redshift & Classify Host) model until it is included in the package.
  • Loading branch information
daniel-muthukrishna committed Mar 27, 2019
1 parent bfbd2b7 commit 6b8ca6d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions astrodash/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self, filenames=[], redshifts=[], smooth=6, minWave=3500, maxWave=1
self.knownZ = True
else:
self.knownZ = False
if not self.redshifts:
self.redshifts = [None] * len(filenames)
self.rlapScores = rlapScores
self.pars = get_training_parameters()
self.nw, w0, w1 = self.pars['nw'], self.pars['w0'], self.pars['w1']
Expand All @@ -54,8 +56,11 @@ def __init__(self, filenames=[], redshifts=[], smooth=6, minWave=3500, maxWave=1
"models/zeroZ/tensorflow_model.ckpt")
else:
if self.classifyHost:
self.modelFilename = os.path.join(self.scriptDirectory, data_files,
"models/agnosticZ_classifyHost/tensorflow_model.ckpt")
raise ValueError("A model that classifies the host while simulatenously not knowing redshift does not "
"exist currently. Please try one of the other 3 models or check back at a later "
"date. Contact the author for support or further queries.")
# self.modelFilename = os.path.join(self.scriptDirectory, data_files,
# "models/agnosticZ_classifyHost/tensorflow_model.ckpt")
else:
self.modelFilename = os.path.join(self.scriptDirectory, data_files,
"models/agnosticZ/tensorflow_model.ckpt")
Expand Down Expand Up @@ -103,7 +108,7 @@ def list_best_matches(self, n=5, saveFilename='DASH_matches.txt'):
for i in range(20):
host, name, age = classification_split(bestTypes[specNum][i])
if not self.knownZ:
redshift, _, redshiftErr = self.calc_redshift(inputImages[i], name, age, inputMinMaxIndexes[i])[0]
redshift, _, redshiftErr = self.calc_redshift(inputImages[i], name, age, inputMinMaxIndexes[i])
redshifts.append(redshift)
redshiftErrs.append(redshiftErr)
prob = softmaxes[specNum][i]
Expand All @@ -125,8 +130,10 @@ def list_best_matches(self, n=5, saveFilename='DASH_matches.txt'):

if not redshifts:
redshifts = self.redshifts
redshiftErrs = [None] * len(self.redshifts)
else:
redshifts = np.array(redshifts)
redshiftErrs = np.array(redshiftErrs)

if saveFilename:
self.save_best_matches(bestMatchLists, redshifts, bestBroadTypes, rlapLabels, matchesReliableLabels,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
long_description = f.read()

setup(name='astrodash',
version='1.0.14',
version='1.0.15',
description='Deep Learning for Automated Spectral Classification of Supernovae',
long_description=long_description,
url='https://github.com/daniel-muthukrishna/DASH',
Expand Down

0 comments on commit 6b8ca6d

Please sign in to comment.