Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL support #28

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Unreleased
**New features:**
* Support SSL encrypted connection to Tarantool EE (closes [#22](https://github.com/igorcoding/asynctnt/issues/22))

## v2.0.1
* Fixed an issue with encoding datetimes less than 01-01-1970 (fixes [#29](https://github.com/igorcoding/asynctnt/issues/29))
* Fixed "Edit on Github" links in docs (fixes [#26](https://github.com/igorcoding/asynctnt/issues/26))
Expand Down
1 change: 1 addition & 0 deletions asynctnt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .const import Transport
from .connection import Connection, connect
from .iproto.protocol import (
Iterator, Response, TarantoolTuple, PushIterator,
Expand Down
144 changes: 128 additions & 16 deletions asynctnt/connection.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import asyncio
import enum
import functools
import ssl
import os
from typing import Optional, Union

from .api import Api
from .const import Transport
from .exceptions import TarantoolDatabaseError, \
ErrorCode, TarantoolError
ErrorCode, TarantoolError, SSLError
from .iproto import protocol
from .log import logger
from .stream import Stream
from .utils import get_running_loop
from .utils import get_running_loop, PY_37

__all__ = (
'Connection', 'connect', 'ConnectionState'
Expand All @@ -27,11 +29,13 @@ class ConnectionState(enum.IntEnum):

class Connection(Api):
__slots__ = (
'_host', '_port', '_username', '_password',
'_fetch_schema', '_auto_refetch_schema', '_initial_read_buffer_size',
'_encoding', '_connect_timeout', '_reconnect_timeout',
'_request_timeout', '_ping_timeout', '_loop', '_state', '_state_prev',
'_transport', '_protocol',
'_host', '_port', '_parameter_transport', '_ssl_key_file',
'_ssl_cert_file', '_ssl_ca_file', '_ssl_ciphers',
'_username', '_password', '_fetch_schema',
'_auto_refetch_schema', '_initial_read_buffer_size',
'_encoding', '_connect_timeout', '_ssl_handshake_timeout',
'_reconnect_timeout', '_request_timeout', '_ping_timeout',
'_loop', '_state', '_state_prev', '_transport', '_protocol',
'_disconnect_waiter', '_reconnect_task',
'_connect_lock', '_disconnect_lock',
'_ping_task', '__create_task'
Expand All @@ -40,11 +44,17 @@ class Connection(Api):
def __init__(self, *,
host: str = '127.0.0.1',
port: Union[int, str] = 3301,
transport: Optional[Transport] = Transport.DEFAULT,
igorcoding marked this conversation as resolved.
Show resolved Hide resolved
ssl_key_file: Optional[str] = None,
ssl_cert_file: Optional[str] = None,
ssl_ca_file: Optional[str] = None,
ssl_ciphers: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
fetch_schema: bool = True,
auto_refetch_schema: bool = True,
connect_timeout: float = 3.,
ssl_handshake_timeout: float = 3.,
request_timeout: float = -1.,
reconnect_timeout: float = 1. / 3.,
ping_timeout: float = 5.,
Expand Down Expand Up @@ -78,6 +88,22 @@ def __init__(self, *,
:param port:
Tarantool port
(pass ``/path/to/sockfile`` to connect ot unix socket)
:param transport:
This parameter can be used to configure traffic encryption.
Pass ``asynctnt.Transport.SSL`` value to enable SSL
encryption (by default there is no encryption)
:param ssl_key_file:
A path to a private SSL key file.
Optional, mandatory if server uses CA file
:param ssl_cert_file:
A path to an SSL certificate file.
Optional, mandatory if server uses CA file
:param ssl_ca_file:
A path to a trusted certificate authorities (CA) file.
Optional
:param ssl_ciphers:
A colon-separated (:) list of SSL cipher suites
the connection can use. Optional
:param username:
Username to use for auth
(if ``None`` you are connected as a guest)
Expand All @@ -93,6 +119,10 @@ def __init__(self, *,
be checked by Tarantool, so no errors will occur
:param connect_timeout:
Time in seconds how long to wait for connecting to socket
:param ssl_handshake_timeout:
Time in seconds to wait for the TLS handshake to complete
before aborting the connection (used only for a TLS
connection). Supported for Python 3.7 or newer
:param request_timeout:
Request timeout (in seconds) for all requests
(by default there is no timeout)
Expand All @@ -116,6 +146,13 @@ def __init__(self, *,
super().__init__()
self._host = host
self._port = port

self._parameter_transport = transport
self._ssl_key_file = ssl_key_file
self._ssl_cert_file = ssl_cert_file
self._ssl_ca_file = ssl_ca_file
self._ssl_ciphers = ssl_ciphers

self._username = username
self._password = password
self._fetch_schema = False if fetch_schema is None else fetch_schema
Expand All @@ -131,6 +168,7 @@ def __init__(self, *,
self._encoding = encoding or 'utf-8'

self._connect_timeout = connect_timeout
self._ssl_handshake_timeout = ssl_handshake_timeout
self._reconnect_timeout = reconnect_timeout or 0
self._request_timeout = request_timeout
self._ping_timeout = ping_timeout or 0
Expand Down Expand Up @@ -220,6 +258,54 @@ def protocol_factory(self,
on_connection_lost=self.connection_lost,
loop=self._loop)

def _create_ssl_context(self):
try:
if hasattr(ssl, 'TLSVersion'):
# Since python 3.7
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# Reset to default OpenSSL values.
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
# Require TLSv1.2, because other protocol versions don't seem
# to support the GOST cipher.
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.maximum_version = ssl.TLSVersion.TLSv1_2
else:
# Deprecated, but it works for python < 3.7
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)

if self._ssl_cert_file:
# If the password argument is not specified and a password is
# required, OpenSSL’s built-in password prompting mechanism
# will be used to interactively prompt the user for a password.
#
# We should disable this behaviour, because a python
# application that uses the connector unlikely assumes
# interaction with a human + a Tarantool implementation does
# not support this at least for now.
def password_raise_error():
raise SSLError("a password for decrypting the private " +
"key is unsupported")
context.load_cert_chain(certfile=self._ssl_cert_file,
keyfile=self._ssl_key_file,
password=password_raise_error)

if self._ssl_ca_file:
context.load_verify_locations(cafile=self._ssl_ca_file)
context.verify_mode = ssl.CERT_REQUIRED
# A Tarantool implementation does not check hostname. We don't
# do that too. As a result we don't set here:
# context.check_hostname = True

if self._ssl_ciphers:
context.set_ciphers(self._ssl_ciphers)

return context
except SSLError as e:
raise
except Exception as e:
raise SSLError(e)

async def _connect(self, return_exceptions: bool = True):
if self._loop is None:
self._loop = get_running_loop()
Expand All @@ -246,6 +332,12 @@ async def full_connect():
while True:
connected_fut = _create_future(self._loop)

ssl_context = None
ssl_handshake_timeout = None
if self._parameter_transport == Transport.SSL:
ssl_context = self._create_ssl_context()
ssl_handshake_timeout = self._ssl_handshake_timeout

if self._host.startswith('unix/'):
unix_path = self._port
assert isinstance(unix_path, str), \
Expand All @@ -257,16 +349,34 @@ async def full_connect():
'Unix socket `{}` not found'.format(
unix_path)

conn = self._loop.create_unix_connection(
functools.partial(self.protocol_factory,
connected_fut),
unix_path
)
if PY_37:
conn = self._loop.create_unix_connection(
functools.partial(self.protocol_factory,
connected_fut),
unix_path,
ssl=ssl_context,
ssl_handshake_timeout=ssl_handshake_timeout)
else:
conn = self._loop.create_unix_connection(
functools.partial(self.protocol_factory,
connected_fut),
unix_path,
ssl=ssl_context)

else:
conn = self._loop.create_connection(
functools.partial(self.protocol_factory,
connected_fut),
self._host, self._port)
if PY_37:
conn = self._loop.create_connection(
functools.partial(self.protocol_factory,
connected_fut),
self._host, self._port,
ssl=ssl_context,
ssl_handshake_timeout=ssl_handshake_timeout)
else:
conn = self._loop.create_connection(
functools.partial(self.protocol_factory,
connected_fut),
self._host, self._port,
ssl=ssl_context)

tr, pr = await conn

Expand Down Expand Up @@ -337,6 +447,8 @@ async def full_connect():

if return_exceptions:
self._reconnect_task = None
if isinstance(e, ssl.SSLError):
e = SSLError(e)
raise e

logger.exception(e)
Expand Down
5 changes: 5 additions & 0 deletions asynctnt/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import enum

class Transport(enum.IntEnum):
DEFAULT = 1
SSL = 2
6 changes: 6 additions & 0 deletions asynctnt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class TarantoolNotConnectedError(TarantoolNetworkError):
"""
pass

class SSLError(TarantoolError):
"""
Raised when something is wrong with encrypted connection
"""
pass


class ErrorCode(enum.IntEnum):
"""
Expand Down
45 changes: 43 additions & 2 deletions asynctnt/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

from asynctnt.utils import get_running_loop
from asynctnt.const import Transport

VERSION_STRING_REGEX = re.compile(r'\s*([\d.]+).*')

Expand Down Expand Up @@ -90,6 +91,11 @@ class TarantoolInstance(metaclass=abc.ABCMeta):
def __init__(self, *,
host='127.0.0.1',
port=3301,
transport=Transport.DEFAULT,
ssl_key_file=None,
ssl_cert_file=None,
ssl_ca_file=None,
ssl_ciphers=None,
console_host=None,
console_port=3302,
replication_source=None,
Expand All @@ -113,6 +119,22 @@ def __init__(self, *,
to be listening on (default = 127.0.0.1)
:param port: The port which Tarantool instance is going
to be listening on (default = 3301)
:param transport:
This parameter can be used to configure traffic encryption.
Pass ``asynctnt.Transport.SSL`` value to enable SSL
encryption (by default there is no encryption)
:param str ssl_key_file:
A path to a private SSL key file.
Mandatory if server uses SSL encryption
:param str ssl_cert_file:
A path to an SSL certificate file.
Mandatory if server uses SSL encryption
:param str ssl_ca_file:
A path to a trusted certificate authorities (CA) file.
Optional
:param str ssl_ciphers:
A colon-separated (:) list of SSL cipher suites
the server can use. Optional
:param console_host: The host which Tarantool console is going
to be listening on (to execute admin commands)
(default = host)
Expand Down Expand Up @@ -147,6 +169,11 @@ def __init__(self, *,

self._host = host
self._port = port
self._parameter_transport = transport
self._ssl_key_file = ssl_key_file
self._ssl_cert_file = ssl_cert_file
self._ssl_ca_file = ssl_ca_file
self._ssl_ciphers = ssl_ciphers
self._console_host = console_host or host
self._console_port = console_port
self._replication_source = replication_source
Expand Down Expand Up @@ -248,7 +275,7 @@ def _create_initlua_template(self):
return check_version_internal(expected, version)
end
local cfg = {
listen = "${host}:${port}",
listen = "${host}:${port}${listen_params}",
wal_mode = "${wal_mode}",
custom_proc_title = "${custom_proc_title}",
slab_alloc_arena = ${slab_alloc_arena},
Expand Down Expand Up @@ -289,9 +316,23 @@ def _render_initlua(self):
if self._specify_work_dir:
work_dir = '"' + self._root + '"'

listen_params = ''
if self._parameter_transport == Transport.SSL:
listen_params = "?transport=ssl&"
if self._ssl_key_file:
listen_params += "ssl_key_file={}&".format(self._ssl_key_file)
if self._ssl_cert_file:
listen_params += "ssl_cert_file={}&".format(self._ssl_cert_file)
if self._ssl_ca_file:
listen_params += "ssl_ca_file={}&".format(self._ssl_ca_file)
if self._ssl_ciphers:
listen_params += "ssl_ciphers={}&".format(self._ssl_ciphers)
listen_params = listen_params[:-1]

d = {
'host': self._host,
'port': self._port,
'listen_params': listen_params,
'console_host': self._console_host,
'console_port': self._console_port,
'wal_mode': self._wal_mode,
Expand Down Expand Up @@ -589,7 +630,7 @@ def bin_version(self) -> Optional[tuple]:
proc = subprocess.Popen([self._command_to_run, '-V'],
stdout=subprocess.PIPE)
output = proc.stdout.read().decode()
version_str = output.split('\n')[0].split(' ')[1]
version_str = output.split('\n')[0].replace('Tarantool ', '').replace('Enterprise ', '')
return self._parse_version(version_str)

def command(self, cmd, print_greeting=True):
Expand Down
29 changes: 29 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,32 @@ async def main():

asyncio.run(main())
```

## Connect with SSL encryption
```python
import asyncio
import asynctnt


async def main():
conn = asynctnt.Connection(host='127.0.0.1',
port=3301,
transport=asynctnt.Transport.SSL,
ssl_key_file='./ssl/host.key',
ssl_cert_file='./ssl/host.crt',
ssl_ca_file='./ssl/ca.crt',
ssl_ciphers='ECDHE-RSA-AES256-GCM-SHA384')
await conn.connect()

resp = await conn.ping()
print(resp)

await conn.disconnect()

asyncio.run(main())
```

Stdout:
```
<Response sync=4 rowcount=0 data=None>
```
Loading