Skip to content

Commit

Permalink
Memory mapping while creating arrays. Training agnostic model was pre…
Browse files Browse the repository at this point in the history
…viously stalling due to memory.

Fixed minor bug in rlap warning in classfiy.py
  • Loading branch information
daniel-muthukrishna committed Jun 8, 2018
1 parent 117dd61 commit 0e48da7
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 23 deletions.
2 changes: 1 addition & 1 deletion dash/analyse_all_ozdes_atels.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def main(spectraDir, atelTextFile, saveMatchesFilename):
print(row.Name)

# Classify and print the best matches
classification = dash.Classify(filenames, knownRedshifts, classifyHost=False, rlapScores=True, smooth=6)
classification = dash.Classify(filenames, knownRedshifts, classifyHost=False, rlapScores=True, smooth=6, knownZ=True)
bestFits, redshifts, bestTypes, rlapFlag, matchesFlag = classification.list_best_matches(n=5, saveFilename=saveMatchesFilename)

print("{0:17} | {1:5} | {2:8} | {3:10} | {4:6} | {5:10} | {6:10}".format("Name", " z ", "DASH_Fit", " Age ", "Prob.", "Flag", "Wiki Fit"))
Expand Down
1 change: 1 addition & 0 deletions dash/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def rlap_warning_label(self, bestType, inputImage, inputMinMaxIndex):
rlapLabel = "No rlap"
else:
rlapLabel = "(NO_TEMPLATES)"
rlapWarningBool = "None"

# import matplotlib
# matplotlib.use('TkAgg')
Expand Down
8 changes: 4 additions & 4 deletions dash/create_and_save_all_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


if __name__ == '__main__':
modelName = 'new_zeroZ_classifyHost'
modelName = 'new_agnosticZ'
trainWithHost = True
classifyHost = True
classifyHost = False
minZ = 0.
maxZ = 0.
numOfRedshifts = 1
maxZ = 0.8
numOfRedshifts = 50
trainFraction = 1.0
numTrainBatches = 2000000

Expand Down
50 changes: 32 additions & 18 deletions dash/create_arrays.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import glob
import numpy as np
from random import shuffle
import multiprocessing as mp
Expand Down Expand Up @@ -292,19 +294,21 @@ def __init__(self, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, min
self.createLabels = CreateLabels(self.nTypes, self.minAge, self.maxAge, self.ageBinSize, self.typeList, hostTypes, nHostTypes)
self.hostTypes = hostTypes

# TODO: Maybe do memory mapping for these arrays
self.images = []
self.labelsIndexes = []
self.filenames = []
self.typeNames = []

def combined_sn_gal_templates_to_arrays(self, args):
snTemplateLocation, snTempList, galTemplateLocation, galTempList, snFractions = args
images = np.empty((0, int(self.nw)), np.float16) # Number of pixels
labelsIndexes = []
filenames = []
typeNames = []
agesList = []

randnum = np.random.randint(10000)
arraySize = len(galTempList) * len(snTempList) * 50 * len(snFractions) * self.numOfRedshifts
images = np.memmap('images_{}_{}.dat'.format(snTempList[0], randnum), dtype=np.float16, mode='w+', shape=(arraySize, int(self.nw)))
labelsIndexes = np.memmap('labels_{}_{}.dat'.format(snTempList[0], randnum), dtype=np.uint16, mode='w+', shape=arraySize)
filenames = np.memmap('filenames_{}_{}.dat'.format(snTempList[0], randnum), dtype=object, mode='w+', shape=arraySize)
typeNames = np.memmap('typeNames_{}_{}.dat'.format(snTempList[0], randnum), dtype=object, mode='w+', shape=arraySize)
nRows = 0

for j in range(len(galTempList)):
galFilename = galTemplateLocation + galTempList[j] if galTemplateLocation is not None else None
Expand All @@ -322,7 +326,6 @@ def combined_sn_gal_templates_to_arrays(self, args):
redshifts = np.random.uniform(low=self.minZ, high=self.maxZ, size=self.numOfRedshifts)
for z in redshifts:
tempWave, tempFlux, nCols, ages, tType, tMinIndex, tMaxIndex = readSpectra.sn_plus_gal_template(ageidx, snCoeff, galCoeff, z)
agesList.append(ages[ageidx])
if tMinIndex == tMaxIndex or not tempFlux.any():
print("NO DATA for {} {} ageIdx:{} z>={}".format(galTempList[j], snTempList[i], ageidx, z))
break
Expand All @@ -337,21 +340,22 @@ def combined_sn_gal_templates_to_arrays(self, args):
nonzeroflux = tempFlux[tMinIndex:tMaxIndex + 1]
newflux = (nonzeroflux - min(nonzeroflux)) / (max(nonzeroflux) - min(nonzeroflux))
newflux2 = np.concatenate((tempFlux[0:tMinIndex], newflux, tempFlux[tMaxIndex + 1:]))
images = np.append(images, np.array([newflux2]), axis=0)
labelsIndexes.append(labelIndex) # labels = np.append(labels, np.array([label]), axis=0)
filenames.append("{0}_{1}_{2}_{3}_snCoeff{4}_z{5}".format(snTempList[i], tType, str(ages[ageidx]), galTempList[j], snCoeff, (z)))
typeNames.append(typeName)
images[nRows] = np.array([newflux2])
labelsIndexes[nRows] = labelIndex
filenames[nRows] = "{0}_{1}_{2}_{3}_snCoeff{4}_z{5}".format(snTempList[i], tType, str(ages[ageidx]), galTempList[j], snCoeff, (z))
typeNames[nRows] = typeName
nRows += 1
print(snTempList[i], nCols, galTempList[j])

return images, np.array(labelsIndexes).astype(int), np.array(filenames), np.array(typeNames)
return images, np.array(labelsIndexes).astype(int), np.array(filenames), np.array(typeNames), nRows

def collect_results(self, result):
"""Uses apply_async's callback to setup up a separate Queue for each process"""
imagesPart, labelsPart, filenamesPart, typeNamesPart = result
self.images.extend(imagesPart)
self.labelsIndexes.extend(labelsPart)
self.filenames.extend(filenamesPart)
self.typeNames.extend(typeNamesPart)
imagesPart, labelsPart, filenamesPart, typeNamesPart, nRows = result
self.images.extend(imagesPart[0:nRows])
self.labelsIndexes.extend(labelsPart[0:nRows])
self.filenames.extend(filenamesPart[0:nRows])
self.typeNames.extend(typeNamesPart[0:nRows])

def combined_sn_gal_arrays_multiprocessing(self, snTemplateLocation, snTempFileList, galTemplateLocation, galTempFileList):
if galTemplateLocation is None or galTempFileList is None:
Expand All @@ -376,7 +380,7 @@ def combined_sn_gal_arrays_multiprocessing(self, snTemplateLocation, snTempFileL
outputs = results.get()
for i, output in enumerate(outputs):
self.collect_results(output)
print('combining results...', i, len(outputs))
print('combining results...', output[-1], i, len(outputs))

self.images = np.array(self.images)
self.labelsIndexes = np.array(self.labelsIndexes)
Expand All @@ -385,4 +389,14 @@ def combined_sn_gal_arrays_multiprocessing(self, snTemplateLocation, snTempFileL

print("Completed Creating Arrays!")

# Delete temporary memory mapping files
for filename in glob.glob('images_*.dat'):
os.remove(filename)
for filename in glob.glob('labels*.dat'):
os.remove(filename)
for filename in glob.glob('filenames_*.dat'):
os.remove(filename)
for filename in glob.glob('typeNames_*.dat'):
os.remove(filename)

return self.images, self.labelsIndexes.astype(np.uint16), self.filenames, self.typeNames

0 comments on commit 0e48da7

Please sign in to comment.