diff --git a/CHANGES.rst b/CHANGES.rst index e7ca80022a..54c6f7d742 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -135,6 +135,10 @@ Release History - Corrected the ``rmses`` values in ``BuiltConnection.solver_info`` when using ``NNls`` and ``Nnl2sL2`` solvers, and the ``reg`` argument for ``Nnl2sL2``. (`#839 `_) +- ``spa.Vocabulary.create_pointer`` now respects the specified number of + creation attempts, and returns the most dissimilar pointer if none can be + found below the similarity threshold. + (`#817 `_) 2.0.1 (January 27, 2015) ======================== diff --git a/nengo/spa/tests/test_vocabulary.py b/nengo/spa/tests/test_vocabulary.py index 6ef6d9a0f7..9a3a957849 100644 --- a/nengo/spa/tests/test_vocabulary.py +++ b/nengo/spa/tests/test_vocabulary.py @@ -4,18 +4,19 @@ import pytest from nengo.spa import Vocabulary +from nengo.utils.testing import warns -def test_add(): - v = Vocabulary(3) +def test_add(rng): + v = Vocabulary(3, rng=rng) v.add('A', [1, 2, 3]) v.add('B', [4, 5, 6]) v.add('C', [7, 8, 9]) assert np.allclose(v.vectors, [[1, 2, 3], [4, 5, 6], [7, 8, 9]]) -def test_include_pairs(): - v = Vocabulary(10) +def test_include_pairs(rng): + v = Vocabulary(10, rng=rng) v['A'] v['B'] v['C'] @@ -35,8 +36,8 @@ def test_include_pairs(): assert v.key_pairs == ['A*B', 'A*C', 'B*C'] -def test_parse(): - v = Vocabulary(64) +def test_parse(rng): + v = Vocabulary(64, rng=rng) A = v.parse('A') B = v.parse('B') C = v.parse('C') @@ -64,8 +65,8 @@ def test_invalid_dimensions(): Vocabulary(-1) -def test_identity(): - v = Vocabulary(64) +def test_identity(rng): + v = Vocabulary(64, rng=rng) assert np.allclose(v.identity.v, np.eye(64)[0]) @@ -92,8 +93,8 @@ def test_text(rng): assert v.text(v['D'].v) == '1.00D' -def test_capital(): - v = Vocabulary(16) +def test_capital(rng): + v = Vocabulary(16, rng=rng) with pytest.raises(KeyError): v.parse('a') with pytest.raises(KeyError): @@ -117,13 +118,25 @@ def test_transform(rng): assert v2.parse('B').compare(np.dot(t, B.v)) > 0.95 -def test_prob_cleanup(): - v = Vocabulary(64) +def test_prob_cleanup(rng): + v = Vocabulary(64, rng=rng) assert 1.0 > v.prob_cleanup(0.7, 10000) > 0.9999 assert 0.9999 > v.prob_cleanup(0.6, 10000) > 0.999 assert 0.99 > v.prob_cleanup(0.5, 1000) > 0.9 - v = Vocabulary(128) + v = Vocabulary(128, rng=rng) assert 0.999 > v.prob_cleanup(0.4, 1000) > 0.997 assert 0.99 > v.prob_cleanup(0.4, 10000) > 0.97 assert 0.9 > v.prob_cleanup(0.4, 100000) > 0.8 + + +def test_create_pointer_warning(rng): + v = Vocabulary(2, rng=rng) + + # five pointers shouldn't fit + with warns(UserWarning): + v.parse('A') + v.parse('B') + v.parse('C') + v.parse('D') + v.parse('E') diff --git a/nengo/spa/vocab.py b/nengo/spa/vocab.py index 0271269d40..5f9600cf92 100644 --- a/nengo/spa/vocab.py +++ b/nengo/spa/vocab.py @@ -86,18 +86,24 @@ def create_pointer(self, attempts=100, unitary=False): """Create a new semantic pointer. This will take into account the randomize and max_similarity - parameters from self. + parameters from self. If a pointer satisfying max_similarity + is not generated after the specified number of attempts, the + candidate pointer with lowest maximum cosine with all existing + pointers is returned. """ if self.randomize: - count = 0 - p = pointer.SemanticPointer(self.dimensions, rng=self.rng) - if self.vectors.shape[0] > 0: - while count < 100: - similarity = np.dot(self.vectors, p.v) - if max(similarity) < self.max_similarity: - break - p = pointer.SemanticPointer(self.dimensions, rng=self.rng) - count += 1 + if self.vectors.shape[0] == 0: + p = pointer.SemanticPointer(self.dimensions, rng=self.rng) + else: + p_sim = np.inf + for _ in range(attempts): + pp = pointer.SemanticPointer(self.dimensions, rng=self.rng) + pp_sim = max(np.dot(self.vectors, pp.v)) + if pp_sim < p_sim: + p = pp + p_sim = pp_sim + if p_sim < self.max_similarity: + break else: warnings.warn( 'Could not create a semantic pointer with '