From 5982f6bbbfd15967c9e63fd071ee5ca9a0e1f695 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Tue, 2 Sep 2025 10:50:19 -0400 Subject: [PATCH 1/2] Add tests for DatabaseWrapper.get_connection_params() --- tests/backend_/test_base.py | 44 ++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/backend_/test_base.py b/tests/backend_/test_base.py index 7695b6f4d..75af0efdf 100644 --- a/tests/backend_/test_base.py +++ b/tests/backend_/test_base.py @@ -6,7 +6,7 @@ from django_mongodb_backend.base import DatabaseWrapper -class DatabaseWrapperTests(SimpleTestCase): +class GetConnectionParamsTests(SimpleTestCase): def test_database_name_empty(self): settings = connection.settings_dict.copy() settings["NAME"] = "" @@ -14,6 +14,48 @@ def test_database_name_empty(self): with self.assertRaisesMessage(ImproperlyConfigured, msg): DatabaseWrapper(settings).get_connection_params() + def test_host(self): + settings = connection.settings_dict.copy() + settings["HOST"] = "host" + params = DatabaseWrapper(settings).get_connection_params() + self.assertEqual(params["host"], "host") + + def test_host_empty(self): + settings = connection.settings_dict.copy() + settings["HOST"] = "" + params = DatabaseWrapper(settings).get_connection_params() + self.assertIsNone(params["host"]) + + def test_user(self): + settings = connection.settings_dict.copy() + settings["USER"] = "user" + params = DatabaseWrapper(settings).get_connection_params() + self.assertEqual(params["username"], "user") + + def test_password(self): + settings = connection.settings_dict.copy() + settings["PASSWORD"] = "password" # noqa: S105 + params = DatabaseWrapper(settings).get_connection_params() + self.assertEqual(params["password"], "password") + + def test_port(self): + settings = connection.settings_dict.copy() + settings["PORT"] = 123 + params = DatabaseWrapper(settings).get_connection_params() + self.assertEqual(params["port"], 123) + + def test_port_as_string(self): + settings = connection.settings_dict.copy() + settings["PORT"] = "123" + params = DatabaseWrapper(settings).get_connection_params() + self.assertEqual(params["port"], 123) + + def test_options(self): + settings = connection.settings_dict.copy() + settings["OPTIONS"] = {"extra": "option"} + params = DatabaseWrapper(settings).get_connection_params() + self.assertEqual(params["extra"], "option") + class DatabaseWrapperConnectionTests(TestCase): def test_set_autocommit(self): From 862c2cd204dd5767ec7dc6cb4f55e43c4f7491c8 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Fri, 29 Aug 2025 18:54:08 -0400 Subject: [PATCH 2/2] INTPYTHON-743 Allow using MongoDB connection string in DATABASES["HOST"] --- .github/workflows/mongodb_settings.py | 13 ++++--- README.md | 8 ++-- django_mongodb_backend/base.py | 29 ++++++++++++--- docs/intro/configure.rst | 38 +++++++++++-------- docs/releases/5.2.x.rst | 5 ++- tests/backend_/test_base.py | 53 ++++++++++++++++++++++++++- 6 files changed, 113 insertions(+), 33 deletions(-) diff --git a/.github/workflows/mongodb_settings.py b/.github/workflows/mongodb_settings.py index 49d44a5fc..20cd41cbc 100644 --- a/.github/workflows/mongodb_settings.py +++ b/.github/workflows/mongodb_settings.py @@ -1,13 +1,16 @@ import os -from django_mongodb_backend import parse_uri +from pymongo.uri_parser import parse_uri if mongodb_uri := os.getenv("MONGODB_URI"): - db_settings = parse_uri(mongodb_uri, db_name="dummy") - + db_settings = { + "ENGINE": "django_mongodb_backend", + "HOST": mongodb_uri, + } # Workaround for https://github.com/mongodb-labs/mongo-orchestration/issues/268 - if db_settings["USER"] and db_settings["PASSWORD"]: - db_settings["OPTIONS"].update({"tls": True, "tlsAllowInvalidCertificates": True}) + uri = parse_uri(mongodb_uri) + if uri.get("username") and uri.get("password"): + db_settings["OPTIONS"] = {"tls": True, "tlsAllowInvalidCertificates": True} DATABASES = { "default": {**db_settings, "NAME": "djangotests"}, "other": {**db_settings, "NAME": "djangotests-other"}, diff --git a/README.md b/README.md index 8c1824083..264e30bb0 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,11 @@ setting like so: ```python DATABASES = { - "default": django_mongodb_backend.parse_uri( - "", db_name="example" - ), + "default": { + "ENGINE": "django_mongodb_backend", + "HOST": "", + "NAME": "db_name", + }, } ``` diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index a70c7fbdd..f751c27fa 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -11,6 +11,7 @@ from pymongo.collection import Collection from pymongo.driver_info import DriverInfo from pymongo.mongo_client import MongoClient +from pymongo.uri_parser import parse_uri from . import __version__ as django_mongodb_backend_version from . import dbapi as Database @@ -157,6 +158,18 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): self.in_atomic_block_mongo = False # Current number of nested 'atomic' calls. self.nested_atomics = 0 + # If database "NAME" isn't specified, try to get it from HOST, if it's + # a connection string. + if self.settings_dict["NAME"] == "": # Empty string = unspecified; None = _nodb_cursor() + name_is_missing = True + host = self.settings_dict["HOST"] + if host.startswith(("mongodb://", "mongodb+srv://")): + uri = parse_uri(host) + if database := uri.get("database"): + self.settings_dict["NAME"] = database + name_is_missing = False + if name_is_missing: + raise ImproperlyConfigured('settings.DATABASES is missing the "NAME" value.') def get_collection(self, name, **kwargs): collection = Collection(self.database, name, **kwargs) @@ -183,15 +196,19 @@ def init_connection_state(self): def get_connection_params(self): settings_dict = self.settings_dict - if not settings_dict["NAME"]: - raise ImproperlyConfigured('settings.DATABASES is missing the "NAME" value.') - return { + params = { "host": settings_dict["HOST"] or None, - "port": int(settings_dict["PORT"] or 27017), - "username": settings_dict.get("USER"), - "password": settings_dict.get("PASSWORD"), **settings_dict["OPTIONS"], } + # MongoClient uses any of these parameters (including "OPTIONS" above) + # to override any corresponding values in a connection string "HOST". + if user := settings_dict.get("USER"): + params["username"] = user + if password := settings_dict.get("PASSWORD"): + params["password"] = password + if port := settings_dict.get("PORT"): + params["port"] = int(port) + return params @async_unsafe def get_new_connection(self, conn_params): diff --git a/docs/intro/configure.rst b/docs/intro/configure.rst index 831cc137e..9aebcc030 100644 --- a/docs/intro/configure.rst +++ b/docs/intro/configure.rst @@ -105,8 +105,26 @@ to match the first two numbers from your version.) Configuring the ``DATABASES`` setting ===================================== -After you've set up a project, configure Django's :setting:`DATABASES` setting -similar to this:: +After you've set up a project, configure Django's :setting:`DATABASES` setting. + +If you have a connection string, you can provide it like this:: + + DATABASES = { + "default": { + "ENGINE": "django_mongodb_backend", + "HOST": "mongodb+srv://my_user:my_password@cluster0.example.mongodb.net/?retryWrites=true&w=majority&tls=false", + "NAME": "my_database", + }, + } + +.. versionchanged:: 5.2.1 + + Support for the connection string in ``"HOST"`` was added. Previous + versions recommended using :func:`~django_mongodb_backend.utils.parse_uri`. + +Alternatively, you can separate the connection string so that your settings +look more like what you usually see with Django. This constructs a +:setting:`DATABASES` setting equivalent to the first example:: DATABASES = { "default": { @@ -117,7 +135,6 @@ similar to this:: "PASSWORD": "my_password", "PORT": 27017, "OPTIONS": { - # Example: "retryWrites": "true", "w": "majority", "tls": "false", @@ -128,8 +145,8 @@ similar to this:: For a localhost configuration, you can omit :setting:`HOST` or specify ``"HOST": "localhost"``. -:setting:`HOST` only needs a scheme prefix for SRV connections -(``mongodb+srv://``). A ``mongodb://`` prefix is never required. +If you provide a connection string in ``HOST``, any of the other values below +will override the values in the connection string. :setting:`OPTIONS` is an optional dictionary of parameters that will be passed to :class:`~pymongo.mongo_client.MongoClient`. @@ -143,17 +160,6 @@ For a replica set or sharded cluster where you have multiple hosts, include all of them in :setting:`HOST`, e.g. ``"mongodb://mongos0.example.com:27017,mongos1.example.com:27017"``. -Alternatively, if you prefer to simply paste in a MongoDB URI rather than parse -it into the format above, you can use -:func:`~django_mongodb_backend.utils.parse_uri`:: - - import django_mongodb_backend - - MONGODB_URI = "mongodb+srv://my_user:my_password@cluster0.example.mongodb.net/myDatabase?retryWrites=true&w=majority&tls=false" - DATABASES["default"] = django_mongodb_backend.parse_uri(MONGODB_URI) - -This constructs a :setting:`DATABASES` setting equivalent to the first example. - .. _configuring-database-routers-setting: Configuring the ``DATABASE_ROUTERS`` setting diff --git a/docs/releases/5.2.x.rst b/docs/releases/5.2.x.rst index cf106a497..8079ac439 100644 --- a/docs/releases/5.2.x.rst +++ b/docs/releases/5.2.x.rst @@ -10,7 +10,10 @@ Django MongoDB Backend 5.2.x New features ------------ -- ... +- Allowed :ref:`specifying the MongoDB connection string + ` in ``DATABASES["HOST"]``, eliminating the + need to use :func:`~django_mongodb_backend.utils.parse_uri` to configure the + :setting:`DATABASES` setting. Bug fixes --------- diff --git a/tests/backend_/test_base.py b/tests/backend_/test_base.py index 75af0efdf..00993ef8f 100644 --- a/tests/backend_/test_base.py +++ b/tests/backend_/test_base.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + from django.core.exceptions import ImproperlyConfigured from django.db import connection from django.db.backends.signals import connection_created @@ -6,14 +8,45 @@ from django_mongodb_backend.base import DatabaseWrapper -class GetConnectionParamsTests(SimpleTestCase): +class DatabaseWrapperTests(SimpleTestCase): def test_database_name_empty(self): settings = connection.settings_dict.copy() settings["NAME"] = "" msg = 'settings.DATABASES is missing the "NAME" value.' with self.assertRaisesMessage(ImproperlyConfigured, msg): - DatabaseWrapper(settings).get_connection_params() + DatabaseWrapper(settings) + + def test_database_name_empty_and_host_does_not_contain_database(self): + settings = connection.settings_dict.copy() + settings["NAME"] = "" + settings["HOST"] = "mongodb://localhost" + msg = 'settings.DATABASES is missing the "NAME" value.' + with self.assertRaisesMessage(ImproperlyConfigured, msg): + DatabaseWrapper(settings) + + def test_database_name_parsed_from_host(self): + settings = connection.settings_dict.copy() + settings["NAME"] = "" + settings["HOST"] = "mongodb://localhost/db" + self.assertEqual(DatabaseWrapper(settings).settings_dict["NAME"], "db") + def test_database_name_parsed_from_srv_host(self): + settings = connection.settings_dict.copy() + settings["NAME"] = "" + settings["HOST"] = "mongodb+srv://localhost/db" + # patch() prevents a crash when PyMongo attempts to resolve the + # nonexistent SRV record. + with patch("dns.resolver.resolve"): + self.assertEqual(DatabaseWrapper(settings).settings_dict["NAME"], "db") + + def test_database_name_not_overridden_by_host(self): + settings = connection.settings_dict.copy() + settings["NAME"] = "not overridden" + settings["HOST"] = "mongodb://localhost/db" + self.assertEqual(DatabaseWrapper(settings).settings_dict["NAME"], "not overridden") + + +class GetConnectionParamsTests(SimpleTestCase): def test_host(self): settings = connection.settings_dict.copy() settings["HOST"] = "host" @@ -56,6 +89,22 @@ def test_options(self): params = DatabaseWrapper(settings).get_connection_params() self.assertEqual(params["extra"], "option") + def test_unspecified_settings_omitted(self): + settings = connection.settings_dict.copy() + # django.db.utils.ConnectionHandler sets unspecified values to an empty + # string. + settings.update( + { + "USER": "", + "PASSWORD": "", + "PORT": "", + } + ) + params = DatabaseWrapper(settings).get_connection_params() + self.assertNotIn("username", params) + self.assertNotIn("password", params) + self.assertNotIn("port", params) + class DatabaseWrapperConnectionTests(TestCase): def test_set_autocommit(self):