From b35b8e66fe20309a3518a2115f48f9c3451bef28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Behmo?= Date: Thu, 15 Aug 2024 11:18:12 +0200 Subject: [PATCH] chore: refactor MongoBackend to simplify mocking The MongoBackend now provides access only to a database, and it is the responsibility of each child model to access the collection. This is done transparently via the _collection property. The MongoBackend constructor was also simplified. Because we are now using a single database object, we can drastically simplify mocking in tests. That's because we don't have to mock models one by one. And it turns out that we don't need to store the client attribute. Neither do we need to store a diffent collection name for every model instance, so we can simply store it as a class attribute. This means that we no longer need constructors for models. --- forum/models/base_model.py | 35 ++++++------ forum/models/contents.py | 14 +---- forum/models/users.py | 14 +---- forum/mongo.py | 89 +++++++++++++----------------- tests/conftest.py | 59 +++----------------- tests/test_models/test_comments.py | 32 +++++------ tests/test_models/test_threads.py | 32 +++++------ tests/test_models/test_users.py | 38 ++++++------- tests/test_views/test_flags.py | 40 ++++++-------- 9 files changed, 132 insertions(+), 221 deletions(-) diff --git a/forum/models/base_model.py b/forum/models/base_model.py index aa2470e5..62758f16 100644 --- a/forum/models/base_model.py +++ b/forum/models/base_model.py @@ -3,40 +3,39 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Optional from bson import ObjectId -from pymongo.collection import Collection +from pymongo.collection import Collection as PymongoCollection from pymongo.cursor import Cursor -from forum.mongo import MongoBackend +from forum.mongo import Database, get_database + +Collection = PymongoCollection[dict[str, Any]] class MongoBaseModel(ABC): """Abstract Class for Mongo model implementation""" - def __init__( - self, - collection_name: Optional[str] = None, - client: Optional[MongoBackend] = None, - ) -> None: - self.client: MongoBackend = client or MongoBackend(collection=collection_name) + MONGODB_DATABASE: Optional[Database] = None + COLLECTION_NAME: str = "default" @property - def _collection(self) -> Collection[Dict[str, Any]]: - """Get mongo db collection""" - return self.get_client.collection + def _collection(self) -> Collection: + return self.__get_database()[self.COLLECTION_NAME] - @property - def get_client(self) -> MongoBackend: - """Get mongo client""" - return self.client + @classmethod + def __get_database(cls) -> Database: + """Get or create static class database.""" + if cls.MONGODB_DATABASE is None: + cls.MONGODB_DATABASE = get_database() + return cls.MONGODB_DATABASE - def get(self, _id: str) -> Optional[Dict[str, Any]]: + def get(self, _id: str) -> Optional[dict[str, Any]]: """Get a document by filter""" return self._collection.find_one({"_id": _id}) - def list(self, **kwargs: Any) -> Cursor[Dict[str, Any]]: + def list(self, **kwargs: Any) -> Cursor[dict[str, Any]]: """Get a list of all documents filtered by kwargs""" return self._collection.find(kwargs) diff --git a/forum/models/contents.py b/forum/models/contents.py index fdc00938..0909205e 100644 --- a/forum/models/contents.py +++ b/forum/models/contents.py @@ -8,7 +8,6 @@ from bson import ObjectId from forum.models.base_model import MongoBaseModel -from forum.mongo import MongoBackend class Contents(MongoBaseModel): @@ -17,18 +16,7 @@ class Contents(MongoBaseModel): """ content_type: str = "" - - def __init__( - self, collection_name: str = "contents", client: Optional[MongoBackend] = None - ) -> None: - """ - Initializes the Content class. - - Args: - collection_name: The name of the MongoDB collection. - client: The MongoDB client. - """ - super().__init__(collection_name, client) + COLLECTION_NAME: str = "contents" def get( self, _id: str diff --git a/forum/models/users.py b/forum/models/users.py index 09eb3617..5adcd118 100644 --- a/forum/models/users.py +++ b/forum/models/users.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional from forum.models.base_model import MongoBaseModel -from forum.mongo import MongoBackend class Users(MongoBaseModel): @@ -13,18 +12,7 @@ class Users(MongoBaseModel): Users class for cs_comments_service user model """ - def __init__( - self, collection_name: str = "users", client: Optional[MongoBackend] = None - ) -> None: - """ - Initializes the Users class. - - Args: - collection_name: The name of the MongoDB collection. - client: The MongoDB client. - - """ - super().__init__(collection_name, client) + COLLECTION_NAME: str = "users" def insert( self, diff --git a/forum/mongo.py b/forum/mongo.py index 070b06fa..371cfe0b 100644 --- a/forum/mongo.py +++ b/forum/mongo.py @@ -1,60 +1,47 @@ """Mongo module for forum app.""" import logging -from typing import Any, Dict +from typing import Any, Optional from django.conf import settings from pymongo import MongoClient -from pymongo.collection import Collection +from pymongo.database import Database as PymongoDatabase log = logging.getLogger(__name__) - -class MongoBackend: - """Class for mongoDB cs_comments_service backend.""" - - def __init__(self, **kwargs: Any) -> None: - """ - Connect to MongoDB. - - :Parameters: - - - `host`: hostname - - `port`: port - - `user`: collection username - - `password`: collection user password - - `database`: name of the database - - `collection`: name of the collection - - `authsource`: name of the authentication database - - `extra`: parameters to pymongo.MongoClient not listed above - - """ - # Extract connection parameters from kwargs - - host = kwargs.get("host", settings.MONGO_HOST) - port = kwargs.get("port", settings.MONGO_PORT) - - user = kwargs.get("user", "") - password = kwargs.get("password", "") - - db_name = kwargs.get("database", "cs_comments_service") - collection_name = kwargs.get("collection", "") - - auth_source = kwargs.get("authsource") or None - - # Other mongo connection arguments - extra = kwargs.get("extra", {}) - - # Make timezone aware by default - extra["tz_aware"] = extra.get("tz_aware", True) - - # Connect to database and get collection - - self.connection: MongoClient[Any] = MongoClient(host=host, port=port, **extra) - - database = self.connection[db_name] - - if user or password: - database.authenticate(user, password, source=auth_source) - - self.collection: Collection[Dict[str, Any]] = database[collection_name] +Database = PymongoDatabase[dict[str, Any]] + + +def get_database( + host: str = settings.MONGO_HOST, + port: int = settings.MONGO_PORT, + user: str = "", + password: str = "", + database: str = "cs_comments_service", + authsource: Optional[str] = None, + tz_aware: bool = True, + **extra: Any +) -> Database: + """ + Connect to MongoDB. + + :Parameters: + + - `host`: hostname + - `port`: port + - `user`: collection username + - `password`: collection user password + - `database`: name of the database + - `authsource`: name of the authentication database + - `extra`: parameters to pymongo.MongoClient not listed above + + """ + connection: MongoClient[Any] = MongoClient( + host=host, port=port, tz_aware=tz_aware, **extra + ) + db = connection[database] + + if user or password: + db.authenticate(user, password, source=authsource) + + return db diff --git a/tests/conftest.py b/tests/conftest.py index 0e488ea0..50779f48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import json from typing import Any, Generator, Union -from unittest.mock import MagicMock, patch import mongomock import pytest @@ -13,59 +12,15 @@ from django.test import Client from pymongo import MongoClient -from forum.models import Comment, CommentThread, Contents, Users -from forum.mongo import MongoBackend - -@pytest.fixture(name="mock_mongo_backend") -def fixture_mock_mongo_backend() -> MagicMock: - """Mock MongoClient for tests.""" +@pytest.fixture(autouse=True) +def patch_default_mongo_database(monkeypatch: pytest.MonkeyPatch) -> None: + """Mock default mongodb database for tests.""" client: MongoClient[Any] = mongomock.MongoClient() - db = client["test_forum_db"] - - collections = { - "contents": db["contents"], - "users": db["users"], - } - - mock_backend: MagicMock = MagicMock(spec=MongoBackend) - for name, collection in collections.items(): - setattr(mock_backend, name, collection) - - return mock_backend - - -@pytest.fixture(name="patch_mongo_backend") -def fixture_patch_mongo_backend( - mock_mongo_backend: MagicMock, -) -> Generator[MagicMock, None, None]: - """Patch the MongoBackend instance with a mock.""" - with patch("forum.mongo.MongoBackend", return_value=mock_mongo_backend): - yield mock_mongo_backend - - -@pytest.fixture(name="users_model") -def fixture_users_model(patch_mongo_backend: MagicMock) -> Users: - """Get Users model with patched backend.""" - return Users(client=patch_mongo_backend.users) - - -@pytest.fixture(name="comment_thread_model") -def fixture_comment_thread_model(patch_mongo_backend: MagicMock) -> CommentThread: - """Get CommentThread model with patched backend.""" - return CommentThread(client=patch_mongo_backend.contents) - - -@pytest.fixture(name="comment_model") -def fixture_comment_model(patch_mongo_backend: MagicMock) -> Comment: - """Get Comment model with patched backend.""" - return Comment(client=patch_mongo_backend.contents) - - -@pytest.fixture(name="content_model") -def fixture_content_model(patch_mongo_backend: MagicMock) -> Contents: - """Get Contents model with patched backend.""" - return Contents(client=patch_mongo_backend.contents) + monkeypatch.setattr( + "forum.models.base_model.MongoBaseModel.MONGODB_DATABASE", + client["test_forum_db"], + ) class APIClient(Client): diff --git a/tests/test_models/test_comments.py b/tests/test_models/test_comments.py index f96678be..48ba9b7d 100644 --- a/tests/test_models/test_comments.py +++ b/tests/test_models/test_comments.py @@ -5,9 +5,9 @@ from forum.models import Comment -def test_insert(comment_model: Comment) -> None: +def test_insert() -> None: """Test insert a comment into MongoDB.""" - comment_id = comment_model.insert( + comment_id = Comment().insert( body="

This is a test comment

", course_id="course1", comment_thread_id="66af33634a1e1f001b7ed57f", @@ -15,51 +15,51 @@ def test_insert(comment_model: Comment) -> None: author_username="author_user", ) assert comment_id is not None - comment_data = comment_model.get(_id=comment_id) + comment_data = Comment().get(_id=comment_id) assert comment_data is not None assert comment_data["body"] == "

This is a test comment

" -def test_delete(comment_model: Comment) -> None: +def test_delete() -> None: """Test delete a comment from MongoDB.""" - comment_id = comment_model.insert( + comment_id = Comment().insert( body="

This is a test comment

", course_id="course1", comment_thread_id="66af33634a1e1f001b7ed57f", author_id="author1", author_username="author_user", ) - result = comment_model.delete(comment_id) + result = Comment().delete(comment_id) assert result == 1 - comment_data = comment_model.get(_id=comment_id) + comment_data = Comment().get(_id=comment_id) assert comment_data is None -def test_list(comment_model: Comment) -> None: +def test_list() -> None: """Test list all comments from MongoDB.""" course_id = "course-xyz" thread_id = "66af33634a1e1f001b7ed57f" author_id = "4" author_username = "edly" - comment_model.insert( + Comment().insert( "

Comment 1

", course_id, thread_id, author_id, author_username ) - comment_model.insert( + Comment().insert( "

Comment 2

", course_id, thread_id, author_id, author_username ) - comment_model.insert( + Comment().insert( "

Comment 3

", course_id, thread_id, author_id, author_username ) - comments_list = comment_model.list() + comments_list = Comment().list() assert len(list(comments_list)) == 3 assert all(comment["body"].startswith("

Comment") for comment in comments_list) -def test_update(comment_model: Comment) -> None: +def test_update() -> None: """Test update a comment in MongoDB.""" - comment_id = comment_model.insert( + comment_id = Comment().insert( body="

This is a test comment

", course_id="course1", comment_thread_id="66af33634a1e1f001b7ed57f", @@ -67,10 +67,10 @@ def test_update(comment_model: Comment) -> None: author_username="author_user", ) - result = comment_model.update( + result = Comment().update( comment_id=comment_id, body="

Updated comment

", ) assert result == 1 - comment_data = comment_model.get(_id=comment_id) or {} + comment_data = Comment().get(_id=comment_id) or {} assert comment_data.get("body", "") == "

Updated comment

" diff --git a/tests/test_models/test_threads.py b/tests/test_models/test_threads.py index 56f562fe..399642c7 100644 --- a/tests/test_models/test_threads.py +++ b/tests/test_models/test_threads.py @@ -6,9 +6,9 @@ from forum.models import CommentThread -def test_insert(comment_thread_model: CommentThread) -> None: +def test_insert() -> None: """Test insert a comment thread into MongoDB.""" - thread_id = comment_thread_model.insert( + thread_id = CommentThread().insert( title="Test Thread", body="This is a test thread", course_id="course1", @@ -17,15 +17,15 @@ def test_insert(comment_thread_model: CommentThread) -> None: author_username="author_user", ) assert thread_id is not None - thread_data = comment_thread_model.get(thread_id) + thread_data = CommentThread().get(thread_id) assert thread_data is not None assert thread_data["title"] == "Test Thread" assert thread_data["body"] == "This is a test thread" -def test_delete(comment_thread_model: CommentThread) -> None: +def test_delete() -> None: """Test delete a comment thread from MongoDB.""" - thread_id = comment_thread_model.insert( + thread_id = CommentThread().insert( title="Test Thread", body="This is a test thread", course_id="course1", @@ -33,15 +33,15 @@ def test_delete(comment_thread_model: CommentThread) -> None: author_id="author1", author_username="author_user", ) - result = comment_thread_model.delete(thread_id) + result = CommentThread().delete(thread_id) assert result == 1 - thread_data = comment_thread_model.get(thread_id) + thread_data = CommentThread().get(thread_id) assert thread_data is None -def test_list(comment_thread_model: CommentThread) -> None: +def test_list() -> None: """Test list all comment threads from MongoDB.""" - comment_thread_model.insert( + CommentThread().insert( "Thread 1", "Body 1", "_type", @@ -49,7 +49,7 @@ def test_list(comment_thread_model: CommentThread) -> None: "1", "user1", ) - comment_thread_model.insert( + CommentThread().insert( "Thread 2", "Body 2", "_type", @@ -57,7 +57,7 @@ def test_list(comment_thread_model: CommentThread) -> None: "1", "user1", ) - comment_thread_model.insert( + CommentThread().insert( "Thread 3", "Body 3", "_type", @@ -65,14 +65,14 @@ def test_list(comment_thread_model: CommentThread) -> None: "1", "user1", ) - threads_list = comment_thread_model.list() + threads_list = CommentThread().list() assert len(list(threads_list)) == 3 assert all(thread["title"].startswith("Thread") for thread in threads_list) -def test_update(comment_thread_model: CommentThread) -> None: +def test_update() -> None: """Test update a comment thread in MongoDB.""" - thread_id = comment_thread_model.insert( + thread_id = CommentThread().insert( title="Test Thread", body="This is a test thread", course_id="course1", @@ -81,14 +81,14 @@ def test_update(comment_thread_model: CommentThread) -> None: author_username="author_user", ) - result = comment_thread_model.update( + result = CommentThread().update( thread_id=thread_id, title="Updated Title", body="Updated body", commentable_id="new_commentable_id", ) assert result == 1 - thread_data = comment_thread_model.get(thread_id) + thread_data = CommentThread().get(thread_id) assert thread_data is not None assert thread_data["title"] == "Updated Title" assert thread_data["body"] == "Updated body" diff --git a/tests/test_models/test_users.py b/tests/test_models/test_users.py index ffa475af..3ec4c80d 100644 --- a/tests/test_models/test_users.py +++ b/tests/test_models/test_users.py @@ -6,17 +6,17 @@ from forum.models import Users -def test_get(users_model: Users) -> None: +def test_get() -> None: """Test get user from mongodb""" external_id = "test_external_id" username = "test_username" email = "test_email" - users_model.insert( + Users().insert( external_id, username, email, ) - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is not None assert user_data["_id"] == external_id assert user_data["external_id"] == external_id @@ -24,14 +24,14 @@ def test_get(users_model: Users) -> None: assert user_data["email"] == email -def test_insert(users_model: Users) -> None: +def test_insert() -> None: """Test insert user from mongodb""" external_id = "test_external_id" username = "test_username" email = "test_email" - result = users_model.insert(external_id, username, email) + result = Users().insert(external_id, username, email) assert result is not None - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is not None assert user_data["_id"] == external_id assert user_data["external_id"] == external_id @@ -39,44 +39,44 @@ def test_insert(users_model: Users) -> None: assert user_data["email"] == email -def test_delete(users_model: Users) -> None: +def test_delete() -> None: """Test delete user from mongodb""" external_id = "test_external_id" - users_model.insert(external_id, "test_username", "test_email") - result = users_model.delete(external_id) + Users().insert(external_id, "test_username", "test_email") + result = Users().delete(external_id) assert result == 1 - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is None -def test_list(users_model: Users) -> None: +def test_list() -> None: """Test list user from mongodb""" - users_model.insert( + Users().insert( external_id="user1", username="user1", email="user1", ) - users_model.insert( + Users().insert( external_id="user2", username="user2", email="user1", ) - users_model.insert( + Users().insert( external_id="user3", username="user3", email="user1", ) - users_list = users_model.list() + users_list = Users().list() assert len(list(users_list)) == 3 assert all(user["username"] in ["user1", "user2", "user3"] for user in users_list) -def test_update(users_model: Users) -> None: +def test_update() -> None: """Test update user from mongodb""" external_id = "test_external_id" username = "test_username" email = "test_email" - users_model.insert( + Users().insert( external_id=external_id, username=username, email=email, @@ -84,11 +84,11 @@ def test_update(users_model: Users) -> None: new_username = "new_username" new_email = "new_email" - result = users_model.update(external_id, username=new_username, email=new_email) + result = Users().update(external_id, username=new_username, email=new_email) assert result is not None assert result == 1 - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is not None assert user_data["external_id"] == external_id assert user_data["username"] == new_username diff --git a/tests/test_views/test_flags.py b/tests/test_views/test_flags.py index 7a8023b3..2b6ac8d3 100644 --- a/tests/test_views/test_flags.py +++ b/tests/test_views/test_flags.py @@ -7,9 +7,7 @@ from forum.models import Contents, Users -def test_comment_thread_api( - api_client: Client, users_model: Users, content_model: Contents -) -> None: +def test_comment_thread_api(api_client: Client) -> None: """ Test the comment thread flag API. @@ -17,16 +15,16 @@ def test_comment_thread_api( """ user_id = "1" comment_thread_id = "66ace22474ba69001e1440cd" - users_model.insert(user_id, username="user1", email="email1") - content_model.insert( + Users().insert(user_id, username="user1", email="email1") + Contents().insert( comment_thread_id, "3", abuse_flaggers=[], historical_abuse_flaggers=[], visible=True, ) - mock_users_class = Mock(return_value=users_model) - mock_contents_class = Mock(return_value=content_model) + mock_users_class = Mock(return_value=Users()) + mock_contents_class = Mock(return_value=Contents()) with patch("forum.models.Users", new=mock_users_class): with patch("forum.models.Contents", new=mock_contents_class): response = api_client.put( @@ -42,14 +40,12 @@ def test_comment_thread_api( data={"user_id": str(user_id)}, ) assert response.status_code == 200 - comment = content_model.get(comment_thread_id) + comment = Contents().get(comment_thread_id) assert comment is not None assert comment["abuse_flaggers"] == [] -def test_comment_flag_api( - api_client: Client, users_model: Users, content_model: Contents -) -> None: +def test_comment_flag_api(api_client: Client) -> None: """ Test the comment flag API. @@ -57,16 +53,16 @@ def test_comment_flag_api( """ user_id = "1" comment_id = "66ace22474ba69001e1440cd" - users_model.insert(user_id, username="user1", email="email1") - content_model.insert( + Users().insert(user_id, username="user1", email="email1") + Contents().insert( comment_id, "3", abuse_flaggers=[], historical_abuse_flaggers=[], visible=True, ) - mock_users_class = Mock(return_value=users_model) - mock_contents_class = Mock(return_value=content_model) + mock_users_class = Mock(return_value=Users()) + mock_contents_class = Mock(return_value=Contents()) with patch("forum.models.Users", new=mock_users_class): with patch("forum.models.Contents", new=mock_contents_class): response = api_client.put( @@ -82,7 +78,7 @@ def test_comment_flag_api( data={"user_id": str(user_id)}, ) assert response.status_code == 200 - comment = content_model.get(comment_id) + comment = Contents().get(comment_id) assert comment is not None assert comment["abuse_flaggers"] == [] @@ -91,23 +87,21 @@ def test_comment_flag_api( data={"user_id": str(user_id)}, ) assert response.status_code == 200 - comment = content_model.get(comment_id) + comment = Contents().get(comment_id) assert comment is not None assert comment["abuse_flaggers"] == [] -def test_comment_flag_api_invalid_data( - api_client: Client, users_model: Users, content_model: Contents -) -> None: +def test_comment_flag_api_invalid_data(api_client: Client) -> None: """ Test the comment flag API with invalid data. This test checks that the API returns a 400 error when the user or comment does not exist. """ user_id = "1" - users_model.insert(user_id, username="user1", email="email1") - mock_users_class = Mock(return_value=users_model) - mock_contents_class = Mock(return_value=content_model) + Users().insert(user_id, username="user1", email="email1") + mock_users_class = Mock(return_value=Users()) + mock_contents_class = Mock(return_value=Contents()) with patch("forum.models.Users", new=mock_users_class): with patch("forum.models.Contents", new=mock_contents_class): response = api_client.put(