diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 2ea7228c0..066d64875 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -20,3 +20,12 @@ from .meta import version as __version__ + +# Export current (v1) API. This should be updated to export the latest +# version of the API when a new one is added. This gives the option to +# `import neo4j.vX` for a specific version or `import neo4j` for the +# latest. +from .v1.constants import * +from .v1.exceptions import * +from .v1.session import * +from .v1.types import * diff --git a/neo4j/util.py b/neo4j/util.py index ddaa1e2e9..0bdbbf960 100644 --- a/neo4j/util.py +++ b/neo4j/util.py @@ -19,14 +19,8 @@ # limitations under the License. -from __future__ import unicode_literals - import logging -from argparse import ArgumentParser -from json import loads as json_loads -from sys import stdout, stderr - -from .v1.session import GraphDatabase, CypherError +from sys import stdout class ColourFormatter(logging.Formatter): @@ -50,7 +44,7 @@ def format(self, record): class Watcher(object): - """ Log watcher for debug output. + """ Log watcher for monitoring driver and protocol activity. """ handlers = {} @@ -74,3 +68,16 @@ def stop(self): self.logger.removeHandler(self.handlers[self.logger_name]) except KeyError: pass + + +def watch(logger_name, level=logging.INFO, out=stdout): + """ Quick wrapper for using the Watcher. + + :param logger_name: name of logger to watch + :param level: minimum log level to show (default INFO) + :param out: where to send output (default stdout) + :return: Watcher instance + """ + watcher = Watcher(logger_name) + watcher.watch(level, out) + return watcher diff --git a/neo4j/v1/__init__.py b/neo4j/v1/__init__.py index ab5f76641..e445046cc 100644 --- a/neo4j/v1/__init__.py +++ b/neo4j/v1/__init__.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .connection import ProtocolError from .constants import * +from .exceptions import * from .session import * from .types import * diff --git a/neo4j/v1/connection.py b/neo4j/v1/bolt.py similarity index 96% rename from neo4j/v1/connection.py rename to neo4j/v1/bolt.py index 25d25174e..d7bf8446f 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/bolt.py @@ -31,8 +31,6 @@ from socket import create_connection, SHUT_RDWR, error as SocketError from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from -import errno - from .constants import DEFAULT_PORT, DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, \ TRUST_DEFAULT, TRUST_ON_FIRST_USE from .compat import hex2 @@ -239,6 +237,13 @@ def on_failure(metadata): def __del__(self): self.close() + @property + def healthy(self): + """ Return ``True`` if this connection is healthy, ``False`` if + unhealthy and ``None`` if closed. + """ + return None if self.closed else not self.defunct + def append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. @@ -333,6 +338,12 @@ def fetch(self): handler(*fields) raw.close() + def fetch_all(self): + while self.responses: + response = self.responses[0] + while not response.complete: + self.fetch() + def close(self): """ Close the connection. """ @@ -389,7 +400,7 @@ def match_or_trust(self, host, der_encoded_certificate): return True -def connect(host, port=None, ssl_context=None, **config): +def connect(host_port, ssl_context=None, **config): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -397,18 +408,18 @@ def connect(host, port=None, ssl_context=None, **config): # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html - port = port or DEFAULT_PORT - if __debug__: log_info("~~ [CONNECT] %s %d", host, port) + if __debug__: log_info("~~ [CONNECT] %s", host_port) try: - s = create_connection((host, port)) + s = create_connection(host_port) except SocketError as error: if error.errno == 111 or error.errno == 61: - raise ProtocolError("Unable to connect to %s on port %d - is the server running?" % (host, port)) + raise ProtocolError("Unable to connect to %s on port %d - is the server running?" % host_port) else: raise # Secure the connection if an SSL context has been provided if ssl_context and SSL_AVAILABLE: + host, port = host_port if __debug__: log_info("~~ [SECURE] %s", host) try: s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None) diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index ef7a84459..bf95d35b7 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -28,23 +28,19 @@ class which can be used to obtain `Driver` instances that are used for from __future__ import division -from collections import deque, namedtuple +from collections import deque +from .bolt import connect, Response, RUN, PULL_ALL from .compat import integer, string, urlparse -from .connection import connect, Response, RUN, PULL_ALL -from .constants import ENCRYPTED_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES +from .constants import DEFAULT_PORT, ENCRYPTED_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES from .exceptions import CypherError, ProtocolError, ResultError from .ssl_compat import SSL_AVAILABLE, SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED +from .summary import ResultSummary from .types import hydrated DEFAULT_MAX_POOL_SIZE = 50 -STATEMENT_TYPE_READ_ONLY = "r" -STATEMENT_TYPE_READ_WRITE = "rw" -STATEMENT_TYPE_WRITE_ONLY = "w" -STATEMENT_TYPE_SCHEMA_WRITE = "s" - def basic_auth(user, password): """ Generate a basic auth token for a given user and password. @@ -100,15 +96,21 @@ class Driver(object): """ Accessor for a specific graph database resource. """ - def __init__(self, url, **config): - self.url = url - parsed = urlparse(self.url) - if parsed.scheme == "bolt": - self.host = parsed.hostname - self.port = parsed.port + def __init__(self, address, **config): + if "://" in address: + parsed = urlparse(address) + if parsed.scheme == "bolt": + host = parsed.hostname + port = parsed.port or DEFAULT_PORT + else: + raise ProtocolError("Only the 'bolt' URI scheme is supported [%s]" % address) + elif ":" in address: + host, port = address.split(":") + port = int(port) else: - raise ProtocolError("Unsupported URI scheme: '%s' in url: '%s'. Currently only supported 'bolt'." % - (parsed.scheme, url)) + host = address + port = DEFAULT_PORT + self.address = (host, port) self.config = config self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE) self.session_pool = deque() @@ -137,20 +139,20 @@ def session(self): >>> from neo4j.v1 import GraphDatabase >>> driver = GraphDatabase.driver("bolt://localhost") >>> session = driver.session() - """ session = None - done = False - while not done: + connected = False + while not connected: try: session = self.session_pool.pop() except IndexError: - session = Session(self) - done = True + connection = connect(self.address, self.ssl_context, **self.config) + session = Session(self, connection) + connected = True else: if session.healthy: session.connection.reset() - done = session.healthy + connected = session.healthy return session def recycle(self, session): @@ -284,170 +286,39 @@ def peek(self): raise ResultError("End of stream") -class ResultSummary(object): - """ A summary of execution returned with a :class:`.StatementResult` object. - """ - - #: The statement that was executed to produce this result. - statement = None - - #: Dictionary of parameters passed with the statement. - parameters = None +def run(connection, statement, parameters=None): + """ Run a Cypher statement on a given connection. - #: The type of statement (``'r'`` = read-only, ``'rw'`` = read/write). - statement_type = None - - #: A set of statistical information held in a :class:`.Counters` instance. - counters = None - - #: A :class:`.Plan` instance - plan = None - - #: A :class:`.ProfiledPlan` instance - profile = None - - #: Notifications provide extra information for a user executing a statement. - #: They can be warnings about problematic queries or other valuable information that can be - #: presented in a client. - #: Unlike failures or errors, notifications do not affect the execution of a statement. - notifications = None - - def __init__(self, statement, parameters, **metadata): - self.statement = statement - self.parameters = parameters - self.statement_type = metadata.get("type") - self.counters = SummaryCounters(metadata.get("stats", {})) - if "plan" in metadata: - self.plan = make_plan(metadata["plan"]) - if "profile" in metadata: - self.profile = make_plan(metadata["profile"]) - self.plan = self.profile - self.notifications = [] - for notification in metadata.get("notifications", []): - position = notification.get("position") - if position is not None: - position = Position(position["offset"], position["line"], position["column"]) - self.notifications.append(Notification(notification["code"], notification["title"], - notification["description"], notification["severity"], position)) - - -class SummaryCounters(object): - """ Set of statistics from a Cypher statement execution. + :param connection: connection to carry the request and response + :param statement: Cypher statement + :param parameters: optional dictionary of parameters + :return: statement result """ + # Ensure the statement is a Unicode value + if isinstance(statement, bytes): + statement = statement.decode("UTF-8") + + params = {} + for key, value in (parameters or {}).items(): + if isinstance(key, bytes): + key = key.decode("UTF-8") + if isinstance(value, bytes): + params[key] = value.decode("UTF-8") + else: + params[key] = value + parameters = params - #: - nodes_created = 0 - - #: - nodes_deleted = 0 - - #: - relationships_created = 0 - - #: - relationships_deleted = 0 - - #: - properties_set = 0 - - #: - labels_added = 0 - - #: - labels_removed = 0 - - #: - indexes_added = 0 - - #: - indexes_removed = 0 - - #: - constraints_added = 0 - - #: - constraints_removed = 0 + run_response = Response(connection) + pull_all_response = Response(connection) + result = StatementResult(connection, run_response, pull_all_response) + result.statement = statement + result.parameters = parameters - def __init__(self, statistics): - for key, value in dict(statistics).items(): - key = key.replace("-", "_") - setattr(self, key, value) + connection.append(RUN, (statement, parameters), response=run_response) + connection.append(PULL_ALL, response=pull_all_response) + connection.send() - def __repr__(self): - return repr(vars(self)) - - @property - def contains_updates(self): - return bool(self.nodes_created or self.nodes_deleted or \ - self.relationships_created or self.relationships_deleted or \ - self.properties_set or self.labels_added or self.labels_removed or \ - self.indexes_added or self.indexes_removed or \ - self.constraints_added or self.constraints_removed) - - -#: A plan describes how the database will execute your statement. -#: -#: operator_type: -#: the name of the operation performed by the plan -#: identifiers: -#: the list of identifiers used by this plan -#: arguments: -#: a dictionary of arguments used in the specific operation performed by the plan -#: children: -#: a list of sub-plans -Plan = namedtuple("Plan", ("operator_type", "identifiers", "arguments", "children")) - -#: A profiled plan describes how the database executed your statement. -#: -#: db_hits: -#: the number of times this part of the plan touched the underlying data stores -#: rows: -#: the number of records this part of the plan produced -ProfiledPlan = namedtuple("ProfiledPlan", Plan._fields + ("db_hits", "rows")) - -#: Representation for notifications found when executing a statement. A -#: notification can be visualized in a client pinpointing problems or -#: other information about the statement. -#: -#: code: -#: a notification code for the discovered issue. -#: title: -#: a short summary of the notification -#: description: -#: a long description of the notification -#: severity: -#: the severity level of the notification -#: position: -#: the position in the statement where this notification points to, if relevant. -Notification = namedtuple("Notification", ("code", "title", "description", "severity", "position")) - -#: A position within a statement, consisting of offset, line and column. -#: -#: offset: -#: the character offset referred to by this position; offset numbers start at 0 -#: line: -#: the line number referred to by the position; line numbers start at 1 -#: column: -#: the column number referred to by the position; column numbers start at 1 -Position = namedtuple("Position", ("offset", "line", "column")) - - -def make_plan(plan_dict): - """ Construct a Plan or ProfiledPlan from a dictionary of metadata values. - - :param plan_dict: - :return: - """ - operator_type = plan_dict["operatorType"] - identifiers = plan_dict.get("identifiers", []) - arguments = plan_dict.get("args", []) - children = [make_plan(child) for child in plan_dict.get("children", [])] - if "dbHits" in plan_dict or "rows" in plan_dict: - db_hits = plan_dict.get("dbHits", 0) - rows = plan_dict.get("rows", 0) - return ProfiledPlan(operator_type, identifiers, arguments, children, db_hits, rows) - else: - return Plan(operator_type, identifiers, arguments, children) + return result class Session(object): @@ -456,11 +327,10 @@ class Session(object): method. """ - def __init__(self, driver): + def __init__(self, driver, connection): self.driver = driver - self.connection = connect(driver.host, driver.port, driver.ssl_context, **driver.config) + self.connection = connection self.transaction = None - self.last_result = None def __enter__(self): return self @@ -473,8 +343,7 @@ def healthy(self): """ Return ``True`` if this session is healthy, ``False`` if unhealthy and ``None`` if closed. """ - connection = self.connection - return None if connection.closed else not connection.defunct + return self.connection.healthy def run(self, statement, parameters=None): """ Run a parameterised Cypher statement. @@ -487,41 +356,13 @@ def run(self, statement, parameters=None): if self.transaction: raise ProtocolError("Statements cannot be run directly on a session with an open transaction;" " either run from within the transaction or use a different session.") - return self._run(statement, parameters) - - def _run(self, statement, parameters=None): - # Ensure the statement is a Unicode value - if isinstance(statement, bytes): - statement = statement.decode("UTF-8") - - params = {} - for key, value in (parameters or {}).items(): - if isinstance(key, bytes): - key = key.decode("UTF-8") - if isinstance(value, bytes): - params[key] = value.decode("UTF-8") - else: - params[key] = value - parameters = params - - run_response = Response(self.connection) - pull_all_response = Response(self.connection) - result = StatementResult(self.connection, run_response, pull_all_response) - result.statement = statement - result.parameters = parameters - - self.connection.append(RUN, (statement, parameters), response=run_response) - self.connection.append(PULL_ALL, response=pull_all_response) - self.connection.send() - - self.last_result = result - return result + return run(self.connection, statement, parameters) def close(self): """ Recycle this session through the driver it came from. """ - if self.last_result: - self.last_result.buffer() + if self.connection and not self.connection.closed: + self.connection.fetch_all() if self.transaction: self.transaction.close() self.driver.recycle(self) @@ -534,7 +375,11 @@ def begin_transaction(self): if self.transaction: raise ProtocolError("You cannot begin a transaction on a session with an open transaction;" " either run from within the transaction or use a different session.") - self.transaction = Transaction(self) + + def clear_transaction(): + self.transaction = None + + self.transaction = Transaction(self.connection, on_close=clear_transaction) return self.transaction @@ -559,9 +404,10 @@ class Transaction(object): #: with commit or rollback. closed = False - def __init__(self, session): - self.session = session - self.session._run("BEGIN") + def __init__(self, connection, on_close): + self.connection = connection + self.on_close = on_close + run(self.connection, "BEGIN") def __enter__(self): return self @@ -574,12 +420,12 @@ def __exit__(self, exc_type, exc_value, traceback): def run(self, statement, parameters=None): """ Run a Cypher statement within the context of this transaction. - :param statement: - :param parameters: - :return: + :param statement: Cypher statement + :param parameters: dictionary of parameters + :return: result object """ assert not self.closed - return self.session._run(statement, parameters) + return run(self.connection, statement, parameters) def commit(self): """ Mark this transaction as successful and close in order to @@ -600,11 +446,11 @@ def close(self): """ assert not self.closed if self.success: - self.session._run("COMMIT") + run(self.connection, "COMMIT") else: - self.session._run("ROLLBACK") + run(self.connection, "ROLLBACK") self.closed = True - self.session.transaction = None + self.on_close() class Record(object): diff --git a/neo4j/v1/summary.py b/neo4j/v1/summary.py new file mode 100644 index 000000000..f6fabfbbd --- /dev/null +++ b/neo4j/v1/summary.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import namedtuple + + +STATEMENT_TYPE_READ_ONLY = "r" +STATEMENT_TYPE_READ_WRITE = "rw" +STATEMENT_TYPE_WRITE_ONLY = "w" +STATEMENT_TYPE_SCHEMA_WRITE = "s" + + +class ResultSummary(object): + """ A summary of execution returned with a :class:`.StatementResult` object. + """ + + #: The statement that was executed to produce this result. + statement = None + + #: Dictionary of parameters passed with the statement. + parameters = None + + #: The type of statement (``'r'`` = read-only, ``'rw'`` = read/write). + statement_type = None + + #: A set of statistical information held in a :class:`.Counters` instance. + counters = None + + #: A :class:`.Plan` instance + plan = None + + #: A :class:`.ProfiledPlan` instance + profile = None + + #: Notifications provide extra information for a user executing a statement. + #: They can be warnings about problematic queries or other valuable information that can be + #: presented in a client. + #: Unlike failures or errors, notifications do not affect the execution of a statement. + notifications = None + + def __init__(self, statement, parameters, **metadata): + self.statement = statement + self.parameters = parameters + self.statement_type = metadata.get("type") + self.counters = SummaryCounters(metadata.get("stats", {})) + if "plan" in metadata: + self.plan = make_plan(metadata["plan"]) + if "profile" in metadata: + self.profile = make_plan(metadata["profile"]) + self.plan = self.profile + self.notifications = [] + for notification in metadata.get("notifications", []): + position = notification.get("position") + if position is not None: + position = Position(position["offset"], position["line"], position["column"]) + self.notifications.append(Notification(notification["code"], notification["title"], + notification["description"], notification["severity"], position)) + + +class SummaryCounters(object): + """ Set of statistics from a Cypher statement execution. + """ + + #: + nodes_created = 0 + + #: + nodes_deleted = 0 + + #: + relationships_created = 0 + + #: + relationships_deleted = 0 + + #: + properties_set = 0 + + #: + labels_added = 0 + + #: + labels_removed = 0 + + #: + indexes_added = 0 + + #: + indexes_removed = 0 + + #: + constraints_added = 0 + + #: + constraints_removed = 0 + + def __init__(self, statistics): + for key, value in dict(statistics).items(): + key = key.replace("-", "_") + setattr(self, key, value) + + def __repr__(self): + return repr(vars(self)) + + @property + def contains_updates(self): + return bool(self.nodes_created or self.nodes_deleted or + self.relationships_created or self.relationships_deleted or + self.properties_set or self.labels_added or self.labels_removed or + self.indexes_added or self.indexes_removed or + self.constraints_added or self.constraints_removed) + + +#: A plan describes how the database will execute your statement. +#: +#: operator_type: +#: the name of the operation performed by the plan +#: identifiers: +#: the list of identifiers used by this plan +#: arguments: +#: a dictionary of arguments used in the specific operation performed by the plan +#: children: +#: a list of sub-plans +Plan = namedtuple("Plan", ("operator_type", "identifiers", "arguments", "children")) + +#: A profiled plan describes how the database executed your statement. +#: +#: db_hits: +#: the number of times this part of the plan touched the underlying data stores +#: rows: +#: the number of records this part of the plan produced +ProfiledPlan = namedtuple("ProfiledPlan", Plan._fields + ("db_hits", "rows")) + +#: Representation for notifications found when executing a statement. A +#: notification can be visualized in a client pinpointing problems or +#: other information about the statement. +#: +#: code: +#: a notification code for the discovered issue. +#: title: +#: a short summary of the notification +#: description: +#: a long description of the notification +#: severity: +#: the severity level of the notification +#: position: +#: the position in the statement where this notification points to, if relevant. +Notification = namedtuple("Notification", ("code", "title", "description", "severity", "position")) + +#: A position within a statement, consisting of offset, line and column. +#: +#: offset: +#: the character offset referred to by this position; offset numbers start at 0 +#: line: +#: the line number referred to by the position; line numbers start at 1 +#: column: +#: the column number referred to by the position; column numbers start at 1 +Position = namedtuple("Position", ("offset", "line", "column")) + + +def make_plan(plan_dict): + """ Construct a Plan or ProfiledPlan from a dictionary of metadata values. + + :param plan_dict: + :return: + """ + operator_type = plan_dict["operatorType"] + identifiers = plan_dict.get("identifiers", []) + arguments = plan_dict.get("args", []) + children = [make_plan(child) for child in plan_dict.get("children", [])] + if "dbHits" in plan_dict or "rows" in plan_dict: + db_hits = plan_dict.get("dbHits", 0) + rows = plan_dict.get("rows", 0) + return ProfiledPlan(operator_type, identifiers, arguments, children, db_hits, rows) + else: + return Plan(operator_type, identifiers, arguments, children) diff --git a/test/test_session.py b/test/test_session.py index 5a5b66af6..9695f3d5a 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -26,7 +26,7 @@ from mock import patch from neo4j.v1.constants import TRUST_ON_FIRST_USE -from neo4j.v1.exceptions import CypherError, ResultError +from neo4j.v1.exceptions import CypherError, ProtocolError, ResultError from neo4j.v1.session import GraphDatabase, basic_auth, Record, SSL_AVAILABLE from neo4j.v1.types import Node, Relationship, Path @@ -34,7 +34,6 @@ auth_token = basic_auth("neo4j", "neo4j") -from neo4j.v1.exceptions import ProtocolError class DriverTestCase(ServerTestCase): @@ -346,7 +345,7 @@ def test_automatic_reset_after_failure(self): assert False, "A Cypher error should have occurred" def test_defunct(self): - from neo4j.v1.connection import ChunkChannel, ProtocolError + from neo4j.v1.bolt import ChunkChannel, ProtocolError with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: assert not session.connection.defunct with patch.object(ChunkChannel, "chunk_reader", side_effect=ProtocolError()):