Skip to content
Browse files

write tryprune.py in progress

  • Loading branch information...
1 parent 1143427 commit 9cca4180400774231f64363c0f6acb87e5444f23 @epico epico committed Jul 25, 2011
Showing with 55 additions and 3 deletions.
  1. +55 −3 tryprune.py
View
58 tryprune.py
@@ -1,5 +1,7 @@
#!/usr/bin/python3
import os
+import os.path
+import shutil
import sys
from subprocess import Popen, PIPE
from argparse import ArgumentParser
@@ -45,7 +47,17 @@ def exportModel(modelfile, textmodel):
sys.exit('Corrupted model found when exporting:' + modelfile)
#end processing
-def mergeOneModel(mergedmodel, onemodel):
+def mergeOneModel(mergedmodel, onemodel, score):
+ #validate first
+ validateModel(onemodel)
+
+ onemodelstatuspath = onemodel + config.getStatusPostfix()
+ onemodelstatus = utils.load_status(onemodelstatuspath)
+ if not utils.check_epoch(onemodelstatus, 'Estimate'):
+ raise utils.Epoch('Please estimate first.\n')
+ if score != onemodelstatus['EstimateScore']:
+ raise AssertionError('estimate scores mis-match.\n')
+
#begin processing
cmdline = ['./merge_k_mixture_model', \
'--result-file', \
@@ -59,8 +71,28 @@ def mergeOneModel(mergedmodel, onemodel):
sys.exit('Corrupted model found when merging:' + onemodel)
#end processing
-def mergeSomeModels(indexfile, mergenum):
- pass
+def mergeSomeModels(tryname, mergedmodel, sortedindexname, mergenum):
+ last_score = 1.
+ #begin processing
+ indexfile = open(sortedindexname, 'r')
+ for i in range(mergenum):
+ line = indexfile.readline()
+ if not line:
+ raise AssertionError('No more models.\n')
+ line = line.rstrip(os.linesep)
+ (subdir, modelname, score) = line.split('#', 2)
+ score = float(score)
+ if score > last_score:
+ raise AssertionError('score must be descending.\n')
+
+ onemodel = os.path.join(config.getModelDir(), subdir, modelname)
+ mergeOneModel(mergedmodel, onemodel, score)
+ last_score = score
+ indexfile.close()
+ #end processing
+
+ #validate merged model
+ validateModel(mergedmodel)
def pruneModel(modelfile, k, CDF):
#begin processing
@@ -98,3 +130,23 @@ def pruneModel(modelfile, k, CDF):
args = parser.parse_args()
print(args)
+ tryname = 'try' + args.tryname
+ #merge model candidates
+ mergedmodel = os.path.join(config.getFinalDir(), tryname, 'merged.db')
+ sortedindexname = os.path.join(args.modeldir, \
+ config.getSortedEstimateIndex())
+ mergeSomeModels(tryname, mergedmodel, sortedindexname, args.mergenumber)
+
+ #export textual format
+ exportfile = os.path.join(config.getFinalDir(), tryname, 'kmm_merged.text')
+ exportModel(mergedmodel, exportfile)
+
+ #prune merged model
+ prunedmodel = os.path.join(config.getFinalDir(), tryname, 'pruned.db')
+ #backup merged model
+ shutil.copyfile(mergedmodel, prunedmodel)
+ pruneModel(prunedmodel, args.k, args.CDF)
+
+ #export textual format
+ exportfile = os.path.join(config.getFinalDir(), tryname, 'kmm_pruned.text')
+ exportModel(prunedmodel, exportModel)

0 comments on commit 9cca418

Please sign in to comment.
Something went wrong with that request. Please try again.