Skip to content

Commit

Permalink
Added training fraction as an argument to create_training_set
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-muthukrishna committed May 13, 2018
1 parent d381821 commit df25945
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
8 changes: 4 additions & 4 deletions dash/create_and_save_all_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


if __name__ == '__main__':
modelName = 'zeroZ'
modelName = 'zeroZ_trainOnAll'
dataDirName = os.path.join(scriptDirectory, 'data_files_{0}/'.format(modelName))
dataFilenames = []
if not os.path.exists(dataDirName):
Expand All @@ -28,8 +28,8 @@
f.write("Classify Host: False\n")
f.write("Redshift: Zero\n")
f.write("Redshift Range: 0. to 0.\n")
f.write("Num of Redshifts: 50\n")
f.write("Fraction of Training Set Used: 0.\n")
f.write("Num of Redshifts: 1\n")
f.write("Fraction of Training Set Used: 1.0\n")
f.write("Training Amount: 50 x 500000\n")
f.write("Changed wavelength range to 3000 to 10000A\n")
f.write("Set outer region to 0.5\n")
Expand All @@ -44,7 +44,7 @@
print("time spent: {0:.2f}".format(t2 - t1))

# CREATE TRAINING SET FILES
trainingSetFilename = create_training_set_files(dataDirName, minZ=0., maxZ=0., numOfRedshifts=1, trainWithHost=True, classifyHost=False)
trainingSetFilename = create_training_set_files(dataDirName, minZ=0., maxZ=0., numOfRedshifts=1, trainWithHost=True, classifyHost=False, trainFraction=1.0)
dataFilenames.append(trainingSetFilename)
t3 = time.time()
print("time spent: {0:.2f}".format(t3 - t2))
Expand Down
17 changes: 9 additions & 8 deletions dash/create_training_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class CreateTrainingSet(object):

def __init__(self, snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation, galTempFileList, hostTypes, nHostTypes):
def __init__(self, snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation, galTempFileList, hostTypes, nHostTypes, trainFraction):
self.snidTemplateLocation = snidTemplateLocation
self.snidTempFileList = snidTempFileList
self.galTemplateLocation = galTemplateLocation
Expand All @@ -21,6 +21,7 @@ def __init__(self, snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, m
self.maxAge = maxAge
self.ageBinSize = ageBinSize
self.typeList = typeList
self.trainFraction = trainFraction
self.ageBinning = AgeBinning(self.minAge, self.maxAge, self.ageBinSize)
self.numOfAgeBins = self.ageBinning.age_bin(self.maxAge-0.1) + 1
self.nLabels = self.nTypes * self.numOfAgeBins * nHostTypes
Expand All @@ -42,8 +43,8 @@ def all_templates_to_arrays(self):
return arraysShuf, typeAmounts

def sort_data(self):
trainPercentage = 0.8
testPercentage = 0.2
trainPercentage = self.trainFraction
testPercentage = 1.0 - self.trainFraction
validatePercentage = 0.

arrays, typeAmounts = self.all_templates_to_arrays()
Expand Down Expand Up @@ -74,7 +75,7 @@ def sort_data(self):


class SaveTrainingSet(object):
def __init__(self, snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation=None, galTempFileList=None, hostTypes=None, nHostTypes=1):
def __init__(self, snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation=None, galTempFileList=None, hostTypes=None, nHostTypes=1, trainFraction=0.8):
self.snidTemplateLocation = snidTemplateLocation
self.snidTempFileList = snidTempFileList
self.w0 = w0
Expand All @@ -87,7 +88,7 @@ def __init__(self, snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, m
self.typeList = typeList
self.createLabels = CreateLabels(nTypes, minAge, maxAge, ageBinSize, typeList, hostTypes, nHostTypes)

self.createTrainingSet = CreateTrainingSet(snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation, galTempFileList, hostTypes, nHostTypes)
self.createTrainingSet = CreateTrainingSet(snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation, galTempFileList, hostTypes, nHostTypes, trainFraction)
self.sortData = self.createTrainingSet.sort_data()
self.trainImages = self.sortData[0][0]
self.trainLabels = self.sortData[0][1]
Expand Down Expand Up @@ -141,7 +142,7 @@ def save_arrays(self, saveFilename):
os.remove(filename)


def create_training_set_files(dataDirName, minZ=0, maxZ=0, numOfRedshifts=80, trainWithHost=True, classifyHost=False):
def create_training_set_files(dataDirName, minZ=0, maxZ=0, numOfRedshifts=80, trainWithHost=True, classifyHost=False, trainFraction=0.8):
with open(os.path.join(dataDirName, 'training_params.pickle'), 'rb') as f1:
pars = pickle.load(f1)
nTypes, w0, w1, nw, minAge, maxAge, ageBinSize, typeList = pars['nTypes'], pars['w0'], pars['w1'], \
Expand All @@ -162,7 +163,7 @@ def create_training_set_files(dataDirName, minZ=0, maxZ=0, numOfRedshifts=80, tr
else:
galTemplateLocation, galTempFileList = None, None

saveTrainingSet = SaveTrainingSet(snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation, galTempFileList, hostList, nHostTypes)
saveTrainingSet = SaveTrainingSet(snidTemplateLocation, snidTempFileList, w0, w1, nw, nTypes, minAge, maxAge, ageBinSize, typeList, minZ, maxZ, numOfRedshifts, galTemplateLocation, galTempFileList, hostList, nHostTypes, trainFraction)
typeNamesList, typeAmounts = saveTrainingSet.type_amounts()

saveFilename = os.path.join(dataDirName, 'training_set.zip')
Expand All @@ -172,4 +173,4 @@ def create_training_set_files(dataDirName, minZ=0, maxZ=0, numOfRedshifts=80, tr


if __name__ == '__main__':
trainingSetFilename = create_training_set_files('data_files/', minZ=0, maxZ=0, numOfRedshifts=80, trainWithHost=False, classifyHost=False)
trainingSetFilename = create_training_set_files('data_files/', minZ=0, maxZ=0, numOfRedshifts=80, trainWithHost=False, classifyHost=False, trainFraction=0.8)

0 comments on commit df25945

Please sign in to comment.