diff --git a/docs/source/index.rst b/docs/source/index.rst index 2179c1a90..782f4d703 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -24,6 +24,8 @@ Session API .. autoclass:: neo4j.v1.Record :members: +.. autofunction:: neo4j.v1.record + .. autoclass:: neo4j.v1.Result :members: diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 55b048023..04a293a78 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -262,9 +262,7 @@ def fetch_next(self): # Unpack from the raw byte stream and call the relevant message handler(s) raw.seek(0) response = self.responses[0] - for message in unpack(): - signature = message.signature - fields = tuple(message) + for signature, fields in unpack(): if __debug__: log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields))) handler_name = "on_%s" % message_names[signature].lower() diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 31b3ab33e..70f6c099c 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -473,48 +473,95 @@ def close(self): self.closed = True self.session.transaction = None - class Record(object): - """ Record object for storing result values along with field names. + """ Record is an ordered collection of fields. + + A Record object is used for storing result values along with field names. Fields can be accessed by numeric or named index (``record[0]`` or - ``record["field"]``) or by attribute (``record.field``). + ``record["field"]``). """ def __init__(self, keys, values): - self.__keys__ = keys - self.__values__ = values + self._keys = tuple(keys) + self._values = tuple(values) + + def keys(self): + """ Return the keys (key names) of the record + """ + return self._keys + + def values(self): + """ Return the values of the record + """ + return self._values + + def items(self): + """ Return the fields of the record as a list of key and value tuples + """ + return zip(self._keys, self._values) + + def index(self, key): + """ Return the index of the given key + """ + try: + return self._keys.index(key) + except ValueError: + raise KeyError(key) + + def __record__(self): + return self + + def __contains__(self, key): + return self._keys.__contains__(key) + + def __iter__(self): + return iter(self._keys) + + def copy(self): + return Record(self._keys, self._values) + + def __getitem__(self, item): + if isinstance(item, string): + return self._values[self.index(item)] + elif isinstance(item, integer): + return self._values[item] + else: + raise TypeError(item) + + def __len__(self): + return len(self._keys) def __repr__(self): - values = self.__values__ + values = self._values s = [] - for i, field in enumerate(self.__keys__): + for i, field in enumerate(self._keys): s.append("%s=%r" % (field, values[i])) return "" % " ".join(s) + def __hash__(self): + return hash(self._keys) ^ hash(self._values) + def __eq__(self, other): try: - return vars(self) == vars(other) - except TypeError: - return tuple(self) == tuple(other) + return self._keys == tuple(other.keys()) and self._values == tuple(other.values()) + except AttributeError: + return False def __ne__(self, other): return not self.__eq__(other) - def __len__(self): - return self.__keys__.__len__() +def record(obj): + """ Obtain an immutable record for the given object + (either by calling obj.__record__() or by copying out the record data) + """ + try: + return obj.__record__() + except AttributeError: + keys = obj.keys() + values = [] + for key in keys: + values.append(obj[key]) + return Record(keys, values) + - def __getitem__(self, item): - if isinstance(item, string): - return getattr(self, item) - elif isinstance(item, integer): - return getattr(self, self.__keys__[item]) - else: - raise TypeError(item) - def __getattr__(self, item): - try: - i = self.__keys__.index(item) - except ValueError: - raise AttributeError("No key %r" % item) - else: - return self.__values__[i] diff --git a/test/test_session.py b/test/test_session.py index ef3f39089..736075a28 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -21,7 +21,7 @@ from unittest import TestCase -from neo4j.v1.session import GraphDatabase, CypherError +from neo4j.v1.session import GraphDatabase, CypherError, Record, record from neo4j.v1.typesystem import Node, Relationship, Path @@ -36,11 +36,11 @@ def test_can_run_simple_statement(self): for record in session.run("RETURN 1 AS n"): assert record[0] == 1 assert record["n"] == 1 - with self.assertRaises(AttributeError): + with self.assertRaises(KeyError): + _ = record["x"] + assert record["n"] == 1 + with self.assertRaises(KeyError): _ = record["x"] - assert record.n == 1 - with self.assertRaises(AttributeError): - _ = record.x with self.assertRaises(TypeError): _ = record[object()] assert repr(record) @@ -77,7 +77,6 @@ def test_can_run_simple_statement_from_bytes_string(self): for record in session.run(b"RETURN 1 AS n"): assert record[0] == 1 assert record["n"] == 1 - assert record.n == 1 assert repr(record) assert len(record) == 1 count += 1 @@ -138,12 +137,6 @@ def test_can_handle_cypher_error(self): with self.assertRaises(CypherError): session.run("X") - def test_record_equality(self): - with GraphDatabase.driver("bolt://localhost").session() as session: - result = session.run("unwind([1, 1]) AS a RETURN a") - assert result[0] == result[1] - assert result[0] != "this is not a record" - def test_can_obtain_summary_info(self): with GraphDatabase.driver("bolt://localhost").session() as session: result = session.run("CREATE (n) RETURN n") @@ -211,6 +204,79 @@ def test_can_obtain_notification_info(self): assert position.column == 1 +class RecordTestCase(TestCase): + def test_record_equality(self): + record1 = Record(["name","empire"], ["Nigel", "The British Empire"]) + record2 = Record(["name","empire"], ["Nigel", "The British Empire"]) + record3 = Record(["name","empire"], ["Stefan", "Das Deutschland"]) + assert record1 == record2 + assert record1 != record3 + assert record2 != record3 + + def test_record_hashing(self): + record1 = Record(["name","empire"], ["Nigel", "The British Empire"]) + record2 = Record(["name","empire"], ["Nigel", "The British Empire"]) + record3 = Record(["name","empire"], ["Stefan", "Das Deutschland"]) + assert hash(record1) == hash(record2) + assert hash(record1) != hash(record3) + assert hash(record2) != hash(record3) + + def test_record_keys(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert list(aRecord.keys()) == ["name", "empire"] + + def test_record_values(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert list(aRecord.values()) == ["Nigel", "The British Empire"] + + def test_record_items(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert list(aRecord.items()) == [("name", "Nigel"), ("empire", "The British Empire")] + + def test_record_index(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert aRecord.index("name") == 0 + assert aRecord.index("empire") == 1 + with self.assertRaises(KeyError): + aRecord.index("crap") + + def test_record_contains(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert "name" in aRecord + assert "empire" in aRecord + assert "Germans" not in aRecord + + def test_record_iter(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert list(aRecord.__iter__()) == ["name", "empire"] + + def test_record_record(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert record(aRecord) is aRecord + + def test_record_copy(self): + original = Record(["name","empire"], ["Nigel", "The British Empire"]) + duplicate = original.copy() + assert dict(original) == dict(duplicate) + assert original.keys() == duplicate.keys() + assert original is not duplicate + + def test_record_as_dict(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert dict(aRecord) == { "name": "Nigel", "empire": "The British Empire" } + + def test_record_as_list(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert list(aRecord) == ["name", "empire"] + + def test_record_len(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert len(aRecord) == 2 + + def test_record_repr(self): + aRecord = Record(["name","empire"], ["Nigel", "The British Empire"]) + assert repr(aRecord) == "" + class TransactionTestCase(TestCase): def test_can_commit_transaction(self): with GraphDatabase.driver("bolt://localhost").session() as session: