Skip to content

Commit

Permalink
Merge pull request #63 from hbldh/pymongo3.0_feature
Browse files Browse the repository at this point in the history
Modifications to enable pymongo 3.0.x compatability.
  • Loading branch information
ranman committed Oct 19, 2015
2 parents 475e7a9 + 0d58b2e commit a21fda7
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/wiki/wiki.py
Expand Up @@ -44,7 +44,7 @@ def save_page(pagepath):
mongo.db.pages.update(
{'_id': pagepath},
{'$set': {'body': request.form['body']}},
safe=True, upsert=True)
w=1, upsert=True)
return redirect(url_for('show_page', pagepath=pagepath))

@app.errorhandler(404)
Expand Down
32 changes: 21 additions & 11 deletions flask_pymongo/__init__.py
Expand Up @@ -132,14 +132,16 @@ def key(suffix):
raise ValueError('MongoDB URI does not contain database name')
app.config[key('DBNAME')] = parsed['database']
app.config[key('READ_PREFERENCE')] = parsed['options'].get('read_preference')
app.config[key('AUTO_START_REQUEST')] = parsed['options'].get('auto_start_request', True)
app.config[key('USERNAME')] = parsed['username']
app.config[key('PASSWORD')] = parsed['password']
app.config[key('REPLICA_SET')] = parsed['options'].get('replica_set')
app.config[key('MAX_POOL_SIZE')] = parsed['options'].get('max_pool_size')
app.config[key('SOCKET_TIMEOUT_MS')] = parsed['options'].get('socket_timeout_ms', None)
app.config[key('CONNECT_TIMEOUT_MS')] = parsed['options'].get('connect_timeout_ms', None)

if pymongo.version_tuple[0] < 3:
app.config[key('AUTO_START_REQUEST')] = parsed['options'].get('auto_start_request', True)

# we will use the URI for connecting instead of HOST/PORT
app.config.pop(key('HOST'), None)
app.config.setdefault(key('PORT'), 27017)
Expand All @@ -150,10 +152,12 @@ def key(suffix):
app.config.setdefault(key('PORT'), 27017)
app.config.setdefault(key('DBNAME'), app.name)
app.config.setdefault(key('READ_PREFERENCE'), None)
app.config.setdefault(key('AUTO_START_REQUEST'), True)
app.config.setdefault(key('SOCKET_TIMEOUT_MS'), None)
app.config.setdefault(key('CONNECT_TIMEOUT_MS'), None)

if pymongo.version_tuple[0] < 3:
app.config.setdefault(key('AUTO_START_REQUEST'), True)

# these don't have defaults
app.config.setdefault(key('USERNAME'), None)
app.config.setdefault(key('PASSWORD'), None)
Expand Down Expand Up @@ -190,23 +194,26 @@ def key(suffix):

replica_set = app.config[key('REPLICA_SET')]
dbname = app.config[key('DBNAME')]
auto_start_request = app.config[key('AUTO_START_REQUEST')]
max_pool_size = app.config[key('MAX_POOL_SIZE')]
socket_timeout_ms = app.config[key('SOCKET_TIMEOUT_MS')]
connect_timeout_ms = app.config[key('CONNECT_TIMEOUT_MS')]

if pymongo.version_tuple[0] < 3:
auto_start_request = app.config[key('AUTO_START_REQUEST')]
if auto_start_request not in (True, False):
raise TypeError('%s_AUTO_START_REQUEST must be a bool' % config_prefix)

# document class is not supported by URI, using setdefault in all cases
document_class = app.config.setdefault(key('DOCUMENT_CLASS'), None)

if auto_start_request not in (True, False):
raise TypeError('%s_AUTO_START_REQUEST must be a bool' % config_prefix)

args = [host]

kwargs = {
'port': int(app.config[key('PORT')]),
'auto_start_request': auto_start_request,
'tz_aware': True,
}
if pymongo.version_tuple[0] < 3:
kwargs['auto_start_request'] = auto_start_request

if read_preference is not None:
kwargs['read_preference'] = read_preference
Expand All @@ -217,10 +224,14 @@ def key(suffix):
if connect_timeout_ms is not None:
kwargs['connectTimeoutMS'] = connect_timeout_ms

if replica_set is not None:
kwargs['replicaSet'] = replica_set
connection_cls = MongoReplicaSetClient
if pymongo.version_tuple[0] < 3:
if replica_set is not None:
kwargs['replicaSet'] = replica_set
connection_cls = MongoReplicaSetClient
else:
connection_cls = MongoClient
else:
kwargs['replicaSet'] = replica_set
connection_cls = MongoClient

if max_pool_size is not None:
Expand Down Expand Up @@ -294,7 +305,6 @@ def get_upload(filename):
except NoFile:
abort(404)


# mostly copied from flask/helpers.py, with
# modifications for GridFS
data = wrap_file(request.environ, fileobj, buffer_size=1024 * 256)
Expand Down
25 changes: 25 additions & 0 deletions flask_pymongo/wrappers.py
Expand Up @@ -42,6 +42,12 @@ def __getattr__(self, name):
return Database(self, name)
return attr

def __getitem__(self, item):
attr = super(MongoClient, self).__getitem__(item)
if isinstance(attr, database.Database):
return Database(self, item)
return attr

class MongoReplicaSetClient(mongo_replica_set_client.MongoReplicaSetClient):
"""Returns instances of :class:`flask_pymongo.wrappers.Database`
instead of :class:`pymongo.database.Database` when accessed with dot
Expand All @@ -53,6 +59,12 @@ def __getattr__(self, name):
return Database(self, name)
return attr

def __getitem__(self, item):
item_ = super(MongoReplicaSetClient, self).__getitem__(item)
if isinstance(item_, database.Database):
return Database(self, item)
return item_

class Database(database.Database):
"""Returns instances of :class:`flask_pymongo.wrappers.Collection`
instead of :class:`pymongo.collection.Collection` when accessed with dot
Expand All @@ -65,6 +77,12 @@ def __getattr__(self, name):
return Collection(self, name)
return attr

def __getitem__(self, item):
item_ = super(Database, self).__getitem__(item)
if isinstance(item_, collection.Collection):
return Collection(self, item)
return item_

class Collection(collection.Collection):
"""Custom sub-class of :class:`pymongo.collection.Collection` which
adds Flask-specific helper methods.
Expand All @@ -77,6 +95,13 @@ def __getattr__(self, name):
return Collection(db, attr.name)
return attr

def __getitem__(self, item):
item_ = super(Collection, self).__getitem__(item)
if isinstance(item_, collection.Collection):
db = self._Collection__database
return Collection(db, item_.name)
return item_

def find_one_or_404(self, *args, **kwargs):
"""Find and return a single document, or raise a 404 Not Found
exception if no document matches the query spec. See
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,2 +1,2 @@
Flask >= 0.8
PyMongo >= 2.4,<3.0
PyMongo >= 2.4
89 changes: 67 additions & 22 deletions tests/test_config.py
@@ -1,5 +1,8 @@
from tests import util

import time

import pymongo
import flask
import flask.ext.pymongo
import warnings
Expand All @@ -10,7 +13,6 @@ class CustomDict(dict):


class FlaskPyMongoConfigTest(util.FlaskRequestTest):

def setUp(self):
self.app = flask.Flask('test')
self.context = self.app.test_request_context('/')
Expand All @@ -26,8 +28,12 @@ def test_default_config_prefix(self):

mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.db.name == 'flask_pymongo_test_db', 'wrong dbname: %s' % mongo.db.name
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
if pymongo.version_tuple[0] > 2:
time.sleep(0.2)
assert ('localhost', 27017) == mongo.cx.address
else:
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017

def test_custom_config_prefix(self):
self.app.config['CUSTOM_DBNAME'] = 'flask_pymongo_test_db'
Expand All @@ -36,8 +42,12 @@ def test_custom_config_prefix(self):

mongo = flask.ext.pymongo.PyMongo(self.app, 'CUSTOM')
assert mongo.db.name == 'flask_pymongo_test_db', 'wrong dbname: %s' % mongo.db.name
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
if pymongo.version_tuple[0] > 2:
time.sleep(0.2)
assert ('localhost', 27017) == mongo.cx.address
else:
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017

def test_converts_str_to_int(self):
self.app.config['MONGO_DBNAME'] = 'flask_pymongo_test_db'
Expand All @@ -46,8 +56,12 @@ def test_converts_str_to_int(self):

mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.db.name == 'flask_pymongo_test_db', 'wrong dbname: %s' % mongo.db.name
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
if pymongo.version_tuple[0] > 2:
time.sleep(0.2)
assert ('localhost', 27017) == mongo.cx.address
else:
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017

def test_rejects_invalid_string(self):
self.app.config['MONGO_PORT'] = '27017x'
Expand All @@ -61,7 +75,7 @@ def test_multiple_pymongos(self):
for prefix in ('ONE', 'TWO'):
flask.ext.pymongo.PyMongo(self.app, config_prefix=prefix)

# this test passes if it raises no exceptions
# this test passes if it raises no exceptions

def test_config_with_uri(self):
self.app.config['MONGO_URI'] = 'mongodb://localhost:27017/flask_pymongo_test_db'
Expand All @@ -72,8 +86,12 @@ def test_config_with_uri(self):
warnings.simplefilter('ignore')
mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.db.name == 'flask_pymongo_test_db', 'wrong dbname: %s' % mongo.db.name
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
if pymongo.version_tuple[0] > 2:
time.sleep(0.2)
assert ('localhost', 27017) == mongo.cx.address
else:
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017

def test_config_with_uri_no_port(self):
self.app.config['MONGO_URI'] = 'mongodb://localhost/flask_pymongo_test_db'
Expand All @@ -84,17 +102,27 @@ def test_config_with_uri_no_port(self):
warnings.simplefilter('ignore')
mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.db.name == 'flask_pymongo_test_db', 'wrong dbname: %s' % mongo.db.name
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
if pymongo.version_tuple[0] > 2:
time.sleep(0.2)
assert ('localhost', 27017) == mongo.cx.address
else:
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017

def test_config_with_document_class(self):
self.app.config['MONGO_DOCUMENT_CLASS'] = CustomDict
mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.cx.document_class == CustomDict
if pymongo.version_tuple[0] > 2:
assert mongo.cx.codec_options.document_class == CustomDict
else:
assert mongo.cx.document_class == CustomDict

def test_config_without_document_class(self):
mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.cx.document_class == dict
if pymongo.version_tuple[0] > 2:
assert mongo.cx.codec_options.document_class == dict
else:
assert mongo.cx.document_class == dict

def test_host_with_port_does_not_get_overridden_by_separate_port_config_value(self):
self.app.config['MONGO_HOST'] = 'localhost:27017'
Expand All @@ -105,8 +133,12 @@ def test_host_with_port_does_not_get_overridden_by_separate_port_config_value(se
# work, but warn that auth should be supplied
warnings.simplefilter('ignore')
mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
if pymongo.version_tuple[0] > 2:
time.sleep(0.2)
assert ('localhost', 27017) == mongo.cx.address
else:
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017

def test_uri_prioritised_over_host_and_port(self):
self.app.config['MONGO_URI'] = 'mongodb://localhost:27017/database_name'
Expand All @@ -119,8 +151,12 @@ def test_uri_prioritised_over_host_and_port(self):
# work, but warn that auth should be supplied
warnings.simplefilter('ignore')
mongo = flask.ext.pymongo.PyMongo(self.app)
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
if pymongo.version_tuple[0] > 2:
time.sleep(0.2)
assert ('localhost', 27017) == mongo.cx.address
else:
assert mongo.cx.host == 'localhost'
assert mongo.cx.port == 27017
assert mongo.db.name == 'database_name'

def test_uri_without_database_errors_sensibly(self):
Expand All @@ -130,6 +166,7 @@ def test_uri_without_database_errors_sensibly(self):

class CustomDocumentClassTest(util.FlaskPyMongoTest):
""" Class that tests reading from DB with custom document_class """

def test_create_with_document_class(self):
""" This test doesn't use self.mongo, because it has to change config
Expand All @@ -144,14 +181,22 @@ def test_create_with_document_class(self):
# not using self.mongo, because we want to use updated config
# also using CUSTOM, to avoid duplicate config_prefix exception
mongo = flask.ext.pymongo.PyMongo(self.app, 'CUSTOM')
assert mongo.db.things.find_one() == None
assert mongo.db.things.find_one() is None
# write document and retrieve, to check if type is really CustomDict
mongo.db.things.insert({'_id': 'thing', 'val': 'foo'}, safe=True)
if pymongo.version_tuple[0] > 2:
# Write Concern is set to w=1 by default in pymongo > 3.0
mongo.db.things.insert_one({'_id': 'thing', 'val': 'foo'})
else:
mongo.db.things.insert({'_id': 'thing', 'val': 'foo'}, w=1)
assert type(mongo.db.things.find_one()) == CustomDict

def test_create_without_document_class(self):
""" This uses self.mongo, which uses config without document_class """
assert self.mongo.db.things.find_one() == None
assert self.mongo.db.things.find_one() is None
# write document and retrieve, to check if type is dict (default)
self.mongo.db.things.insert({'_id': 'thing', 'val': 'foo'}, safe=True)
if pymongo.version_tuple[0] > 2:
# Write Concern is set to w=1 by default in pymongo > 3.0
self.mongo.db.things.insert_one({'_id': 'thing', 'val': 'foo'})
else:
self.mongo.db.things.insert({'_id': 'thing', 'val': 'foo'}, w=1)
assert type(self.mongo.db.things.find_one()) == dict
14 changes: 11 additions & 3 deletions tests/test_wrappers.py
@@ -1,6 +1,8 @@
from tests import util
from werkzeug.exceptions import HTTPException

import pymongo

class CollectionTest(util.FlaskPyMongoTest):

def test_find_one_or_404(self):
Expand All @@ -11,21 +13,27 @@ def test_find_one_or_404(self):
except HTTPException as notfound:
assert notfound.code == 404, "raised wrong exception"

self.mongo.db.things.insert({'_id': 'thing', 'val': 'foo'}, safe=True)
if pymongo.version_tuple[0] > 2:
self.mongo.db.things.insert_one({'_id': 'thing', 'val': 'foo'})
else:
self.mongo.db.things.insert({'_id': 'thing', 'val': 'foo'}, w=1)

# now it should not raise
thing = self.mongo.db.things.find_one_or_404({'_id': 'thing'})
assert thing['val'] == 'foo', 'got wrong thing'


# also test with dotted-named collections
self.mongo.db.things.morethings.remove()
try:
self.mongo.db.things.morethings.find_one_or_404({'_id': 'thing'})
except HTTPException as notfound:
assert notfound.code == 404, "raised wrong exception"

self.mongo.db.things.morethings.insert({'_id': 'thing', 'val': 'foo'}, safe=True)
if pymongo.version_tuple[0] > 2:
# Write Concern is set to w=1 by default in pymongo > 3.0
self.mongo.db.things.morethings.insert_one({'_id': 'thing', 'val': 'foo'})
else:
self.mongo.db.things.morethings.insert({'_id': 'thing', 'val': 'foo'}, w=1)

# now it should not raise
thing = self.mongo.db.things.morethings.find_one_or_404({'_id': 'thing'})
Expand Down

0 comments on commit a21fda7

Please sign in to comment.