diff --git a/nengo/spa/tests/test_vocabulary.py b/nengo/spa/tests/test_vocabulary.py index 6ef6d9a0f7..9a3a957849 100644 --- a/nengo/spa/tests/test_vocabulary.py +++ b/nengo/spa/tests/test_vocabulary.py @@ -4,18 +4,19 @@ import pytest from nengo.spa import Vocabulary +from nengo.utils.testing import warns -def test_add(): - v = Vocabulary(3) +def test_add(rng): + v = Vocabulary(3, rng=rng) v.add('A', [1, 2, 3]) v.add('B', [4, 5, 6]) v.add('C', [7, 8, 9]) assert np.allclose(v.vectors, [[1, 2, 3], [4, 5, 6], [7, 8, 9]]) -def test_include_pairs(): - v = Vocabulary(10) +def test_include_pairs(rng): + v = Vocabulary(10, rng=rng) v['A'] v['B'] v['C'] @@ -35,8 +36,8 @@ def test_include_pairs(): assert v.key_pairs == ['A*B', 'A*C', 'B*C'] -def test_parse(): - v = Vocabulary(64) +def test_parse(rng): + v = Vocabulary(64, rng=rng) A = v.parse('A') B = v.parse('B') C = v.parse('C') @@ -64,8 +65,8 @@ def test_invalid_dimensions(): Vocabulary(-1) -def test_identity(): - v = Vocabulary(64) +def test_identity(rng): + v = Vocabulary(64, rng=rng) assert np.allclose(v.identity.v, np.eye(64)[0]) @@ -92,8 +93,8 @@ def test_text(rng): assert v.text(v['D'].v) == '1.00D' -def test_capital(): - v = Vocabulary(16) +def test_capital(rng): + v = Vocabulary(16, rng=rng) with pytest.raises(KeyError): v.parse('a') with pytest.raises(KeyError): @@ -117,13 +118,25 @@ def test_transform(rng): assert v2.parse('B').compare(np.dot(t, B.v)) > 0.95 -def test_prob_cleanup(): - v = Vocabulary(64) +def test_prob_cleanup(rng): + v = Vocabulary(64, rng=rng) assert 1.0 > v.prob_cleanup(0.7, 10000) > 0.9999 assert 0.9999 > v.prob_cleanup(0.6, 10000) > 0.999 assert 0.99 > v.prob_cleanup(0.5, 1000) > 0.9 - v = Vocabulary(128) + v = Vocabulary(128, rng=rng) assert 0.999 > v.prob_cleanup(0.4, 1000) > 0.997 assert 0.99 > v.prob_cleanup(0.4, 10000) > 0.97 assert 0.9 > v.prob_cleanup(0.4, 100000) > 0.8 + + +def test_create_pointer_warning(rng): + v = Vocabulary(2, rng=rng) + + # five pointers shouldn't fit + with warns(UserWarning): + v.parse('A') + v.parse('B') + v.parse('C') + v.parse('D') + v.parse('E') diff --git a/nengo/spa/vocab.py b/nengo/spa/vocab.py index 0c9349a8e4..5f9600cf92 100644 --- a/nengo/spa/vocab.py +++ b/nengo/spa/vocab.py @@ -92,20 +92,18 @@ def create_pointer(self, attempts=100, unitary=False): pointers is returned. """ if self.randomize: - count = 0 - p_cand = pointer.SemanticPointer(self.dimensions, rng=self.rng) - p = p_cand - if self.vectors.shape[0] > 0: - while count < attempts: - similarity = np.dot(self.vectors, p_cand.v) - if max(similarity) < self.max_similarity: - p = p_cand - break - elif max(similarity) < max(np.dot(self.vectors, p.v)): - p = p_cand - p_cand = pointer.SemanticPointer(self.dimensions, - rng=self.rng) - count += 1 + if self.vectors.shape[0] == 0: + p = pointer.SemanticPointer(self.dimensions, rng=self.rng) + else: + p_sim = np.inf + for _ in range(attempts): + pp = pointer.SemanticPointer(self.dimensions, rng=self.rng) + pp_sim = max(np.dot(self.vectors, pp.v)) + if pp_sim < p_sim: + p = pp + p_sim = pp_sim + if p_sim < self.max_similarity: + break else: warnings.warn( 'Could not create a semantic pointer with '