Skip to content

Commit

Permalink
Added classify_OzDES_runs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-muthukrishna committed Apr 15, 2018
1 parent 9363eb1 commit 7b5ea5b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 4 deletions.
65 changes: 65 additions & 0 deletions dash/classify_OzDES_runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import dash


def read_ozdes_wiki_atel(atelTextFile):
""" Read ATEL wiki text, which has been saved to a text file: atelTextFile.
Returns the names, redshifts, and wikiClassifications"""
names = []
redshifts = []
wikiClassifications = []
with open(atelTextFile, 'r') as f:
for line in f:
if line[0:3] == 'DES':
objectInfo = line.split()
name = objectInfo[0]
redshift = objectInfo[1]
wikiClassification = ' '.join(objectInfo[2:])

try:
if redshift != '?': # and 'SN' not in wikiClassification:
names.append(name)
redshifts.append(float(redshift))
wikiClassifications.append(wikiClassification)
except ValueError as e:
print("Invalid redshift for line: {0}".format(line))
raise e

return names, redshifts, wikiClassifications


def main(runDirectory, atelTextFile, saveMatchesFilename):

# Store all the file paths to the objects in this run
directoryPath = os.path.join(os.path.dirname(os.path.abspath(__file__)), runDirectory)
allFilePaths = []
for dirpath, dirnames, filenames in os.walk(directoryPath):
for filename in [f for f in filenames if f.endswith(".dat")]:
allFilePaths.append(os.path.join(dirpath, filename))
allFilePaths.reverse()

# Get filenames and corresponding redshifts
names, knownRedshifts, wikiClassifications = read_ozdes_wiki_atel(atelTextFile)
run = []
for i in range(len(names)):
for filePath in allFilePaths:
if names[i] == filePath.split('/')[-1].split('_')[0]:
run.append((filePath, knownRedshifts[i], wikiClassifications[i]))
#break # Uncomment the break to only classify the last dated spectrum for each object instead of classifying all dates.
filenames = [i[0] for i in run]
knownRedshifts = [i[1] for i in run]
wikiClassifications = [i[2] for i in run]

# Classify and print the best matches
classification = dash.Classify(filenames, knownRedshifts, classifyHost=False, rlapScores=True, smooth=6)
bestFits, redshifts, bestTypes, rlapFlag, matchesFlag = classification.list_best_matches(n=5, saveFilename=saveMatchesFilename)
print("{0:17} | {1:5} | {2:8} | {3:10} | {4:6} | {5:10} | {6:10}".format("Name", " z ", "DASH_Fit", " Age ", "Prob.", "Flag", "Wiki Fit"))
for i in range(len(filenames)):
print("{0:17} | {1:5} | {2:8} | {3:10} | {4:6} | {5:10} | {6:10}".format('_'.join([filenames[i].split('/')[-1].split('_')[0], filenames[i].split('/')[-1].split('_')[3]]) , redshifts[i], bestTypes[i][0], bestTypes[i][1], bestTypes[i][2], matchesFlag[i].replace(' matches',''), wikiClassifications[i]))

# Plot one of the matches
classification.plot_with_gui(indexToPlot=7)


if __name__ == '__main__':
main(runDirectory='../templates/run035/', atelTextFile='wiki_atel_run035.txt', saveMatchesFilename='DASH_matches_run35.txt')
4 changes: 2 additions & 2 deletions dash/create_and_save_all_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
f.write("Redshift: Zero\n")
f.write("Redshift Range: 0 to 0.\n")
f.write("Redshift Precision: 0.01\n")
f.write("Fraction of Training Set Used: 0.9\n")
f.write("Fraction of Training Set Used: 0.8\n")
f.write("Training Amount: 50 x 500000\n")
f.write("Changed wavelength range to 3500 to 10000A\n")
f.write("Changed wavelength range to 3000 to 10000A\n")
f.write("Set outer region to 0.5\n")
dataFilenames.append(modelInfoFilename)

Expand Down
4 changes: 2 additions & 2 deletions dash/create_training_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def all_templates_to_arrays(self):
return arraysShuf, typeAmounts

def sort_data(self):
trainPercentage = 0.9
testPercentage = 0.1
trainPercentage = 0.8
testPercentage = 0.2
validatePercentage = 0.

arrays, typeAmounts = self.all_templates_to_arrays()
Expand Down

0 comments on commit 7b5ea5b

Please sign in to comment.