Skip to content

Commit

Permalink
fix mongo password authentication and authentication database from au…
Browse files Browse the repository at this point in the history
…thSource mongo URI query param
  • Loading branch information
Charles Sanquer committed Jun 5, 2017
1 parent ba2fe77 commit 24154fb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
43 changes: 24 additions & 19 deletions hyperopt/mongoexp.py
Expand Up @@ -171,16 +171,7 @@ class ReserveTimeout(Exception):


def read_pw():
username = 'hyperopt'
password = open(os.path.join(os.getenv('HOME'), ".hyperopt")).read()[:-1]
return dict(
username=username,
password=password)


def authenticate_for_db(db):
d = read_pw()
db.authenticate(d['username'], d['password'])
return open(os.path.join(os.getenv('HOME'), ".hyperopt")).read()[:-1]


def parse_url(url, pwfile=None):
Expand All @@ -202,12 +193,20 @@ def parse_url(url, pwfile=None):

# -- parse the string as if it were an ftp address
tmp = urllib.parse.urlparse(ftp_url)
query_params = urllib.parse.parse_qs(tmp.query)

logger.info('PROTOCOL %s' % protocol)
logger.info('USERNAME %s' % tmp.username)
logger.info('HOSTNAME %s' % tmp.hostname)
logger.info('PORT %s' % tmp.port)
logger.info('PATH %s' % tmp.path)

authdbname = None
if 'authSource' in query_params and len(query_params['authSource']):
authdbname = query_params['authSource'][-1]

logger.info('AUTH DB %s' % authdbname)

try:
_, dbname, collection = tmp.path.split('/')
except:
Expand All @@ -226,11 +225,11 @@ def parse_url(url, pwfile=None):
logger.info('PASS %s' % password)
port = int(float(tmp.port)) # port has to be casted explicitly here.

return (protocol, tmp.username, password, tmp.hostname, port, dbname, collection)
return (protocol, tmp.username, password, tmp.hostname, port, dbname, collection, authdbname)


def connection_with_tunnel(host='localhost',
auth_dbname='admin', port=27017,
def connection_with_tunnel(dbname, host='localhost',
auth_dbname=None, port=27017,
ssh=False, user='hyperopt', pw=None):
if ssh:
local_port = numpy.random.randint(low=27500, high=28000)
Expand All @@ -247,10 +246,14 @@ def connection_with_tunnel(host='localhost',
else:
connection = pymongo.MongoClient(host, port, document_class=SON, w=1, j=True)
if user:
if user == 'hyperopt':
authenticate_for_db(connection[auth_dbname])
else:
raise NotImplementedError()
if not pw:
pw = read_pw()

if user == 'hyperopt' and not auth_dbname:
auth_dbname = 'admin'

connection[dbname].authenticate(user, pw, source=auth_dbname)

ssh_tunnel = None

# Note that the w=1 and j=True args to MongoClient above should:
Expand All @@ -261,19 +264,21 @@ def connection_with_tunnel(host='localhost',


def connection_from_string(s):
protocol, user, pw, host, port, db, collection = parse_url(s)
protocol, user, pw, host, port, db, collection, authdb = parse_url(s)
if protocol == 'mongo':
ssh = False
elif protocol in ('mongo+ssh', 'ssh+mongo'):
ssh = True
else:
raise ValueError('unrecognized protocol for MongoJobs', protocol)
connection, tunnel = connection_with_tunnel(
dbname=db,
ssh=ssh,
user=user,
pw=pw,
host=host,
port=port,
auth_dbname=authdb
)
return connection, tunnel, connection[db], connection[db][collection]

Expand Down Expand Up @@ -334,7 +339,7 @@ def alloc(cls, dbname, host='localhost',
auth_dbname='admin', port=27017,
jobs_coll='jobs', gfs_coll='fs', ssh=False, user=None, pw=None):
connection, tunnel = connection_with_tunnel(
host, auth_dbname, port, ssh, user, pw)
dbname, host, auth_dbname, port, ssh, user, pw)
db = connection[dbname]
gfs = gridfs.GridFS(db, collection=gfs_coll)
return cls(db, db[jobs_coll], gfs, connection, tunnel)
Expand Down
17 changes: 17 additions & 0 deletions hyperopt/tests/test_mongoexp.py
Expand Up @@ -14,6 +14,7 @@
import nose.plugins.skip

from hyperopt.base import JOB_STATE_DONE
from hyperopt.mongoexp import parse_url
from hyperopt.mongoexp import MongoTrials
from hyperopt.mongoexp import MongoWorker
from hyperopt.mongoexp import ReserveTimeout
Expand Down Expand Up @@ -112,6 +113,22 @@ def db_up(self):
except: # XXX: don't know what exceptions to put here
return False


def test_parse_url():
uris = [
'mongo://hyperopt:foobar@127.0.0.1:27017/hyperoptdb/jobs',
'mongo://hyperopt:foobar@127.0.0.1:27017/hyperoptdb/jobs?authSource=db1'
]

expected = [
('mongo', 'hyperopt', 'foobar', '127.0.0.1', 27017, 'hyperoptdb', 'jobs', None),
('mongo', 'hyperopt', 'foobar', '127.0.0.1', 27017, 'hyperoptdb', 'jobs', 'db1')
]

for i, uri in enumerate(uris):
assert parse_url(uri) == expected[i]


# -- If we can't create a TempMongo instance, then
# simply print what happened,
try:
Expand Down

0 comments on commit 24154fb

Please sign in to comment.