Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Cleanup passthrough encoder and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtopus committed Apr 10, 2015
1 parent 0c5b45e commit fb8b3a9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 28 deletions.
44 changes: 20 additions & 24 deletions nupic/encoders/pass_through_encoder.py
Expand Up @@ -25,16 +25,15 @@



############################################################################
class PassThroughEncoder(Encoder):
"""Pass an encoded SDR straight to the model.
Each encoding is an SDR in which w out of n bits are turned on.
The input should be a 1-D array or numpy.ndarray of length n
"""

############################################################################
def __init__(self, n, w=None, name="pass_through", forced=False, verbosity=0):
def __init__(self, n, w=None, name="pass_through", forced=False,
verbosity=0):
"""
n -- is the total #bits in output
w -- is used to normalize the sparsity of the output, exactly w bits ON,
Expand All @@ -49,73 +48,70 @@ def __init__(self, n, w=None, name="pass_through", forced=False, verbosity=0):
self.encoders = None
self.forced = forced

############################################################################

def getDecoderOutputFieldTypes(self):
""" [Encoder class virtual method override]
"""
return (FieldMetaType.string,)

############################################################################

def getWidth(self):
return self.n

############################################################################

def getDescription(self):
return self.description

############################################################################

def getScalars(self, input):
""" See method description in base.py """
return numpy.array([0])

############################################################################

def getBucketIndices(self, input):
""" See method description in base.py """
return [0]

############################################################################
def encodeIntoArray(self, input, output):

def encodeIntoArray(self, inputVal, outputVal):
"""See method description in base.py"""
if len(input) != len(output):
if len(inputVal) != len(outputVal):
raise ValueError("Different input (%i) and output (%i) sizes." % (
len(input), len(output)))
len(inputVal), len(outputVal)))

if self.w is not None and sum(input) != self.w:
if self.w is not None and sum(inputVal) != self.w:
raise ValueError("Input has %i bits but w was set to %i." % (
sum(input), self.w))
sum(inputVal), self.w))

output[:] = input[:]
outputVal[:] = inputVal[:]

if self.verbosity >= 2:
print "input:", input, "output:", output
print "decoded:", self.decodedToStr(self.decode(output))
print "input:", inputVal, "output:", outputVal
print "decoded:", self.decodedToStr(self.decode(outputVal))


############################################################################
def decode(self, encoded, parentFieldName=""):
"""See the function description in base.py"""

if parentFieldName != "":
fieldName = "%s.%s" % (parentFieldName, self.name)
else:
fieldName = self.name
# TODO: these methods should be properly implemented

return ({fieldName: ([[0, 0]], "input")}, [fieldName])


############################################################################
def getBucketInfo(self, buckets):
"""See the function description in base.py"""
return [EncoderResult(value=0, scalar=0,
encoding=numpy.zeros(self.n))]
return [EncoderResult(value=0, scalar=0, encoding=numpy.zeros(self.n))]


############################################################################
def topDownCompute(self, encoded):
"""See the function description in base.py"""
return EncoderResult(value=0, scalar=0,
encoding=numpy.zeros(self.n))

############################################################################

def closenessScores(self, expValues, actValues, **kwargs):
"""Does a bitwise compare of the two bitmaps and returns a fractonal
value between 0 and 1 of how similar they are.
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/nupic/encoders/pass_through_encoder_test.py
Expand Up @@ -24,7 +24,6 @@

CL_VERBOSITY = 0

import cPickle as pickle
import tempfile
import unittest2 as unittest

Expand Down Expand Up @@ -64,9 +63,9 @@ def testEncodeBitArray(self):
bitmap[3] = 1
bitmap[5] = 1
out = e.encode(bitmap)
sum_expected = sum(bitmap)
sum_real = out.sum()
self.assertEqual(sum_real, sum_expected)
expectedSum = sum(bitmap)
realSum = out.sum()
self.assertEqual(realSum, expectedSum)


def testClosenessScores(self):
Expand Down

0 comments on commit fb8b3a9

Please sign in to comment.