diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 31754cbef..6d3f42d36 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -135,10 +135,12 @@ def protocol_handlers(cls, protocol_version=None): # Carry out Bolt subclass imports locally to avoid circular dependency issues. from neo4j.io._bolt3 import Bolt3 from neo4j.io._bolt4x0 import Bolt4x0 + from neo4j.io._bolt4x1 import Bolt4x1 handlers = { Bolt3.PROTOCOL_VERSION: Bolt3, - Bolt4x0.PROTOCOL_VERSION: Bolt4x0 + Bolt4x0.PROTOCOL_VERSION: Bolt4x0, + Bolt4x1.PROTOCOL_VERSION: Bolt4x1, } if protocol_version is None: @@ -203,6 +205,10 @@ def open(cls, address, *, auth=None, timeout=None, **pool_config): # Carry out Bolt subclass imports locally to avoid circular dependency issues. from neo4j.io._bolt4x0 import Bolt4x0 connection = Bolt4x0(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent) + elif pool_config.protocol_version == (4, 1): + # Carry out Bolt subclass imports locally to avoid circular dependency issues. + from neo4j.io._bolt4x1 import Bolt4x1 + connection = Bolt4x1(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent) else: log.debug("[#%04X] S: ", s.getpeername()[1]) s.shutdown(SHUT_RDWR) @@ -650,6 +656,7 @@ def fetch_routing_info(self, *, address, timeout, database): # Carry out Bolt subclass imports locally to avoid circular dependency issues. from neo4j.io._bolt3 import Bolt3 from neo4j.io._bolt4x0 import Bolt4x0 + from neo4j.io._bolt4x1 import Bolt4x1 from neo4j.api import ( SYSTEM_DATABASE, @@ -686,7 +693,7 @@ def fail(md): on_success=metadata.update, on_failure=fail, ) - elif cx.PROTOCOL_VERSION == Bolt4x0.PROTOCOL_VERSION: + elif cx.PROTOCOL_VERSION in (Bolt4x0.PROTOCOL_VERSION, Bolt4x1.PROTOCOL_VERSION): if database == DEFAULT_DATABASE: cx.run( "CALL dbms.routing.getRoutingTable($context)", diff --git a/neo4j/io/_bolt4x1.py b/neo4j/io/_bolt4x1.py new file mode 100644 index 000000000..15a056995 --- /dev/null +++ b/neo4j/io/_bolt4x1.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2020 "Neo4j," +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from select import select +from ssl import SSLSocket +from struct import pack as struct_pack +from time import perf_counter +from neo4j.api import ( + Version, + READ_ACCESS, + WRITE_ACCESS, +) +from neo4j.io._courier import MessageInbox +from neo4j.meta import get_user_agent +from neo4j.exceptions import ( + Neo4jError, + AuthError, + ServiceUnavailable, + DatabaseUnavailable, + NotALeader, + ForbiddenOnReadOnlyDatabase, + SessionExpired, +) +from neo4j._exceptions import ( + BoltIncompleteCommitError, + BoltProtocolError, +) +from neo4j.packstream import ( + Unpacker, + Packer, +) +from neo4j.io import ( + Bolt, + BoltPool, +) +from neo4j.conf import PoolConfig +from neo4j.api import ServerInfo +from neo4j.addressing import Address + +from logging import getLogger +log = getLogger("neo4j") + + +class Bolt4x1(Bolt): + + PROTOCOL_VERSION = Version(4, 1) + + # The socket + in_use = False + + # The socket + _closed = False + + # The socket + _defunct = False + + #: The pool of which this connection is a member + pool = None + + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None): + self.unresolved_address = unresolved_address + self.socket = sock + self.server_info = ServerInfo(Address(sock.getpeername()), Bolt4x1.PROTOCOL_VERSION) + self.outbox = Outbox() + self.inbox = Inbox(self.socket, on_error=self._set_defunct) + self.packer = Packer(self.outbox) + self.unpacker = Unpacker(self.inbox) + self.responses = deque() + self._max_connection_lifetime = max_connection_lifetime # self.pool_config.max_connection_lifetime + self._creation_timestamp = perf_counter() + self.supports_multiple_results = True + self.supports_multiple_databases = True + self._is_reset = True + + # Determine the user agent + if user_agent: + self.user_agent = user_agent + else: + self.user_agent = get_user_agent() + + # Determine auth details + if not auth: + self.auth_dict = {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from neo4j import Auth + self.auth_dict = vars(Auth("basic", *auth)) + else: + try: + self.auth_dict = vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + + # Check for missing password + try: + credentials = self.auth_dict["credentials"] + except KeyError: + pass + else: + if credentials is None: + raise AuthError("Password cannot be None") + + @property + def encrypted(self): + return isinstance(self.socket, SSLSocket) + + @property + def der_encoded_server_certificate(self): + return self.socket.getpeercert(binary_form=True) + + @property + def local_port(self): + try: + return self.socket.getsockname()[1] + except IOError: + return 0 + + def hello(self): + headers = {"user_agent": self.user_agent} + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, on_success=self.server_info.metadata.update)) + self.send_all() + self.fetch_all() + + def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if db: + extra["db"] = db + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, **handlers)) + else: + self._append(b"\x10", fields, Response(self, **handlers)) + self._is_reset = False + + def discard(self, n=-1, qid=-1, **handlers): + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) + self._append(b"\x2F", (extra,), Response(self, **handlers)) + + def pull(self, n=-1, qid=-1, **handlers): + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: PULL %r", self.local_port, extra) + self._append(b"\x3F", (extra,), Response(self, **handlers)) + self._is_reset = False + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, **handlers): + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if db: + extra["db"] = db + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, **handlers)) + self._is_reset = False + + def commit(self, **handlers): + log.debug("[#%04X] C: COMMIT", self.local_port) + self._append(b"\x12", (), CommitResponse(self, **handlers)) + + def rollback(self, **handlers): + log.debug("[#%04X] C: ROLLBACK", self.local_port) + self._append(b"\x13", (), Response(self, **handlers)) + + def _append(self, signature, fields=(), response=None): + """ Add a message to the outgoing queue. + + :arg signature: the signature of the message + :arg fields: the fields of the message as a tuple + :arg response: a response object to handle callbacks + """ + self.packer.pack_struct(signature, fields) + self.outbox.chunk() + self.outbox.chunk() + self.responses.append(response) + + def reset(self): + """ Add a RESET message to the outgoing queue, send + it and consume all remaining messages. + """ + + def fail(metadata): + raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) + + log.debug("[#%04X] C: RESET", self.local_port) + self._append(b"\x0F", response=Response(self, on_failure=fail)) + self.send_all() + self.fetch_all() + self._is_reset = True + + def _send_all(self): + data = self.outbox.view() + if data: + self.socket.sendall(data) + self.outbox.clear() + + def send_all(self): + """ Send all queued messages to the server. + """ + if self.closed(): + raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self.defunct(): + raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + try: + self._send_all() + except (IOError, OSError) as error: + log.error("Failed to write data to connection " + "{!r} ({!r}); ({!r})". + format(self.unresolved_address, + self.server_info.address, + "; ".join(map(repr, error.args)))) + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + raise + + def fetch_message(self): + """ Receive at least one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + if self._closed: + raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self._defunct: + raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if not self.responses: + return 0, 0 + + # Receive exactly one message + try: + details, summary_signature, summary_metadata = next(self.inbox) + except (IOError, OSError) as error: + log.error("Failed to read data from connection " + "{!r} ({!r}); ({!r})". + format(self.unresolved_address, + self.server_info.address, + "; ".join(map(repr, error.args)))) + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + raise + + if details: + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data + self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) + response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7E": + log.debug("[#%04X] S: IGNORED", self.local_port) + response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7F": + log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) + try: + response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + self.pool.deactivate(address=self.unresolved_address), + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure(address=self.unresolved_address), + raise + else: + raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, self.unresolved_address) + + return len(details), 1 + + def _set_defunct(self, error=None): + direct_driver = isinstance(self.pool, BoltPool) + + message = ("Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + log.error(message) + # We were attempting to receive data but the connection + # has unexpectedly terminated. So, we need to close the + # connection from the client side, and remove the address + # from the connection pool. + self._defunct = True + self.close() + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + # Iterate through the outstanding responses, and if any correspond + # to COMMIT requests then raise an error to signal that we are + # unable to confirm that the COMMIT completed successfully. + for response in self.responses: + if isinstance(response, CommitResponse): + raise BoltIncompleteCommitError(message, address=None) + + if direct_driver: + raise ServiceUnavailable(message) + else: + raise SessionExpired(message) + + def timedout(self): + return 0 <= self._max_connection_lifetime <= perf_counter() - self._creation_timestamp + + def fetch_all(self): + """ Fetch all outstanding messages. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + detail_count = summary_count = 0 + while self.responses: + response = self.responses[0] + while not response.complete: + detail_delta, summary_delta = self.fetch_message() + detail_count += detail_delta + summary_count += summary_delta + return detail_count, summary_count + + def close(self): + """ Close the connection. + """ + if not self._closed: + if not self._defunct: + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", ()) + try: + self._send_all() + except: + pass + log.debug("[#%04X] C: ", self.local_port) + try: + self.socket.close() + except IOError: + pass + finally: + self._closed = True + + def closed(self): + return self._closed + + def defunct(self): + return self._defunct + + +class Outbox: + + def __init__(self, capacity=8192, max_chunk_size=16384): + self._max_chunk_size = max_chunk_size + self._header = 0 + self._start = 2 + self._end = 2 + self._data = bytearray(capacity) + + def max_chunk_size(self): + return self._max_chunk_size + + def clear(self): + self._header = 0 + self._start = 2 + self._end = 2 + self._data[0:2] = b"\x00\x00" + + def write(self, b): + to_write = len(b) + max_chunk_size = self._max_chunk_size + pos = 0 + while to_write > 0: + chunk_size = self._end - self._start + remaining = max_chunk_size - chunk_size + if remaining == 0 or remaining < to_write <= max_chunk_size: + self.chunk() + else: + wrote = min(to_write, remaining) + new_end = self._end + wrote + self._data[self._end:new_end] = b[pos:pos+wrote] + self._end = new_end + pos += wrote + new_chunk_size = self._end - self._start + self._data[self._header:(self._header + 2)] = struct_pack(">H", new_chunk_size) + to_write -= wrote + + def chunk(self): + self._header = self._end + self._start = self._header + 2 + self._end = self._start + self._data[self._header:self._start] = b"\x00\x00" + + def view(self): + end = self._end + chunk_size = end - self._start + if chunk_size == 0: + return memoryview(self._data[:self._header]) + else: + return memoryview(self._data[:end]) + + +class Inbox(MessageInbox): + + def __next__(self): + tag, fields = self.pop() + if tag == b"\x71": + return fields, None, None + elif fields: + return [], tag, fields[0] + else: + return [], tag, None + + +class Response: + """ Subscriber object for a full response (zero or + more detail messages followed by one summary message). + """ + + def __init__(self, connection, **handlers): + self.connection = connection + self.handlers = handlers + self.complete = False + + def on_records(self, records): + """ Called when one or more RECORD messages have been received. + """ + handler = self.handlers.get("on_records") + if callable(handler): + handler(records) + + def on_success(self, metadata): + """ Called when a SUCCESS message has been received. + """ + handler = self.handlers.get("on_success") + if callable(handler): + handler(metadata) + + if not metadata.get("has_more"): + handler = self.handlers.get("on_summary") + if callable(handler): + handler() + + def on_failure(self, metadata): + """ Called when a FAILURE message has been received. + """ + self.connection.reset() + handler = self.handlers.get("on_failure") + if callable(handler): + handler(metadata) + handler = self.handlers.get("on_summary") + if callable(handler): + handler() + raise Neo4jError.hydrate(**metadata) + + def on_ignored(self, metadata=None): + """ Called when an IGNORED message has been received. + """ + handler = self.handlers.get("on_ignored") + if callable(handler): + handler(metadata) + handler = self.handlers.get("on_summary") + if callable(handler): + handler() + + +class InitResponse(Response): + + def on_failure(self, metadata): + code = metadata.get("code") + message = metadata.get("message", "Connection initialisation failed") + if code == "Neo.ClientError.Security.Unauthorized": + raise AuthError(message) + else: + raise ServiceUnavailable(message) + + +class CommitResponse(Response): + + pass diff --git a/tests/integration/examples/test_driver_introduction_example.py b/tests/integration/examples/test_driver_introduction_example.py new file mode 100644 index 000000000..942ab19c0 --- /dev/null +++ b/tests/integration/examples/test_driver_introduction_example.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2020 "Neo4j," +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from contextlib import redirect_stdout +from io import StringIO + +# tag::driver-introduction-example-import[] +from neo4j import GraphDatabase +import logging +from neo4j.exceptions import ServiceUnavailable +# end::driver-introduction-example-import[] + +from neo4j._exceptions import BoltHandshakeError + + +# python -m pytest tests/integration/examples/test_aura_example.py -s -v + +# tag::driver-introduction-example[] +class App: + + def __init__(self, uri, user, password): + # Aura queries use an encrypted connection + self.driver = GraphDatabase.driver(uri, auth=(user, password), encrypted=True) + + def close(self): + # Don't forget to close the driver connection when you are finished with it + self.driver.close() + + def create_friendship(self, person1_name, person2_name): + with self.driver.session() as session: + # Write transactions allow the driver to handle retries and transient errors + result = session.write_transaction( + self._create_and_return_friendship, person1_name, person2_name) + for row in result: + print("Created friendship between: {p1}, {p2}".format(p1=row['p1'], p2=row['p2'])) + + @staticmethod + def _create_and_return_friendship(tx, person1_name, person2_name): + # To learn more about the Cypher syntax, see https://neo4j.com/docs/cypher-manual/current/ + # The Reference Card is also a good resource for keywords https://neo4j.com/docs/cypher-refcard/current/ + query = """ + CREATE (p1:Person { name: $person1_name }) + CREATE (p2:Person { name: $person2_name }) + CREATE (p1)-[:KNOWS]->(p2) + RETURN p1, p2 + """ + result = tx.run(query, person1_name=person1_name, person2_name=person2_name) + try: + return [{"p1": row["p1"]["name"], "p2": row["p2"]["name"]} + for row in result] + # Capture any errors along with the query and data for traceability + except ServiceUnavailable as exception: + logging.error("{query} raised an error: \n {exception}".format( + query=query, exception=exception)) + raise + + def find_person(self, person_name): + with self.driver.session() as session: + result = session.read_transaction(self._find_and_return_person, person_name) + for row in result: + print("Found person: {row}".format(row=row)) + + @staticmethod + def _find_and_return_person(tx, person_name): + query = """ + MATCH (p:Person) + WHERE p.name = $person_name + RETURN p.name AS name + """ + result = tx.run(query, person_name=person_name) + return [row["name"] for row in result] + +if __name__ == "__main__": + # Aura uses the "bolt+routing" protocol + bolt_url = "%%BOLT_URL_PLACEHOLDER%%" + user = "" + password = "" + app = App(bolt_url, user, password) + app.create_friendship("Alice", "David") + app.find_person("Alice") + app.close() +# end::driver-introduction-example[] + + +def test_driver_introduction_example(uri, auth): + try: + s = StringIO() + with redirect_stdout(s): + app = App(uri, auth[0], auth[1]) + app.create_friendship("Alice", "David") + app.find_person("Alice") + app.close() + + assert s.getvalue().startswith("Found person: Alice") + except ServiceUnavailable as error: + if isinstance(error.__cause__, BoltHandshakeError): + pytest.skip(error.args[0]) diff --git a/tests/requirements.txt b/tests/requirements.txt index ab208c48e..360116f21 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/neo4j-drivers/boltkit@4.0#egg=boltkit +git+https://github.com/neo4j-drivers/boltkit@4.1#egg=boltkit coverage pytest pytest-benchmark diff --git a/tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script new file mode 100644 index 000000000..84da1223d --- /dev/null +++ b/tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script @@ -0,0 +1,6 @@ +!: BOLT 4.1 + +C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} +S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} +C: GOODBYE +S: \ No newline at end of file diff --git a/tests/stub/test_directdriver.py b/tests/stub/test_directdriver.py index 5896e1fac..4504eb287 100644 --- a/tests/stub/test_directdriver.py +++ b/tests/stub/test_directdriver.py @@ -94,6 +94,7 @@ def test_bolt_uri_constructs_bolt_driver(driver_info, test_script): # ("v2/empty_explicit_hello_goodbye.script", ServiceUnavailable), # skip: cant close stub server gracefully ("v3/empty_explicit_hello_goodbye.script", None), ("v4x0/empty_explicit_hello_goodbye.script", None), + ("v4x1/empty_explicit_hello_goodbye.script", None), ] ) def test_direct_driver_handshake_negotiation(driver_info, test_script, test_expected): diff --git a/tests/unit/io/test_class_bolt.py b/tests/unit/io/test_class_bolt.py index 45f691804..082b07f21 100644 --- a/tests/unit/io/test_class_bolt.py +++ b/tests/unit/io/test_class_bolt.py @@ -28,7 +28,7 @@ def test_class_method_protocol_handlers(): # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers protocol_handlers = Bolt.protocol_handlers() - assert len(protocol_handlers) == 2 + assert len(protocol_handlers) == 3 @pytest.mark.parametrize( @@ -53,7 +53,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): def test_class_method_get_handshake(): # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_get_handshake handshake = Bolt.get_handshake() - assert handshake == b"\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00" + assert handshake == b"\x00\x00\x01\x04\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x00" def test_magic_preamble():