Skip to content

PYTHON-4533 - Convert test/test_client.py to async #1730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ async def close(self) -> None:
self.client_ref = None
self.key_vault_coll = None
if self.mongocryptd_client:
await self.mongocryptd_client.close()
await self.mongocryptd_client.aclose()
self.mongocryptd_client = None


Expand Down Expand Up @@ -439,7 +439,7 @@ async def close(self) -> None:
self._closed = True
await self._auto_encrypter.close()
if self._internal_client:
await self._internal_client.close()
await self._internal_client.aclose()
self._internal_client = None


Expand Down
10 changes: 7 additions & 3 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,10 @@ def __init__(
# This will be used later if we fork.
AsyncMongoClient._clients[self._topology._topology_id] = self

async def aconnect(self) -> None:
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
await self._get_topology()

def _init_background(self, old_pid: Optional[int] = None) -> None:
self._topology = Topology(self._topology_settings)
# Seed the topology with the old one's pid so we can detect clients
Expand Down Expand Up @@ -1354,13 +1358,13 @@ async def __aenter__(self) -> AsyncMongoClient[_DocumentType]:
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
await self.aclose()

# See PYTHON-3084.
__iter__ = None

def __next__(self) -> NoReturn:
raise TypeError("'MongoClient' object is not iterable")
raise TypeError("'AsyncMongoClient' object is not iterable")

next = __next__

Expand Down Expand Up @@ -1490,7 +1494,7 @@ async def _end_sessions(self, session_ids: list[_ServerSession]) -> None:
# command.
pass

async def close(self) -> None:
async def aclose(self) -> None:
"""Cleanup client resources and disconnect from MongoDB.

End all server sessions created by this client by sending one or more
Expand Down
4 changes: 4 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,10 @@ def __init__(
# This will be used later if we fork.
MongoClient._clients[self._topology._topology_id] = self

def _connect(self) -> None:
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
self._get_topology()

def _init_background(self, old_pid: Optional[int] = None) -> None:
self._topology = Topology(self._topology_settings)
# Seed the topology with the old one's pid so we can detect clients
Expand Down
107 changes: 101 additions & 6 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import base64
import contextlib
import gc
import multiprocessing
import os
Expand All @@ -39,8 +40,6 @@
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
_all_users,
_create_user,
client_knobs,
db_pwd,
db_user,
Expand All @@ -62,9 +61,9 @@
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import wraps
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus

Expand Down Expand Up @@ -812,6 +811,12 @@ def require_no_api_version(self, func):
func=func,
)

def require_sync(self, func):
"""Run a test only if using the synchronous API."""
return self._require(
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
)

def mongos_seeds(self):
return ",".join("{}:{}".format(*address) for address in self.mongoses)

Expand Down Expand Up @@ -919,6 +924,32 @@ def _target() -> None:
self.assertEqual(proc.exitcode, 0)


class UnitTest(PyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""

@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())

@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())

@classmethod
def _setup_class(cls):
cls._setup_class()

@classmethod
def _tearDown_class(cls):
cls._tearDown_class()


class IntegrationTest(PyMongoTestCase):
"""Async base class for TestCases that need a connection to MongoDB to pass."""

Expand All @@ -933,6 +964,13 @@ def setUpClass(cls):
else:
asyncio.run(cls._setup_class())

@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())

@classmethod
@client_context.require_connection
def _setup_class(cls):
Expand All @@ -947,6 +985,10 @@ def _setup_class(cls):
else:
cls.credentials = {}

@classmethod
def _tearDown_class(cls):
pass

def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection."""
for c in collections:
Expand All @@ -959,7 +1001,7 @@ def patch_system_certs(self, ca_certs):
self.addCleanup(patcher.disable)


class MockClientTest(unittest.TestCase):
class MockClientTest(UnitTest):
"""Base class for TestCases that use MockClient.

This class is *not* an IntegrationTest: if properly written, MockClient
Expand All @@ -972,8 +1014,26 @@ class MockClientTest(unittest.TestCase):
# multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True.
@classmethod
@client_context.require_no_load_balancer
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())

@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())

@classmethod
@client_context.require_no_load_balancer
def _setup_class(cls):
pass

@classmethod
def _tearDown_class(cls):
pass

def setUp(self):
Expand Down Expand Up @@ -1051,3 +1111,38 @@ def print_running_clients():
processed.add(obj._topology_id)
except ReferenceError:
pass


def _all_users(db):
return {u["user"] for u in (db.command("usersInfo")).get("users", [])}


def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
cmd = SON([("createUser", user)])
# X509 doesn't use a password
if pwd:
cmd["pwd"] = pwd
cmd["roles"] = roles or ["root"]
cmd.update(**kwargs)
return authdb.command(cmd)


def connected(client):
"""Convenience to wait for a newly-constructed client to connect."""
with warnings.catch_warnings():
# Ignore warning that ping is always routed to primary even
# if client's read preference isn't PRIMARY.
warnings.simplefilter("ignore", UserWarning)
client.admin.command("ping") # Force connection.

return client


def drop_collections(db: Database):
# Drop all non-system collections in this database.
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
db.drop_collection(coll)


def remove_all_users(db: Database):
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})
Loading
Loading