From 63ca7a0472b4b5b347d9de6a1af3c1ef269bab25 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Tue, 9 Feb 2016 16:59:13 +0000 Subject: [PATCH 1/4] Test socket types --- neo4j/v1/connection.py | 6 +----- test/test_session.py | 20 +++++++++++++++++--- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index f64aff0c4..5ca89feb9 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -324,11 +324,7 @@ def connect(host, port=None, **config): s = create_connection((host, port)) # Secure the connection if so requested - try: - secure = environ["NEO4J_SECURE"] - except KeyError: - secure = config.get("secure", False) - if secure: + if config.get("secure", False): if __debug__: log_info("~~ [SECURE] %s", host) s = secure_socket(s, host) diff --git a/test/test_session.py b/test/test_session.py index 4cd8b118e..6bd93565e 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -19,6 +19,8 @@ # limitations under the License. +from socket import socket +from ssl import SSLSocket from unittest import TestCase from mock import patch @@ -60,9 +62,6 @@ def test_session_that_dies_in_the_pool_will_not_be_given_out(self): session_2 = driver.session() assert session_2 is not session_1 - -class RunTestCase(TestCase): - def test_must_use_valid_url_scheme(self): with self.assertRaises(ValueError): GraphDatabase.driver("x://xxx") @@ -83,6 +82,21 @@ def test_sessions_are_not_reused_if_still_in_use(self): session_1.close() assert session_1 is not session_2 + def test_insecure_session_uses_insecure_socket(self): + driver = GraphDatabase.driver("bolt://localhost", secure=False) + session = driver.session() + assert isinstance(session.connection.channel.socket, socket) + session.close() + + def test_secure_session_uses_secure_socket(self): + driver = GraphDatabase.driver("bolt://localhost", secure=True) + session = driver.session() + assert isinstance(session.connection.channel.socket, SSLSocket) + session.close() + + +class RunTestCase(TestCase): + def test_can_run_simple_statement(self): session = GraphDatabase.driver("bolt://localhost").session() count = 0 From 8a02a90194f2c415df0ce2af35ea2455902d1f61 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Thu, 11 Feb 2016 15:58:24 +0000 Subject: [PATCH 2/4] TLS settings and tests --- neo4j/v1/__init__.py | 1 + neo4j/v1/compat.py | 16 --------- neo4j/v1/connection.py | 78 ++++++++++++++++++++++++++++++++++-------- neo4j/v1/constants.py | 36 +++++++++++++++++++ neo4j/v1/session.py | 15 +++++++- test/test_session.py | 62 ++++++++++++++++++++++++++++----- 6 files changed, 169 insertions(+), 39 deletions(-) create mode 100644 neo4j/v1/constants.py diff --git a/neo4j/v1/__init__.py b/neo4j/v1/__init__.py index d51d7b9af..1a1b454b3 100644 --- a/neo4j/v1/__init__.py +++ b/neo4j/v1/__init__.py @@ -18,5 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .constants import * from .session import * from .typesystem import * diff --git a/neo4j/v1/compat.py b/neo4j/v1/compat.py index 24cdbc744..dc21adad6 100644 --- a/neo4j/v1/compat.py +++ b/neo4j/v1/compat.py @@ -90,19 +90,3 @@ def perf_counter(): from urllib.parse import urlparse except ImportError: from urlparse import urlparse - - -try: - from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, HAS_SNI -except ImportError: - from ssl import wrap_socket, PROTOCOL_SSLv23 - - def secure_socket(s, host): - return wrap_socket(s, ssl_version=PROTOCOL_SSLv23) - -else: - - def secure_socket(s, host): - ssl_context = SSLContext(PROTOCOL_SSLv23) - ssl_context.options |= OP_NO_SSLv2 - return ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None) diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 5ca89feb9..85cb2758c 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -21,25 +21,24 @@ from __future__ import division +from base64 import b64encode from collections import deque from io import BytesIO import logging -from os import environ +from os import makedirs, open as os_open, write as os_write, close as os_close, O_CREAT, O_APPEND, O_WRONLY +from os.path import dirname, isfile from select import select from socket import create_connection, SHUT_RDWR +from ssl import HAS_SNI, SSLError from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from -from ..meta import version -from .compat import hex2, secure_socket +from .constants import DEFAULT_PORT, DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, \ + SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE +from .compat import hex2 from .exceptions import ProtocolError from .packstream import Packer, Unpacker -DEFAULT_PORT = 7687 -DEFAULT_USER_AGENT = "neo4j-python/%s" % version - -MAGIC_PREAMBLE = 0x6060B017 - # Signature bytes for each message type INIT = b"\x01" # 0000 0001 // INIT RESET = b"\x0F" # 0000 1111 // RESET @@ -211,6 +210,10 @@ def __init__(self, sock, **config): user_agent = config.get("user_agent", DEFAULT_USER_AGENT) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") + self.user_agent = user_agent + + # Pick up the server certificate, if any + self.der_encoded_server_certificate = config.get("der_encoded_server_certificate") def on_failure(metadata): raise ProtocolError("Initialisation failed") @@ -218,7 +221,7 @@ def on_failure(metadata): response = Response(self) response.on_failure = on_failure - self.append(INIT, (user_agent,), response=response) + self.append(INIT, (self.user_agent,), response=response) self.send() while not response.complete: self.fetch() @@ -313,7 +316,39 @@ def close(self): self.closed = True -def connect(host, port=None, **config): +def verify_certificate(host, der_encoded_certificate): + base64_encoded_certificate = b64encode(der_encoded_certificate) + if isfile(KNOWN_HOSTS): + with open(KNOWN_HOSTS) as f_in: + for line in f_in: + known_host, _, known_cert = line.strip().partition(":") + if host == known_host: + if base64_encoded_certificate == known_cert: + # Certificate match + return + else: + # Certificate mismatch + print(base64_encoded_certificate) + print(known_cert) + raise ProtocolError("Server certificate does not match known certificate for %r; check " + "details in file %r" % (host, KNOWN_HOSTS)) + # First use (no hosts match) + try: + makedirs(dirname(KNOWN_HOSTS)) + except OSError: + pass + f_out = os_open(KNOWN_HOSTS, O_CREAT | O_APPEND | O_WRONLY, 0o600) # TODO: Windows + if isinstance(host, bytes): + os_write(f_out, host) + else: + os_write(f_out, host.encode("utf-8")) + os_write(f_out, b":") + os_write(f_out, base64_encoded_certificate) + os_write(f_out, b"\n") + os_close(f_out) + + +def connect(host, port=None, ssl_context=None, **config): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -323,10 +358,25 @@ def connect(host, port=None, **config): if __debug__: log_info("~~ [CONNECT] %s %d", host, port) s = create_connection((host, port)) - # Secure the connection if so requested - if config.get("secure", False): + # Secure the connection if an SSL context has been provided + if ssl_context: if __debug__: log_info("~~ [SECURE] %s", host) - s = secure_socket(s, host) + try: + s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None) + except SSLError as cause: + error = ProtocolError("Cannot establish secure connection; %s" % cause.args[1]) + error.__cause__ = cause + raise error + else: + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert(binary_form=True) + if der_encoded_server_certificate is None: + raise ProtocolError("When using a secure socket, the server should always provide a certificate") + security = config.get("security", SECURITY_NONE) + if security == SECURITY_TRUST_ON_FIRST_USE: + verify_certificate(host, der_encoded_server_certificate) + else: + der_encoded_server_certificate = None # Send details of the protocol versions supported supported_versions = [1, 0, 0, 0] @@ -360,4 +410,4 @@ def connect(host, port=None, **config): s.shutdown(SHUT_RDWR) s.close() else: - return Connection(s, **config) + return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config) diff --git a/neo4j/v1/constants.py b/neo4j/v1/constants.py new file mode 100644 index 000000000..8a2dad213 --- /dev/null +++ b/neo4j/v1/constants.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from os.path import expanduser, join + +from ..meta import version + + +DEFAULT_PORT = 7687 +DEFAULT_USER_AGENT = "neo4j-python/%s" % version + +KNOWN_HOSTS = join(expanduser("~"), ".neo4j", "known_hosts") + +MAGIC_PREAMBLE = 0x6060B017 + +SECURITY_NONE = 0 +SECURITY_TRUST_ON_FIRST_USE = 1 +SECURITY_VERIFIED = 2 diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index cace94ff6..02b54cbf1 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -29,9 +29,11 @@ class which can be used to obtain `Driver` instances that are used for from __future__ import division from collections import deque, namedtuple +from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED, Purpose from .compat import integer, string, urlparse from .connection import connect, Response, RUN, PULL_ALL +from .constants import SECURITY_NONE, SECURITY_VERIFIED from .exceptions import CypherError, ResultError from .typesystem import hydrated @@ -77,6 +79,16 @@ def __init__(self, url, **config): self.config = config self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE) self.session_pool = deque() + self.security = security = config.get("security", SECURITY_NONE) + if security > SECURITY_NONE: + ssl_context = SSLContext(PROTOCOL_SSLv23) + ssl_context.options |= OP_NO_SSLv2 + if security >= SECURITY_VERIFIED: + ssl_context.verify_mode = CERT_REQUIRED + ssl_context.load_default_certs(Purpose.SERVER_AUTH) + self.ssl_context = ssl_context + else: + self.ssl_context = None def session(self): """ Create a new session based on the graph database details @@ -425,7 +437,7 @@ class Session(object): def __init__(self, driver): self.driver = driver - self.connection = connect(driver.host, driver.port, **driver.config) + self.connection = connect(driver.host, driver.port, driver.ssl_context, **driver.config) self.transaction = None self.last_cursor = None @@ -654,6 +666,7 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def record(obj): """ Obtain an immutable record for the given object (either by calling obj.__record__() or by copying out the record data) diff --git a/test/test_session.py b/test/test_session.py index 6bd93565e..767ed4f34 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -19,16 +19,22 @@ # limitations under the License. +from os import remove, rename +from os.path import isfile from socket import socket from ssl import SSLSocket from unittest import TestCase from mock import patch -from neo4j.v1.exceptions import ResultError -from neo4j.v1.session import GraphDatabase, CypherError, Record, record +from neo4j.v1.constants import KNOWN_HOSTS, SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE, SECURITY_VERIFIED +from neo4j.v1.exceptions import CypherError, ResultError +from neo4j.v1.session import GraphDatabase, Record, record from neo4j.v1.typesystem import Node, Relationship, Path +KNOWN_HOSTS_BACKUP = KNOWN_HOSTS + ".backup" + + class DriverTestCase(TestCase): def test_healthy_session_will_be_returned_to_the_pool_on_close(self): @@ -82,17 +88,57 @@ def test_sessions_are_not_reused_if_still_in_use(self): session_1.close() assert session_1 is not session_2 - def test_insecure_session_uses_insecure_socket(self): - driver = GraphDatabase.driver("bolt://localhost", secure=False) + +class SecurityTestCase(TestCase): + + def setUp(self): + if isfile(KNOWN_HOSTS): + rename(KNOWN_HOSTS, KNOWN_HOSTS_BACKUP) + + def tearDown(self): + if isfile(KNOWN_HOSTS_BACKUP): + rename(KNOWN_HOSTS_BACKUP, KNOWN_HOSTS) + + def test_default_session_uses_security_none(self): + # TODO: verify this is the correct default (maybe TOFU?) + driver = GraphDatabase.driver("bolt://localhost") + assert driver.security == SECURITY_NONE + + def test_insecure_session_uses_normal_socket(self): + driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE) + session = driver.session() + connection = session.connection + assert isinstance(connection.channel.socket, socket) + assert connection.der_encoded_server_certificate is None + session.close() + + def test_tofu_session_uses_secure_socket(self): + driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE) session = driver.session() - assert isinstance(session.connection.channel.socket, socket) + connection = session.connection + assert isinstance(connection.channel.socket, SSLSocket) + assert connection.der_encoded_server_certificate is not None session.close() - def test_secure_session_uses_secure_socket(self): - driver = GraphDatabase.driver("bolt://localhost", secure=True) + def test_tofu_session_trusts_certificate_after_first_use(self): + driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE) session = driver.session() - assert isinstance(session.connection.channel.socket, SSLSocket) + connection = session.connection + certificate = connection.der_encoded_server_certificate session.close() + session = driver.session() + connection = session.connection + assert connection.der_encoded_server_certificate == certificate + session.close() + + # TODO: Find a way to run this test + # def test_verified_session_uses_secure_socket(self): + # driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_VERIFIED) + # session = driver.session() + # connection = session.connection + # assert isinstance(connection.channel.socket, SSLSocket) + # assert connection.der_encoded_server_certificate is not None + # session.close() class RunTestCase(TestCase): From 844dce5417cbbd0f3a8b5055225e77caec5a09a8 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Thu, 18 Feb 2016 13:16:52 +0100 Subject: [PATCH 3/4] Security TOFU, PersonalCertificateStore --- neo4j/v1/connection.py | 84 +++++++++++++++++++++++++----------------- neo4j/v1/constants.py | 2 + neo4j/v1/session.py | 4 +- test/test_session.py | 34 ++++++----------- test/util.py | 23 ++++++++++++ 5 files changed, 89 insertions(+), 58 deletions(-) diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 85cb2758c..e061fe4b6 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -33,7 +33,7 @@ from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from from .constants import DEFAULT_PORT, DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, \ - SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE + SECURITY_DEFAULT, SECURITY_TRUST_ON_FIRST_USE from .compat import hex2 from .exceptions import ProtocolError from .packstream import Packer, Unpacker @@ -316,36 +316,51 @@ def close(self): self.closed = True -def verify_certificate(host, der_encoded_certificate): - base64_encoded_certificate = b64encode(der_encoded_certificate) - if isfile(KNOWN_HOSTS): - with open(KNOWN_HOSTS) as f_in: - for line in f_in: - known_host, _, known_cert = line.strip().partition(":") - if host == known_host: - if base64_encoded_certificate == known_cert: - # Certificate match - return - else: - # Certificate mismatch - print(base64_encoded_certificate) - print(known_cert) - raise ProtocolError("Server certificate does not match known certificate for %r; check " - "details in file %r" % (host, KNOWN_HOSTS)) - # First use (no hosts match) - try: - makedirs(dirname(KNOWN_HOSTS)) - except OSError: - pass - f_out = os_open(KNOWN_HOSTS, O_CREAT | O_APPEND | O_WRONLY, 0o600) # TODO: Windows - if isinstance(host, bytes): - os_write(f_out, host) - else: - os_write(f_out, host.encode("utf-8")) - os_write(f_out, b":") - os_write(f_out, base64_encoded_certificate) - os_write(f_out, b"\n") - os_close(f_out) +class CertificateStore(object): + + def match_or_trust(self, host, der_encoded_certificate): + """ Check whether the supplied certificate matches that stored for the + specified host. If it does, return ``True``, if it doesn't, return + ``False``. If no entry for that host is found, add it to the store + and return ``True``. + + :arg host: + :arg der_encoded_certificate: + :return: + """ + raise NotImplementedError() + + +class PersonalCertificateStore(CertificateStore): + + def __init__(self, path=None): + self.path = path or KNOWN_HOSTS + + def match_or_trust(self, host, der_encoded_certificate): + base64_encoded_certificate = b64encode(der_encoded_certificate) + if isfile(self.path): + with open(self.path) as f_in: + for line in f_in: + known_host, _, known_cert = line.strip().partition(":") + if host == known_host: + print("Received: %s" % base64_encoded_certificate) + print("Known: %s" % known_cert) + return base64_encoded_certificate == known_cert + # First use (no hosts match) + try: + makedirs(dirname(self.path)) + except OSError: + pass + f_out = os_open(self.path, O_CREAT | O_APPEND | O_WRONLY, 0o600) # TODO: Windows + if isinstance(host, bytes): + os_write(f_out, host) + else: + os_write(f_out, host.encode("utf-8")) + os_write(f_out, b":") + os_write(f_out, base64_encoded_certificate) + os_write(f_out, b"\n") + os_close(f_out) + return True def connect(host, port=None, ssl_context=None, **config): @@ -372,9 +387,12 @@ def connect(host, port=None, ssl_context=None, **config): der_encoded_server_certificate = s.getpeercert(binary_form=True) if der_encoded_server_certificate is None: raise ProtocolError("When using a secure socket, the server should always provide a certificate") - security = config.get("security", SECURITY_NONE) + security = config.get("security", SECURITY_DEFAULT) if security == SECURITY_TRUST_ON_FIRST_USE: - verify_certificate(host, der_encoded_server_certificate) + store = PersonalCertificateStore() + if not store.match_or_trust(host, der_encoded_server_certificate): + raise ProtocolError("Server certificate does not match known certificate for %r; check " + "details in file %r" % (host, KNOWN_HOSTS)) else: der_encoded_server_certificate = None diff --git a/neo4j/v1/constants.py b/neo4j/v1/constants.py index 8a2dad213..238c24ed4 100644 --- a/neo4j/v1/constants.py +++ b/neo4j/v1/constants.py @@ -34,3 +34,5 @@ SECURITY_NONE = 0 SECURITY_TRUST_ON_FIRST_USE = 1 SECURITY_VERIFIED = 2 + +SECURITY_DEFAULT = SECURITY_TRUST_ON_FIRST_USE diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 02b54cbf1..56b0ac5bb 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -33,7 +33,7 @@ class which can be used to obtain `Driver` instances that are used for from .compat import integer, string, urlparse from .connection import connect, Response, RUN, PULL_ALL -from .constants import SECURITY_NONE, SECURITY_VERIFIED +from .constants import SECURITY_NONE, SECURITY_VERIFIED, SECURITY_DEFAULT from .exceptions import CypherError, ResultError from .typesystem import hydrated @@ -79,7 +79,7 @@ def __init__(self, url, **config): self.config = config self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE) self.session_pool = deque() - self.security = security = config.get("security", SECURITY_NONE) + self.security = security = config.get("security", SECURITY_DEFAULT) if security > SECURITY_NONE: ssl_context = SSLContext(PROTOCOL_SSLv23) ssl_context.options |= OP_NO_SSLv2 diff --git a/test/test_session.py b/test/test_session.py index 767ed4f34..ffe2524c5 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -19,23 +19,19 @@ # limitations under the License. -from os import remove, rename -from os.path import isfile from socket import socket from ssl import SSLSocket -from unittest import TestCase from mock import patch -from neo4j.v1.constants import KNOWN_HOSTS, SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE, SECURITY_VERIFIED +from neo4j.v1.constants import SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE from neo4j.v1.exceptions import CypherError, ResultError from neo4j.v1.session import GraphDatabase, Record, record from neo4j.v1.typesystem import Node, Relationship, Path +from test.util import ServerTestCase -KNOWN_HOSTS_BACKUP = KNOWN_HOSTS + ".backup" - -class DriverTestCase(TestCase): +class DriverTestCase(ServerTestCase): def test_healthy_session_will_be_returned_to_the_pool_on_close(self): driver = GraphDatabase.driver("bolt://localhost") @@ -89,20 +85,11 @@ def test_sessions_are_not_reused_if_still_in_use(self): assert session_1 is not session_2 -class SecurityTestCase(TestCase): - - def setUp(self): - if isfile(KNOWN_HOSTS): - rename(KNOWN_HOSTS, KNOWN_HOSTS_BACKUP) +class SecurityTestCase(ServerTestCase): - def tearDown(self): - if isfile(KNOWN_HOSTS_BACKUP): - rename(KNOWN_HOSTS_BACKUP, KNOWN_HOSTS) - - def test_default_session_uses_security_none(self): - # TODO: verify this is the correct default (maybe TOFU?) + def test_default_session_uses_tofu(self): driver = GraphDatabase.driver("bolt://localhost") - assert driver.security == SECURITY_NONE + assert driver.security == SECURITY_TRUST_ON_FIRST_USE def test_insecure_session_uses_normal_socket(self): driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE) @@ -141,7 +128,7 @@ def test_tofu_session_trusts_certificate_after_first_use(self): # session.close() -class RunTestCase(TestCase): +class RunTestCase(ServerTestCase): def test_can_run_simple_statement(self): session = GraphDatabase.driver("bolt://localhost").session() @@ -363,7 +350,7 @@ def test_can_obtain_notification_info(self): assert position.column == 1 -class ResetTestCase(TestCase): +class ResetTestCase(ServerTestCase): def test_automatic_reset_after_failure(self): with GraphDatabase.driver("bolt://localhost").session() as session: @@ -387,7 +374,7 @@ def test_defunct(self): assert session.connection.closed -class RecordTestCase(TestCase): +class RecordTestCase(ServerTestCase): def test_record_equality(self): record1 = Record(["name", "empire"], ["Nigel", "The British Empire"]) record2 = Record(["name", "empire"], ["Nigel", "The British Empire"]) @@ -461,7 +448,8 @@ def test_record_repr(self): assert repr(a_record) == "" -class TransactionTestCase(TestCase): +class TransactionTestCase(ServerTestCase): + def test_can_commit_transaction(self): with GraphDatabase.driver("bolt://localhost").session() as session: tx = session.begin_transaction() diff --git a/test/util.py b/test/util.py index 793fadb67..9148b5898 100644 --- a/test/util.py +++ b/test/util.py @@ -20,8 +20,15 @@ import functools +from os import rename +from os.path import isfile +from unittest import TestCase from neo4j.util import Watcher +from neo4j.v1.constants import KNOWN_HOSTS + + +KNOWN_HOSTS_BACKUP = KNOWN_HOSTS + ".backup" def watch(f): @@ -39,3 +46,19 @@ def wrapper(*args, **kwargs): f(*args, **kwargs) watcher.stop() return wrapper + + +class ServerTestCase(TestCase): + """ Base class for test cases that use a remote server. + """ + + known_hosts = KNOWN_HOSTS + known_hosts_backup = known_hosts + ".backup" + + def setUp(self): + if isfile(self.known_hosts): + rename(self.known_hosts, self.known_hosts_backup) + + def tearDown(self): + if isfile(self.known_hosts_backup): + rename(self.known_hosts_backup, self.known_hosts) From e9269db16981c57d15d53a84118d78104a608f0a Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Thu, 25 Feb 2016 10:37:26 +0000 Subject: [PATCH 4/4] Adjusted to default TOFU security level --- examples/test_examples.py | 5 +++-- neo4j/v1/connection.py | 3 +-- test/tck/tck_util.py | 17 ++++++++--------- test/test_session.py | 2 +- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index 4ad9acd91..001f9008c 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -19,16 +19,17 @@ # limitations under the License. -from unittest import TestCase +from test.util import ServerTestCase # tag::minimal-example-import[] from neo4j.v1 import GraphDatabase # end::minimal-example-import[] -class FreshDatabaseTestCase(TestCase): +class FreshDatabaseTestCase(ServerTestCase): def setUp(self): + ServerTestCase.setUp(self) session = GraphDatabase.driver("bolt://localhost").session() session.run("MATCH (n) DETACH DELETE n") session.close() diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index e061fe4b6..165b4c30d 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -342,9 +342,8 @@ def match_or_trust(self, host, der_encoded_certificate): with open(self.path) as f_in: for line in f_in: known_host, _, known_cert = line.strip().partition(":") + known_cert = known_cert.encode("utf-8") if host == known_host: - print("Received: %s" % base64_encoded_certificate) - print("Known: %s" % known_cert) return base64_encoded_certificate == known_cert # First use (no hosts match) try: diff --git a/test/tck/tck_util.py b/test/tck/tck_util.py index d05913f75..cb755fbe1 100644 --- a/test/tck/tck_util.py +++ b/test/tck/tck_util.py @@ -18,11 +18,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j.v1 import compat, Relationship, Node, Path +from neo4j.v1 import GraphDatabase, Relationship, Node, Path, SECURITY_NONE +from neo4j.v1.compat import string -from neo4j.v1 import GraphDatabase -driver = GraphDatabase.driver("bolt://localhost") +driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE) def send_string(text): @@ -39,11 +39,10 @@ def send_parameters(statement, parameters): return list(cursor.stream()) -def to_unicode(val): - try: - return unicode(val) - except NameError: - return str(val) +try: + to_unicode = unicode +except NameError: + to_unicode = str def string_to_type(str): @@ -91,7 +90,7 @@ def __init__(self, entity): elif isinstance(entity, Path): self.content = self.create_path(entity) elif isinstance(entity, int) or isinstance(entity, float) or isinstance(entity, - (str, compat.string)) or entity is None: + (str, string)) or entity is None: self.content['value'] = entity else: raise ValueError("Do not support object type: %s" % entity) diff --git a/test/test_session.py b/test/test_session.py index ffe2524c5..8ee09065d 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -256,7 +256,7 @@ def test_keys_with_an_error(self): _ = list(cursor.keys()) -class SummaryTestCase(TestCase): +class SummaryTestCase(ServerTestCase): def test_can_obtain_summary_after_consuming_result(self): with GraphDatabase.driver("bolt://localhost").session() as session: