Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #3739 from scottpurdy/get-schema
Browse files Browse the repository at this point in the history
Convert getProtoType to getSchema for consistency
  • Loading branch information
scottpurdy committed Jul 7, 2017
2 parents f65c2da + d6a01f1 commit d756394
Show file tree
Hide file tree
Showing 12 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ prettytable==0.7.2

# When updating nupic.bindings, also update any shared dependencies to keep
# versions in sync.
nupic.bindings==0.7.2
nupic.bindings==1.0.0
numpy==1.12.1
2 changes: 1 addition & 1 deletion src/nupic/frameworks/opf/htm_prediction_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def __setstate__(self, state):


@staticmethod
def getProtoType():
def getSchema():
return HTMPredictionModelProto


Expand Down
10 changes: 5 additions & 5 deletions src/nupic/frameworks/opf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def isInferenceEnabled(self):
return self.__inferenceEnabled

@staticmethod
def getProtoType():
def getSchema():
"""Return the pycapnp proto type that the class uses for serialization.
This is used to convert the proto into the proper type before passing it
Expand All @@ -236,7 +236,7 @@ def _getModelCheckpointFilePath(checkpointDir):

def writeToCheckpoint(self, checkpointDir):
"""Serializes model using capnproto and writes data to ``checkpointDir``"""
proto = self.getProtoType().new_message()
proto = self.getSchema().new_message()

self.write(proto)

Expand Down Expand Up @@ -268,7 +268,7 @@ def readFromCheckpoint(cls, checkpointDir):
checkpointPath = cls._getModelCheckpointFilePath(checkpointDir)

with open(checkpointPath, 'r') as f:
proto = cls.getProtoType().read(f)
proto = cls.getSchema().read(f)

model = cls.read(proto)
return model
Expand All @@ -294,7 +294,7 @@ def writeBaseToProto(self, proto):
def write(self, proto):
"""Write state to proto object.
The type of proto is determined by :meth:`getProtoType`.
The type of proto is determined by :meth:`getSchema`.
"""
raise NotImplementedError()

Expand All @@ -303,7 +303,7 @@ def write(self, proto):
def read(cls, proto):
"""Read state from proto object.
The type of proto is determined by :meth:`getProtoType`.
The type of proto is determined by :meth:`getSchema`.
"""
raise NotImplementedError()

Expand Down
2 changes: 1 addition & 1 deletion src/nupic/frameworks/opf/previous_value_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def resetSequenceStates(self):


@staticmethod
def getProtoType():
def getSchema():
return PreviousValueModelProto


Expand Down
2 changes: 1 addition & 1 deletion src/nupic/frameworks/opf/two_gram_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def resetSequenceStates(self):


@staticmethod
def getProtoType():
def getSchema():
return TwoGramModelProto


Expand Down
2 changes: 1 addition & 1 deletion src/nupic/regions/knn_anomaly_classifier_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def readFromProto(cls, proto):


@staticmethod
def getProtoType():
def getSchema():
return KNNAnomalyClassifierRegionProto


Expand Down
2 changes: 1 addition & 1 deletion src/nupic/regions/knn_classifier_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ def getOutputElementCount(self, name):
raise Exception('Unknown output: ' + name)

@staticmethod
def getProtoType():
def getSchema():
return KNNClassifierRegionProto

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/nupic/regions/record_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,9 @@ def setParameter(self, parameterName, index, parameterValue):


@staticmethod
def getProtoType():
def getSchema():
"""
Overrides :meth:`nupic.bindings.regions.PyRegion.PyRegion.getProtoType`.
Overrides :meth:`nupic.bindings.regions.PyRegion.PyRegion.getSchema`.
"""
return RecordSensorProto

Expand Down
2 changes: 1 addition & 1 deletion src/nupic/regions/sdr_classifier_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def setParameter(self, name, index, value):


@staticmethod
def getProtoType():
def getSchema():
"""
:returns: the pycapnp proto type that the class uses for serialization.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/nupic/regions/sp_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,9 @@ def setParameter(self, parameterName, index, parameterValue):


@staticmethod
def getProtoType():
def getSchema():
"""
Overrides :meth:`~nupic.bindings.regions.PyRegion.PyRegion.getProtoType`.
Overrides :meth:`~nupic.bindings.regions.PyRegion.PyRegion.getSchema`.
"""
return SPRegionProto

Expand Down
4 changes: 2 additions & 2 deletions src/nupic/regions/tm_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,9 @@ def finishLearning(self):


@staticmethod
def getProtoType():
def getSchema():
"""
Overrides :meth:`~nupic.bindings.regions.PyRegion.PyRegion.getProtoType`.
Overrides :meth:`~nupic.bindings.regions.PyRegion.PyRegion.getSchema`.
"""
return TMRegionProto

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/nupic/frameworks/opf/previous_value_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def testCapnpWriteRead(self):
# Deserialize
m2 = previous_value_model.PreviousValueModel.read(readerProto)

self.assertIs(m1.getProtoType(), PreviousValueModelProto)
self.assertIs(m2.getProtoType(), PreviousValueModelProto)
self.assertIs(m1.getSchema(), PreviousValueModelProto)
self.assertIs(m2.getSchema(), PreviousValueModelProto)

self.assertEqual(m2._numPredictions, m1._numPredictions)
self.assertEqual(m2.getInferenceType(), m1.getInferenceType())
Expand Down

0 comments on commit d756394

Please sign in to comment.