diff --git a/tests/test_client.py b/tests/test_client.py index 2c3bb5d4..7ca3f58a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -584,7 +584,7 @@ def test_add_with_boost(self): self.assertEqual(len(res), 5) self.assertEqual('doc_6', res.docs[0]['id']) - def test_field_update(self): + def test_field_update_inc(self): originalDocs = self.solr.search('doc') self.assertEqual(len(originalDocs), 3) updateList = [] @@ -601,6 +601,25 @@ def test_field_update(self): self.assertEqual(True, all(updatedDoc[k] == originalDoc[k] for k in updatedDoc.keys() if k not in ['_version_', 'popularity'])) + def test_field_update_set(self): + originalDocs = self.solr.search('doc') + updated_popularity = 10 + self.assertEqual(len(originalDocs), 3) + updateList = [] + for i, doc in enumerate(originalDocs): + updateList.append({'id': doc['id'], 'popularity': updated_popularity}) + self.solr.add(updateList, fieldUpdates={'popularity': 'set'}) + + updatedDocs = self.solr.search('doc') + self.assertEqual(len(updatedDocs), 3) + for i, (originalDoc, updatedDoc) in enumerate(zip(originalDocs, updatedDocs)): + self.assertEqual(len(updatedDoc.keys()), len(originalDoc.keys())) + self.assertEqual(updatedDoc['popularity'], updated_popularity) + # TODO: change this to use assertSetEqual: + self.assertEqual(True, all(updatedDoc[k] == originalDoc[k] for k in updatedDoc.keys() + if k not in ['_version_', 'popularity'])) + + def test_field_update_add(self): self.solr.add([ { 'id': 'multivalued_1',