Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Add capnp serialization to CLAClassifierRegion #2855

Merged
merged 2 commits into from Dec 19, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/nupic/algorithms/cla_classifier_factory.py
Expand Up @@ -46,3 +46,22 @@ def create(*args, **kwargs):
else:
raise ValueError('Invalid classifier implementation (%r). Value must be '
'"py" or "cpp".' % impl)


@staticmethod
def read(proto):
"""
proto: CLAClassifierRegionProto capnproto object
"""
impl = proto.classifierImp
if impl == 'py':
return CLAClassifier.read(proto.claClassifier)
elif impl == 'cpp':
instance = FastCLAClassifier()
instance.read(proto.claClassifier)
return instance
elif impl == 'diff':
raise NotImplementedError("CLAClassifierDiff.read not implemented")
else:
raise ValueError('Invalid classifier implementation (%r). Value must be '
'"py" or "cpp".' % impl)
13 changes: 13 additions & 0 deletions src/nupic/regions/CLAClassifierRegion.capnp
@@ -0,0 +1,13 @@
@0x86ee045dcbcfbf3f;

using import "/nupic/proto/ClaClassifier.capnp".ClaClassifierProto;

# Next ID: 6
struct CLAClassifierRegionProto {
classifierImp @0 :Text;
claClassifier @1 :ClaClassifierProto;
steps @2 :Text;
alpha @3 :Float32;
verbosity @4 :UInt32;
maxCategoryCount @5 :UInt32;
}
51 changes: 51 additions & 0 deletions src/nupic/regions/CLAClassifierRegion.py
Expand Up @@ -28,7 +28,16 @@
import warnings

from nupic.bindings.regions.PyRegion import PyRegion
from nupic.bindings.algorithms import FastCLAClassifier
from nupic.algorithms.cla_classifier_factory import CLAClassifierFactory
from nupic.support.configuration import Configuration

try:
import capnp
except ImportError:
capnp = None
if capnp:
from nupic.regions.CLAClassifierRegion_capnp import CLAClassifierRegionProto



Expand Down Expand Up @@ -198,7 +207,12 @@ def __init__(self,
maxCategoryCount=None
):

# Set default implementation
if implementation is None:
implementation = Configuration.get('nupic.opf.claClassifier.implementation')

# Convert the steps designation to a list
self.classifierImp = implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use classifierImpl (with trailing "l") for consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPRegion uses spatialImp and TPRegion uses temporalImp, so I chose classifierImp to be consistent with that. Would you still change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine as is.

self.steps = steps
self.stepsList = eval("[%s]" % (steps))
self.alpha = alpha
Expand Down Expand Up @@ -270,6 +284,43 @@ def setParameter(self, name, index, value):
return PyRegion.setParameter(self, name, index, value)


def write(self, proto):
"""Write state to proto object.

proto: PyRegionProto capnproto object
"""
regionImpl = proto.regionImpl.as_struct(CLAClassifierRegionProto)

regionImpl.classifierImp = self.classifierImp
regionImpl.steps = self.steps
regionImpl.alpha = self.alpha
regionImpl.verbosity = self.verbosity
regionImpl.maxCategoryCount = self.maxCategoryCount

self._claClassifier.write(regionImpl.claClassifier)


@classmethod
def read(cls, proto):
"""Read state from proto object.

proto: PyRegionProto capnproto object
"""
regionImpl = proto.regionImpl.as_struct(CLAClassifierRegionProto)

instance = cls()

instance.classifierImp = regionImpl.classifierImp
instance.steps = regionImpl.steps
instance.alpha = regionImpl.alpha
instance.verbosity = regionImpl.verbosity
instance.maxCategoryCount = regionImpl.maxCategoryCount

instance._claClassifier = CLAClassifierFactory.read(regionImpl)

return instance


def reset(self):
pass

Expand Down