Permalink
Browse files

add authentication support

* handle authentication transparently on each connection
  • Loading branch information...
1 parent 6e5c771 commit 193df577154789d2362541939640c44decb91a09 @jehiah jehiah committed Oct 10, 2011
View
@@ -24,36 +24,48 @@
import tornado.iostream
import socket
-import helpers
import struct
import logging
-from errors import ProgrammingError, IntegrityError, InterfaceError
+from bson import SON
+from errors import ProgrammingError, IntegrityError, InterfaceError, AuthenticationError
+import message
+import helpers
class Connection(object):
"""
:Parameters:
- `host`: hostname or ip of mongo host
- `port`: port to connect to
+ - `dbuser`: db user to connect with
+ - `dbpass`: db password
- `autoreconnect` (optional): auto reconnect on interface errors
"""
- def __init__(self, host, port, autoreconnect=True, pool=None):
+ def __init__(self, host, port, dbuser=None, dbpass=None, autoreconnect=True, pool=None):
assert isinstance(host, (str, unicode))
assert isinstance(port, int)
assert isinstance(autoreconnect, bool)
+ assert isinstance(dbuser, (str, unicode, None.__class__))
+ assert isinstance(dbpass, (str, unicode, None.__class__))
assert pool
self.__host = host
self.__port = port
+ self.__dbuser = dbuser
+ self.__dbpass = dbpass
self.__stream = None
self.__callback = None
self.__alive = False
- self.__connect()
+ self.__authenticate = False
self.__autoreconnect = autoreconnect
self.__pool = pool
+ self.__deferred_message = None
+ self.__deferred_callback = None
self.usage_count = 0
+ self.__connect()
def __connect(self):
+ self.usage_count = 0
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect((self.__host, self.__port))
@@ -62,6 +74,9 @@ def __connect(self):
self.__alive = True
except socket.error, error:
raise InterfaceError(error)
+
+ if self.__dbuser and self.__dbpass:
+ self.__authenticate = True
def _socket_close(self):
"""cleanup after the socket is closed by the other end"""
@@ -88,8 +103,6 @@ def close(self):
def send_message(self, message, callback):
""" send a message over the wire; callback=None indicates a safe=False call where we write and forget about it"""
- self.usage_count +=1
- # TODO: handle reconnect
if self.__callback is not None:
raise ProgrammingError('connection already in use')
@@ -99,13 +112,22 @@ def send_message(self, message, callback):
else:
raise InterfaceError('connection invalid. autoreconnect=False')
- self.__callback=callback
+ if self.__authenticate:
+ self.__deferred_message = message
+ self.__deferred_callback = callback
+ self._get_nonce(self._start_authentication)
+ else:
+ self.__callback = callback
+ self._send_message(message)
+
+ def _send_message(self, message):
+ self.usage_count +=1
# __request_id used by get_more()
(self.__request_id, data) = message
# logging.info('request id %d writing %r' % (self.__request_id, data))
try:
self.__stream.write(data)
- if callback:
+ if self.__callback:
self.__stream.read_bytes(16, callback=self._parse_header)
else:
self.__request_id = None
@@ -140,7 +162,11 @@ def _parse_response(self, response):
request_id = self.__request_id
self.__request_id = None
self.__callback = None
- self.__pool.cache(self)
+ if not self.__deferred_message:
+ # skip adding to the cache because there is something else
+ # that needs to be called on this connection for this request
+ # (ie: we authenticted, but still have to send the real req)
+ self.__pool.cache(self)
try:
response = helpers._unpack_response(response, request_id) # TODO: pass tz_awar
@@ -156,3 +182,55 @@ def _parse_response(self, response):
# logging.info('response: %s' % response)
callback(response)
+ def _start_authentication(self, response, error=None):
+ # this is the nonce response
+ if error:
+ logging.error(error)
+ logging.error(response)
+ raise AuthenticationError(error)
+ nonce = response['data'][0]['nonce']
+ key = helpers._auth_key(nonce, self.__dbuser, self.__dbpass)
+
+ self.__callback = self._finish_authentication
+ self._send_message(
+ message.query(0,
+ "%s.$cmd" % self.__pool._dbname,
+ 0,
+ 1,
+ SON([('authenticate', 1), ('user' , self.__dbuser), ('nonce' , nonce), ('key' , key)]),
+ SON({})))
+
+ def _finish_authentication(self, response, error=None):
+ if error:
+ self.__deferred_message = None
+ self.__deferred_callback = None
+ raise AuthenticationError(error)
+ assert response['number_returned'] == 1
+ response = response['data'][0]
+ if response['ok'] != 1:
+ logging.error('Failed authentication %s' % response['errmsg'])
+ self.__deferred_message = None
+ self.__deferred_callback = None
+ raise AuthenticationError(response['errmsg'])
+
+ message = self.__deferred_message
+ callback = self.__deferred_callback
+ self.__deferred_message = None
+ self.__deferred_callback = None
+ self.__callback = callback
+ # continue the original request
+ self._send_message(message)
+
+ def _get_nonce(self, callback):
+ assert self.__callback is None
+ self.__callback = callback
+ self._send_message(
+ message.query(0,
+ "%s.$cmd" % self.__pool._dbname,
+ 0,
+ 1,
+ SON({'getnonce' : 1}),
+ SON({})
+ ))
+
+
View
@@ -377,10 +377,15 @@ def find(self, spec=None, fields=None, skip=0, limit=0,
connection.send_message(
message.query(self.__query_options(),
self.full_collection_name,
- self.__skip, self.__limit,
- self.__query_spec(), self.__fields), callback=functools.partial(self._handle_response, orig_callback=callback))
- except:
+ self.__skip,
+ self.__limit,
+ self.__query_spec(),
+ self.__fields),
+ callback=functools.partial(self._handle_response, orig_callback=callback))
+ except Exception as e:
+ logging.error('Error sending query %s' % e)
connection.close()
+ raise
def _handle_response(self, result, error=None, orig_callback=None):
if error:
@@ -398,7 +403,7 @@ def __query_options(self):
options = 0
if self.__tailable:
options |= _QUERY_OPTIONS["tailable_cursor"]
- if self.__slave_okay or self.__pool.slave_okay:
+ if self.__slave_okay or self.__pool._slave_okay:
options |= _QUERY_OPTIONS["slave_okay"]
if not self.__timeout:
options |= _QUERY_OPTIONS["no_timeout"]
View
@@ -54,3 +54,6 @@ class NotSupportedError(DatabaseError):
class TooManyConnections(Error):
pass
+
+class AuthenticationError(Error):
+ pass
View
@@ -1,3 +1,4 @@
+import hashlib
import bson
from bson.son import SON
@@ -78,3 +79,24 @@ def _index_document(index_list):
"DESCENDING, or GEO2D")
index[key] = value
return index
+
+def _password_digest(username, password):
+ """Get a password digest to use for authentication.
+ """
+ if not isinstance(password, basestring):
+ raise TypeError("password must be an instance of basestring")
+ if not isinstance(username, basestring):
+ raise TypeError("username must be an instance of basestring")
+
+ md5hash = hashlib.md5()
+ md5hash.update("%s:mongo:%s" % (username.encode('utf-8'),
+ password.encode('utf-8')))
+ return unicode(md5hash.hexdigest())
+
+def _auth_key(nonce, username, password):
+ """Get an auth key to use for authentication.
+ """
+ digest = _password_digest(username, password)
+ md5hash = hashlib.md5()
+ md5hash.update("%s%s%s" % (nonce, unicode(username), digest))
+ return unicode(md5hash.hexdigest())
View
@@ -19,6 +19,7 @@
from errors import TooManyConnections, ProgrammingError
from connection import Connection
+
class ConnectionPools(object):
""" singleton to keep track of named connection pools """
@classmethod
@@ -57,6 +58,7 @@ class ConnectionPool(object):
- `maxconnections` (optional): maximum open connections for this pool. 0 for unlimited
- `maxusage` (optional): number of requests allowed on a connection before it is closed. 0 for unlimited
- `dbname`: mongo database name
+ - `slave_okay` (optional): is it okay to connect directly to and perform queries on a slave instance
- `**kwargs`: passed to `connection.Connection`
"""
@@ -89,6 +91,7 @@ def __init__(self,
self._dbname = dbname
self._slave_okay = slave_okay
self._connections = 0
+
# Establish an initial number of idle database connections:
idle = [self.connection() for i in range(mincached)]
@@ -156,14 +159,4 @@ def close(self):
finally:
self._condition.release()
- def __get_slave_okay(self):
- """Is it OK to perform queries on a secondary or slave?
- """
- return self._slave_okay
-
- def __set_slave_okay(self, value):
- """Property setter for slave_okay"""
- assert isinstance(value, bool)
- self._slave_okay = value
- slave_okay = property(__get_slave_okay, __set_slave_okay)
@@ -0,0 +1,54 @@
+#!/usr/bin/env python
+
+# mkdir /tmp/asyncmongo_sample_app2
+# mongod --port 27017 --oplogSize 10 --dbpath /tmp/asyncmongo_sample_app2
+
+# $mongo
+# >>>use test;
+# db.addUser("testuser", "testpass");
+
+# ab -n 1000 -c 16 http://127.0.0.1:8888/
+
+import sys
+import logging
+import os
+app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
+if app_dir not in sys.path:
+ logging.debug('adding %r to sys.path' % app_dir)
+ sys.path.insert(0, app_dir)
+
+import asyncmongo
+# make sure we get the local asyncmongo
+assert asyncmongo.__file__.startswith(app_dir)
+
+import tornado.ioloop
+import tornado.web
+import tornado.options
+
+class MainHandler(tornado.web.RequestHandler):
+ @tornado.web.asynchronous
+ def get(self):
+ db.users.find_one({"user_id" : 1}, callback=self._on_response)
+
+ def _on_response(self, response, error):
+ assert not error
+ self.write(str(response))
+ self.finish()
+
+
+if __name__ == "__main__":
+ tornado.options.parse_command_line()
+ application = tornado.web.Application([
+ (r"/?", MainHandler)
+ ])
+ application.listen(8888)
+ db = asyncmongo.Client(pool_id="test",
+ host='127.0.0.1',
+ port=27017,
+ mincached=5,
+ maxcached=15,
+ maxconnections=30,
+ dbname='test',
+ dbuser='testuser',
+ dbpass='testpass')
+ tornado.ioloop.IOLoop.instance().start()
@@ -0,0 +1,50 @@
+import tornado.ioloop
+import time
+import logging
+import subprocess
+
+import test_shunt
+import asyncmongo
+
+TEST_TIMESTAMP = int(time.time())
+
+class AuthenticationTest(test_shunt.MongoTest):
+ def setUp(self):
+ super(AuthenticationTest, self).setUp()
+ logging.info('creating user')
+ pipe = subprocess.Popen('''echo -e 'use test;\n db.addUser("testuser", "testpass");\n exit;' | mongo --port 27017 --host 127.0.0.1''', shell=True)
+ pipe.wait()
+
+ def test_authentication(self):
+ try:
+ test_shunt.setup()
+ db = asyncmongo.Client(pool_id='testauth', host='127.0.0.1', port=27017, dbname='test', dbuser='testuser', dbpass='testpass', maxconnections=2)
+
+ def update_callback(response, error):
+ logging.info(response)
+ assert len(response) == 1
+ test_shunt.register_called('update')
+ tornado.ioloop.IOLoop.instance().stop()
+
+ db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, upsert=True, callback=update_callback)
+
+ tornado.ioloop.IOLoop.instance().start()
+ test_shunt.assert_called('update')
+
+ def query_callback(response, error):
+ logging.info(response)
+ logging.info(error)
+ assert error is None
+ assert isinstance(response, dict)
+ assert response['_id'] == TEST_TIMESTAMP
+ assert response['test_count'] == 1
+ test_shunt.register_called('retrieved')
+ tornado.ioloop.IOLoop.instance().stop()
+
+ db.test_stats.find_one({"_id" : TEST_TIMESTAMP}, callback=query_callback)
+ tornado.ioloop.IOLoop.instance().start()
+ test_shunt.assert_called('retrieved')
+ except:
+ tornado.ioloop.IOLoop.instance().stop()
+ raise
+
View
@@ -33,7 +33,7 @@ def setUp(self):
os.makedirs(dirname)
self.temp_dirs.append(dirname)
- options = ['mongod', '--bind_ip', '127.0.0.1', '--oplogSize', '10', '--dbpath', dirname] + list(options)
+ options = ['mongod', '--bind_ip', '127.0.0.1', '--oplogSize', '10', '--dbpath', dirname, '-v'] + list(options)
logging.debug(options)
pipe = subprocess.Popen(options)
self.mongods.append(pipe)

0 comments on commit 193df57

Please sign in to comment.