Skip to content

Commit

Permalink
FIXUP: simplified create pointer, test for warning
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Aug 31, 2015
1 parent 839d7d9 commit 9e1ff04
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 27 deletions.
39 changes: 26 additions & 13 deletions nengo/spa/tests/test_vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
import pytest

from nengo.spa import Vocabulary
from nengo.utils.testing import warns


def test_add():
v = Vocabulary(3)
def test_add(rng):
v = Vocabulary(3, rng=rng)
v.add('A', [1, 2, 3])
v.add('B', [4, 5, 6])
v.add('C', [7, 8, 9])
assert np.allclose(v.vectors, [[1, 2, 3], [4, 5, 6], [7, 8, 9]])


def test_include_pairs():
v = Vocabulary(10)
def test_include_pairs(rng):
v = Vocabulary(10, rng=rng)
v['A']
v['B']
v['C']
Expand All @@ -35,8 +36,8 @@ def test_include_pairs():
assert v.key_pairs == ['A*B', 'A*C', 'B*C']


def test_parse():
v = Vocabulary(64)
def test_parse(rng):
v = Vocabulary(64, rng=rng)
A = v.parse('A')
B = v.parse('B')
C = v.parse('C')
Expand Down Expand Up @@ -64,8 +65,8 @@ def test_invalid_dimensions():
Vocabulary(-1)


def test_identity():
v = Vocabulary(64)
def test_identity(rng):
v = Vocabulary(64, rng=rng)
assert np.allclose(v.identity.v, np.eye(64)[0])


Expand All @@ -92,8 +93,8 @@ def test_text(rng):
assert v.text(v['D'].v) == '1.00D'


def test_capital():
v = Vocabulary(16)
def test_capital(rng):
v = Vocabulary(16, rng=rng)
with pytest.raises(KeyError):
v.parse('a')
with pytest.raises(KeyError):
Expand All @@ -117,13 +118,25 @@ def test_transform(rng):
assert v2.parse('B').compare(np.dot(t, B.v)) > 0.95


def test_prob_cleanup():
v = Vocabulary(64)
def test_prob_cleanup(rng):
v = Vocabulary(64, rng=rng)
assert 1.0 > v.prob_cleanup(0.7, 10000) > 0.9999
assert 0.9999 > v.prob_cleanup(0.6, 10000) > 0.999
assert 0.99 > v.prob_cleanup(0.5, 1000) > 0.9

v = Vocabulary(128)
v = Vocabulary(128, rng=rng)
assert 0.999 > v.prob_cleanup(0.4, 1000) > 0.997
assert 0.99 > v.prob_cleanup(0.4, 10000) > 0.97
assert 0.9 > v.prob_cleanup(0.4, 100000) > 0.8


def test_create_pointer_warning(rng):
v = Vocabulary(2, rng=rng)

# five pointers shouldn't fit
with warns(UserWarning):
v.parse('A')
v.parse('B')
v.parse('C')
v.parse('D')
v.parse('E')
26 changes: 12 additions & 14 deletions nengo/spa/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,18 @@ def create_pointer(self, attempts=100, unitary=False):
pointers is returned.
"""
if self.randomize:
count = 0
p_cand = pointer.SemanticPointer(self.dimensions, rng=self.rng)
p = p_cand
if self.vectors.shape[0] > 0:
while count < attempts:
similarity = np.dot(self.vectors, p_cand.v)
if max(similarity) < self.max_similarity:
p = p_cand
break
elif max(similarity) < max(np.dot(self.vectors, p.v)):
p = p_cand
p_cand = pointer.SemanticPointer(self.dimensions,
rng=self.rng)
count += 1
if self.vectors.shape[0] == 0:
p = pointer.SemanticPointer(self.dimensions, rng=self.rng)
else:
p_sim = np.inf
for _ in range(attempts):
pp = pointer.SemanticPointer(self.dimensions, rng=self.rng)
pp_sim = max(np.dot(self.vectors, pp.v))
if pp_sim < p_sim:
p = pp
p_sim = pp_sim
if p_sim < self.max_similarity:
break
else:
warnings.warn(
'Could not create a semantic pointer with '
Expand Down

0 comments on commit 9e1ff04

Please sign in to comment.