Skip to content

Commit

Permalink
support TLS.
Browse files Browse the repository at this point in the history
  • Loading branch information
samrushing committed Apr 14, 2015
1 parent abf0002 commit eac9451
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions coro/db/postgres/postgres.py
Expand Up @@ -233,8 +233,7 @@ class postgres_client:
DEFAULT_ADDRESS = '/tmp/.s.PGSQL.5432'
DEFAULT_USER = 'pgsql'

def __init__ (self, database, username='', password='',
address=None):
def __init__ (self, database, username='', password='', address=None, ssl_context=None):

self._backend_pid = 0
self._secret_key = 0 # used for cancellations
Expand All @@ -249,6 +248,7 @@ def __init__ (self, database, username='', password='',
else:
self.address = self.DEFAULT_ADDRESS

self.ssl_context = ssl_context
self.database = database
self.backend_parameters = {}

Expand Down Expand Up @@ -549,6 +549,19 @@ def read_packet(self):
def _startup(self):
"""Implement startup phase of Postgres protocol"""

if self.ssl_context:
self.send_packet (PG_SSLREQUEST_MSG, 80877103)
msg = self._socket.recv_exact (1) # not a normal packet?
if msg == 'S':
# willing
import coro.ssl
sock = coro.ssl.sock (self.ssl_context, fd=self._socket.fd)
sock.ssl_connect()
self._orig_socket = self._socket
self._socket = sock
else:
raise ConnectError ("unable to negotiate TLS")

# send startup packet
self.send_packet (
PG_STARTUP_MSG,
Expand Down Expand Up @@ -988,6 +1001,7 @@ def get_result(self, timeout):
# See http://www.postgresql.org/docs/9.2/static/protocol-message-formats.html

PG_STARTUP_MSG = ''
PG_SSLREQUEST_MSG = ''
PG_CANCELREQUEST_MSG = ''
PG_COMMAND_COMPLETE_MSG = 'C'
PG_COPY_DONE_MSG = 'c'
Expand Down Expand Up @@ -1466,6 +1480,12 @@ def sleep(x):
# db.lo_close(fd)
# db.query("ROLLBACK")

def test_ssl():
import coro.ssl.openssl
ctx = coro.ssl.openssl.ssl_ctx()
db = postgres_client ('mydb', 'myuser', 'mypass', ('192.168.1.99', 5432), ssl_context = ctx)
return db

def test_concurrent_dbm(tries):
dbm = database_manager(debug=True)
for i in xrange(tries):
Expand All @@ -1485,8 +1505,8 @@ def watcher(thread_ids):

if __name__ == '__main__':

import backdoor
coro.spawn (backdoor.serve)
import coro.backdoor
coro.spawn (coro.backdoor.serve)

# thread_ids = []
#
Expand Down

0 comments on commit eac9451

Please sign in to comment.