Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SP creation to use correct # of attempts #817

Merged
merged 1 commit into from Sep 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Expand Up @@ -135,6 +135,10 @@ Release History
- Corrected the ``rmses`` values in ``BuiltConnection.solver_info`` when using
``NNls`` and ``Nnl2sL2`` solvers, and the ``reg`` argument for ``Nnl2sL2``.
(`#839 <https://github.com/nengo/nengo/pull/839>`_)
- ``spa.Vocabulary.create_pointer`` now respects the specified number of
creation attempts, and returns the most dissimilar pointer if none can be
found below the similarity threshold.
(`#817 <https://github.com/nengo/nengo/pull/817>`_)

2.0.1 (January 27, 2015)
========================
Expand Down
39 changes: 26 additions & 13 deletions nengo/spa/tests/test_vocabulary.py
Expand Up @@ -4,18 +4,19 @@
import pytest

from nengo.spa import Vocabulary
from nengo.utils.testing import warns


def test_add():
v = Vocabulary(3)
def test_add(rng):
v = Vocabulary(3, rng=rng)
v.add('A', [1, 2, 3])
v.add('B', [4, 5, 6])
v.add('C', [7, 8, 9])
assert np.allclose(v.vectors, [[1, 2, 3], [4, 5, 6], [7, 8, 9]])


def test_include_pairs():
v = Vocabulary(10)
def test_include_pairs(rng):
v = Vocabulary(10, rng=rng)
v['A']
v['B']
v['C']
Expand All @@ -35,8 +36,8 @@ def test_include_pairs():
assert v.key_pairs == ['A*B', 'A*C', 'B*C']


def test_parse():
v = Vocabulary(64)
def test_parse(rng):
v = Vocabulary(64, rng=rng)
A = v.parse('A')
B = v.parse('B')
C = v.parse('C')
Expand Down Expand Up @@ -64,8 +65,8 @@ def test_invalid_dimensions():
Vocabulary(-1)


def test_identity():
v = Vocabulary(64)
def test_identity(rng):
v = Vocabulary(64, rng=rng)
assert np.allclose(v.identity.v, np.eye(64)[0])


Expand All @@ -92,8 +93,8 @@ def test_text(rng):
assert v.text(v['D'].v) == '1.00D'


def test_capital():
v = Vocabulary(16)
def test_capital(rng):
v = Vocabulary(16, rng=rng)
with pytest.raises(KeyError):
v.parse('a')
with pytest.raises(KeyError):
Expand All @@ -117,13 +118,25 @@ def test_transform(rng):
assert v2.parse('B').compare(np.dot(t, B.v)) > 0.95


def test_prob_cleanup():
v = Vocabulary(64)
def test_prob_cleanup(rng):
v = Vocabulary(64, rng=rng)
assert 1.0 > v.prob_cleanup(0.7, 10000) > 0.9999
assert 0.9999 > v.prob_cleanup(0.6, 10000) > 0.999
assert 0.99 > v.prob_cleanup(0.5, 1000) > 0.9

v = Vocabulary(128)
v = Vocabulary(128, rng=rng)
assert 0.999 > v.prob_cleanup(0.4, 1000) > 0.997
assert 0.99 > v.prob_cleanup(0.4, 10000) > 0.97
assert 0.9 > v.prob_cleanup(0.4, 100000) > 0.8


def test_create_pointer_warning(rng):
v = Vocabulary(2, rng=rng)

# five pointers shouldn't fit
with warns(UserWarning):
v.parse('A')
v.parse('B')
v.parse('C')
v.parse('D')
v.parse('E')
26 changes: 16 additions & 10 deletions nengo/spa/vocab.py
Expand Up @@ -86,18 +86,24 @@ def create_pointer(self, attempts=100, unitary=False):
"""Create a new semantic pointer.

This will take into account the randomize and max_similarity
parameters from self.
parameters from self. If a pointer satisfying max_similarity
is not generated after the specified number of attempts, the
candidate pointer with lowest maximum cosine with all existing
pointers is returned.
"""
if self.randomize:
count = 0
p = pointer.SemanticPointer(self.dimensions, rng=self.rng)
if self.vectors.shape[0] > 0:
while count < 100:
similarity = np.dot(self.vectors, p.v)
if max(similarity) < self.max_similarity:
break
p = pointer.SemanticPointer(self.dimensions, rng=self.rng)
count += 1
if self.vectors.shape[0] == 0:
p = pointer.SemanticPointer(self.dimensions, rng=self.rng)
else:
p_sim = np.inf
for _ in range(attempts):
pp = pointer.SemanticPointer(self.dimensions, rng=self.rng)
pp_sim = max(np.dot(self.vectors, pp.v))
if pp_sim < p_sim:
p = pp
p_sim = pp_sim
if p_sim < self.max_similarity:
break
else:
warnings.warn(
'Could not create a semantic pointer with '
Expand Down