Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mongo password authentication #309

Merged
merged 1 commit into from Feb 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 24 additions & 19 deletions hyperopt/mongoexp.py
Expand Up @@ -171,16 +171,7 @@ class ReserveTimeout(Exception):


def read_pw():
username = 'hyperopt'
password = open(os.path.join(os.getenv('HOME'), ".hyperopt")).read()[:-1]
return dict(
username=username,
password=password)


def authenticate_for_db(db):
d = read_pw()
db.authenticate(d['username'], d['password'])
return open(os.path.join(os.getenv('HOME'), ".hyperopt")).read()[:-1]


def parse_url(url, pwfile=None):
Expand All @@ -202,12 +193,20 @@ def parse_url(url, pwfile=None):

# -- parse the string as if it were an ftp address
tmp = urllib.parse.urlparse(ftp_url)
query_params = urllib.parse.parse_qs(tmp.query)

logger.info('PROTOCOL %s' % protocol)
logger.info('USERNAME %s' % tmp.username)
logger.info('HOSTNAME %s' % tmp.hostname)
logger.info('PORT %s' % tmp.port)
logger.info('PATH %s' % tmp.path)

authdbname = None
if 'authSource' in query_params and len(query_params['authSource']):
authdbname = query_params['authSource'][-1]

logger.info('AUTH DB %s' % authdbname)

try:
_, dbname, collection = tmp.path.split('/')
except:
Expand All @@ -226,11 +225,11 @@ def parse_url(url, pwfile=None):
logger.info('PASS %s' % password)
port = int(float(tmp.port)) # port has to be casted explicitly here.

return (protocol, tmp.username, password, tmp.hostname, port, dbname, collection)
return (protocol, tmp.username, password, tmp.hostname, port, dbname, collection, authdbname)


def connection_with_tunnel(host='localhost',
auth_dbname='admin', port=27017,
def connection_with_tunnel(dbname, host='localhost',
auth_dbname=None, port=27017,
ssh=False, user='hyperopt', pw=None):
if ssh:
local_port = numpy.random.randint(low=27500, high=28000)
Expand All @@ -247,10 +246,14 @@ def connection_with_tunnel(host='localhost',
else:
connection = pymongo.MongoClient(host, port, document_class=SON, w=1, j=True)
if user:
if user == 'hyperopt':
authenticate_for_db(connection[auth_dbname])
else:
raise NotImplementedError()
if not pw:
pw = read_pw()

if user == 'hyperopt' and not auth_dbname:
auth_dbname = 'admin'

connection[dbname].authenticate(user, pw, source=auth_dbname)

ssh_tunnel = None

# Note that the w=1 and j=True args to MongoClient above should:
Expand All @@ -261,19 +264,21 @@ def connection_with_tunnel(host='localhost',


def connection_from_string(s):
protocol, user, pw, host, port, db, collection = parse_url(s)
protocol, user, pw, host, port, db, collection, authdb = parse_url(s)
if protocol == 'mongo':
ssh = False
elif protocol in ('mongo+ssh', 'ssh+mongo'):
ssh = True
else:
raise ValueError('unrecognized protocol for MongoJobs', protocol)
connection, tunnel = connection_with_tunnel(
dbname=db,
ssh=ssh,
user=user,
pw=pw,
host=host,
port=port,
auth_dbname=authdb
)
return connection, tunnel, connection[db], connection[db][collection]

Expand Down Expand Up @@ -334,7 +339,7 @@ def alloc(cls, dbname, host='localhost',
auth_dbname='admin', port=27017,
jobs_coll='jobs', gfs_coll='fs', ssh=False, user=None, pw=None):
connection, tunnel = connection_with_tunnel(
host, auth_dbname, port, ssh, user, pw)
dbname, host, auth_dbname, port, ssh, user, pw)
db = connection[dbname]
gfs = gridfs.GridFS(db, collection=gfs_coll)
return cls(db, db[jobs_coll], gfs, connection, tunnel)
Expand Down
17 changes: 17 additions & 0 deletions hyperopt/tests/test_mongoexp.py
Expand Up @@ -14,6 +14,7 @@
import nose.plugins.skip

from hyperopt.base import JOB_STATE_DONE
from hyperopt.mongoexp import parse_url
from hyperopt.mongoexp import MongoTrials
from hyperopt.mongoexp import MongoWorker
from hyperopt.mongoexp import ReserveTimeout
Expand Down Expand Up @@ -112,6 +113,22 @@ def db_up(self):
except: # XXX: don't know what exceptions to put here
return False


def test_parse_url():
uris = [
'mongo://hyperopt:foobar@127.0.0.1:27017/hyperoptdb/jobs',
'mongo://hyperopt:foobar@127.0.0.1:27017/hyperoptdb/jobs?authSource=db1'
]

expected = [
('mongo', 'hyperopt', 'foobar', '127.0.0.1', 27017, 'hyperoptdb', 'jobs', None),
('mongo', 'hyperopt', 'foobar', '127.0.0.1', 27017, 'hyperoptdb', 'jobs', 'db1')
]

for i, uri in enumerate(uris):
assert parse_url(uri) == expected[i]


# -- If we can't create a TempMongo instance, then
# simply print what happened,
try:
Expand Down