Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #2800 from chetan51/fix_checkpoint_test
Browse files Browse the repository at this point in the history
Fix opf_checkpoint_test.py
  • Loading branch information
chetan51 committed Dec 16, 2015
2 parents 728bd19 + 630942d commit 24ec928
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/nupic/frameworks/opf/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,15 @@ def _getModelCheckpointDir(experimentDir, checkpointLabel):
Returns:
absolute path to the serialization directory
"""
checkpointDir = os.path.join(_getCheckpointParentDir(experimentDir),
checkpointDir = os.path.join(getCheckpointParentDir(experimentDir),
checkpointLabel + g_defaultCheckpointExtension)
checkpointDir = os.path.abspath(checkpointDir)

return checkpointDir



def _getCheckpointParentDir(experimentDir):
def getCheckpointParentDir(experimentDir):
"""Get checkpoint parent dir.
Returns: absolute path to the base serialization directory within which
Expand Down Expand Up @@ -539,7 +539,7 @@ def _isCheckpointDir(checkpointDir):

def _printAvailableCheckpoints(experimentDir):
"""List available checkpoints for the specified experiment."""
checkpointParentDir = _getCheckpointParentDir(experimentDir)
checkpointParentDir = getCheckpointParentDir(experimentDir)

if not os.path.exists(checkpointParentDir):
print "No available checkpoints."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import shutil

from nupic.data.file_record_stream import FileRecordStream
from nupic.frameworks.opf.experiment_runner import runExperiment
from nupic.frameworks.opf.experiment_runner import runExperiment, getCheckpointParentDir
from nupic.support import initLogging
from nupic.support.unittesthelpers.testcasebase import (
unittest, TestCaseBase as HelperTestCaseBase)
Expand All @@ -45,22 +45,6 @@ def shortDescription(self):
return None


@staticmethod
def getOpfNonTemporalPredictionFilepath(experimentDir, taskLabel):
path = os.path.join(experimentDir,
"inference",
"%s.nontemporal.predictionLog.csv" % taskLabel)
return os.path.abspath(path)


@staticmethod
def getOpfTemporalPredictionFilepath(experimentDir, taskLabel):
path = os.path.join(experimentDir,
"inference",
"%s.temporal.predictionLog.csv" % taskLabel)
return os.path.abspath(path)


def compareOPFPredictionFiles(self, path1, path2, temporal,
maxMismatches=None):
""" Compare temporal or non-temporal predictions for the given experiment
Expand Down Expand Up @@ -384,6 +368,11 @@ def _testSamePredictions(self, experiment, predSteps, checkpointAt,
except StopIteration:
break

# clean up model checkpoint directories
shutil.rmtree(getCheckpointParentDir(aExpDir))
shutil.rmtree(getCheckpointParentDir(bExpDir))
shutil.rmtree(getCheckpointParentDir(aPlusBExpDir))

print "Predictions match!"


Expand Down

0 comments on commit 24ec928

Please sign in to comment.