Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Session API
.. autoclass:: neo4j.v1.Record
:members:

.. autofunction:: neo4j.v1.record

.. autoclass:: neo4j.v1.Result
:members:

Expand Down
4 changes: 1 addition & 3 deletions neo4j/v1/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,7 @@ def fetch_next(self):
# Unpack from the raw byte stream and call the relevant message handler(s)
raw.seek(0)
response = self.responses[0]
for message in unpack():
signature = message.signature
fields = tuple(message)
for signature, fields in unpack():
if __debug__:
log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields)))
handler_name = "on_%s" % message_names[signature].lower()
Expand Down
99 changes: 73 additions & 26 deletions neo4j/v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,48 +473,95 @@ def close(self):
self.closed = True
self.session.transaction = None


class Record(object):
""" Record object for storing result values along with field names.
""" Record is an ordered collection of fields.

A Record object is used for storing result values along with field names.
Fields can be accessed by numeric or named index (``record[0]`` or
``record["field"]``) or by attribute (``record.field``).
``record["field"]``).
"""

def __init__(self, keys, values):
self.__keys__ = keys
self.__values__ = values
self._keys = tuple(keys)
self._values = tuple(values)

def keys(self):
""" Return the keys (key names) of the record
"""
return self._keys

def values(self):
""" Return the values of the record
"""
return self._values

def items(self):
""" Return the fields of the record as a list of key and value tuples
"""
return zip(self._keys, self._values)

def index(self, key):
""" Return the index of the given key
"""
try:
return self._keys.index(key)
except ValueError:
raise KeyError(key)

def __record__(self):
return self

def __contains__(self, key):
return self._keys.__contains__(key)

def __iter__(self):
return iter(self._keys)

def copy(self):
return Record(self._keys, self._values)

def __getitem__(self, item):
if isinstance(item, string):
return self._values[self.index(item)]
elif isinstance(item, integer):
return self._values[item]
else:
raise TypeError(item)

def __len__(self):
return len(self._keys)

def __repr__(self):
values = self.__values__
values = self._values
s = []
for i, field in enumerate(self.__keys__):
for i, field in enumerate(self._keys):
s.append("%s=%r" % (field, values[i]))
return "<Record %s>" % " ".join(s)

def __hash__(self):
return hash(self._keys) ^ hash(self._values)

def __eq__(self, other):
try:
return vars(self) == vars(other)
except TypeError:
return tuple(self) == tuple(other)
return self._keys == tuple(other.keys()) and self._values == tuple(other.values())
except AttributeError:
return False

def __ne__(self, other):
return not self.__eq__(other)

def __len__(self):
return self.__keys__.__len__()
def record(obj):
""" Obtain an immutable record for the given object
(either by calling obj.__record__() or by copying out the record data)
"""
try:
return obj.__record__()
except AttributeError:
keys = obj.keys()
values = []
for key in keys:
values.append(obj[key])
return Record(keys, values)


def __getitem__(self, item):
if isinstance(item, string):
return getattr(self, item)
elif isinstance(item, integer):
return getattr(self, self.__keys__[item])
else:
raise TypeError(item)

def __getattr__(self, item):
try:
i = self.__keys__.index(item)
except ValueError:
raise AttributeError("No key %r" % item)
else:
return self.__values__[i]
90 changes: 78 additions & 12 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from unittest import TestCase

from neo4j.v1.session import GraphDatabase, CypherError
from neo4j.v1.session import GraphDatabase, CypherError, Record, record
from neo4j.v1.typesystem import Node, Relationship, Path


Expand All @@ -36,11 +36,11 @@ def test_can_run_simple_statement(self):
for record in session.run("RETURN 1 AS n"):
assert record[0] == 1
assert record["n"] == 1
with self.assertRaises(AttributeError):
with self.assertRaises(KeyError):
_ = record["x"]
assert record["n"] == 1
with self.assertRaises(KeyError):
_ = record["x"]
assert record.n == 1
with self.assertRaises(AttributeError):
_ = record.x
with self.assertRaises(TypeError):
_ = record[object()]
assert repr(record)
Expand Down Expand Up @@ -77,7 +77,6 @@ def test_can_run_simple_statement_from_bytes_string(self):
for record in session.run(b"RETURN 1 AS n"):
assert record[0] == 1
assert record["n"] == 1
assert record.n == 1
assert repr(record)
assert len(record) == 1
count += 1
Expand Down Expand Up @@ -138,12 +137,6 @@ def test_can_handle_cypher_error(self):
with self.assertRaises(CypherError):
session.run("X")

def test_record_equality(self):
with GraphDatabase.driver("bolt://localhost").session() as session:
result = session.run("unwind([1, 1]) AS a RETURN a")
assert result[0] == result[1]
assert result[0] != "this is not a record"

def test_can_obtain_summary_info(self):
with GraphDatabase.driver("bolt://localhost").session() as session:
result = session.run("CREATE (n) RETURN n")
Expand Down Expand Up @@ -211,6 +204,79 @@ def test_can_obtain_notification_info(self):
assert position.column == 1


class RecordTestCase(TestCase):
def test_record_equality(self):
record1 = Record(["name","empire"], ["Nigel", "The British Empire"])
record2 = Record(["name","empire"], ["Nigel", "The British Empire"])
record3 = Record(["name","empire"], ["Stefan", "Das Deutschland"])
assert record1 == record2
assert record1 != record3
assert record2 != record3

def test_record_hashing(self):
record1 = Record(["name","empire"], ["Nigel", "The British Empire"])
record2 = Record(["name","empire"], ["Nigel", "The British Empire"])
record3 = Record(["name","empire"], ["Stefan", "Das Deutschland"])
assert hash(record1) == hash(record2)
assert hash(record1) != hash(record3)
assert hash(record2) != hash(record3)

def test_record_keys(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert list(aRecord.keys()) == ["name", "empire"]

def test_record_values(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert list(aRecord.values()) == ["Nigel", "The British Empire"]

def test_record_items(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert list(aRecord.items()) == [("name", "Nigel"), ("empire", "The British Empire")]

def test_record_index(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert aRecord.index("name") == 0
assert aRecord.index("empire") == 1
with self.assertRaises(KeyError):
aRecord.index("crap")

def test_record_contains(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert "name" in aRecord
assert "empire" in aRecord
assert "Germans" not in aRecord

def test_record_iter(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert list(aRecord.__iter__()) == ["name", "empire"]

def test_record_record(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert record(aRecord) is aRecord

def test_record_copy(self):
original = Record(["name","empire"], ["Nigel", "The British Empire"])
duplicate = original.copy()
assert dict(original) == dict(duplicate)
assert original.keys() == duplicate.keys()
assert original is not duplicate

def test_record_as_dict(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert dict(aRecord) == { "name": "Nigel", "empire": "The British Empire" }

def test_record_as_list(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert list(aRecord) == ["name", "empire"]

def test_record_len(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert len(aRecord) == 2

def test_record_repr(self):
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
assert repr(aRecord) == "<Record name='Nigel' empire='The British Empire'>"

class TransactionTestCase(TestCase):
def test_can_commit_transaction(self):
with GraphDatabase.driver("bolt://localhost").session() as session:
Expand Down