diff --git a/forum/models/base_model.py b/forum/models/base_model.py index aa2470e5..62758f16 100644 --- a/forum/models/base_model.py +++ b/forum/models/base_model.py @@ -3,40 +3,39 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Optional from bson import ObjectId -from pymongo.collection import Collection +from pymongo.collection import Collection as PymongoCollection from pymongo.cursor import Cursor -from forum.mongo import MongoBackend +from forum.mongo import Database, get_database + +Collection = PymongoCollection[dict[str, Any]] class MongoBaseModel(ABC): """Abstract Class for Mongo model implementation""" - def __init__( - self, - collection_name: Optional[str] = None, - client: Optional[MongoBackend] = None, - ) -> None: - self.client: MongoBackend = client or MongoBackend(collection=collection_name) + MONGODB_DATABASE: Optional[Database] = None + COLLECTION_NAME: str = "default" @property - def _collection(self) -> Collection[Dict[str, Any]]: - """Get mongo db collection""" - return self.get_client.collection + def _collection(self) -> Collection: + return self.__get_database()[self.COLLECTION_NAME] - @property - def get_client(self) -> MongoBackend: - """Get mongo client""" - return self.client + @classmethod + def __get_database(cls) -> Database: + """Get or create static class database.""" + if cls.MONGODB_DATABASE is None: + cls.MONGODB_DATABASE = get_database() + return cls.MONGODB_DATABASE - def get(self, _id: str) -> Optional[Dict[str, Any]]: + def get(self, _id: str) -> Optional[dict[str, Any]]: """Get a document by filter""" return self._collection.find_one({"_id": _id}) - def list(self, **kwargs: Any) -> Cursor[Dict[str, Any]]: + def list(self, **kwargs: Any) -> Cursor[dict[str, Any]]: """Get a list of all documents filtered by kwargs""" return self._collection.find(kwargs) diff --git a/forum/models/contents.py b/forum/models/contents.py index fdc00938..0909205e 100644 --- a/forum/models/contents.py +++ b/forum/models/contents.py @@ -8,7 +8,6 @@ from bson import ObjectId from forum.models.base_model import MongoBaseModel -from forum.mongo import MongoBackend class Contents(MongoBaseModel): @@ -17,18 +16,7 @@ class Contents(MongoBaseModel): """ content_type: str = "" - - def __init__( - self, collection_name: str = "contents", client: Optional[MongoBackend] = None - ) -> None: - """ - Initializes the Content class. - - Args: - collection_name: The name of the MongoDB collection. - client: The MongoDB client. - """ - super().__init__(collection_name, client) + COLLECTION_NAME: str = "contents" def get( self, _id: str diff --git a/forum/models/users.py b/forum/models/users.py index 09eb3617..5adcd118 100644 --- a/forum/models/users.py +++ b/forum/models/users.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional from forum.models.base_model import MongoBaseModel -from forum.mongo import MongoBackend class Users(MongoBaseModel): @@ -13,18 +12,7 @@ class Users(MongoBaseModel): Users class for cs_comments_service user model """ - def __init__( - self, collection_name: str = "users", client: Optional[MongoBackend] = None - ) -> None: - """ - Initializes the Users class. - - Args: - collection_name: The name of the MongoDB collection. - client: The MongoDB client. - - """ - super().__init__(collection_name, client) + COLLECTION_NAME: str = "users" def insert( self, diff --git a/forum/mongo.py b/forum/mongo.py index 070b06fa..371cfe0b 100644 --- a/forum/mongo.py +++ b/forum/mongo.py @@ -1,60 +1,47 @@ """Mongo module for forum app.""" import logging -from typing import Any, Dict +from typing import Any, Optional from django.conf import settings from pymongo import MongoClient -from pymongo.collection import Collection +from pymongo.database import Database as PymongoDatabase log = logging.getLogger(__name__) - -class MongoBackend: - """Class for mongoDB cs_comments_service backend.""" - - def __init__(self, **kwargs: Any) -> None: - """ - Connect to MongoDB. - - :Parameters: - - - `host`: hostname - - `port`: port - - `user`: collection username - - `password`: collection user password - - `database`: name of the database - - `collection`: name of the collection - - `authsource`: name of the authentication database - - `extra`: parameters to pymongo.MongoClient not listed above - - """ - # Extract connection parameters from kwargs - - host = kwargs.get("host", settings.MONGO_HOST) - port = kwargs.get("port", settings.MONGO_PORT) - - user = kwargs.get("user", "") - password = kwargs.get("password", "") - - db_name = kwargs.get("database", "cs_comments_service") - collection_name = kwargs.get("collection", "") - - auth_source = kwargs.get("authsource") or None - - # Other mongo connection arguments - extra = kwargs.get("extra", {}) - - # Make timezone aware by default - extra["tz_aware"] = extra.get("tz_aware", True) - - # Connect to database and get collection - - self.connection: MongoClient[Any] = MongoClient(host=host, port=port, **extra) - - database = self.connection[db_name] - - if user or password: - database.authenticate(user, password, source=auth_source) - - self.collection: Collection[Dict[str, Any]] = database[collection_name] +Database = PymongoDatabase[dict[str, Any]] + + +def get_database( + host: str = settings.MONGO_HOST, + port: int = settings.MONGO_PORT, + user: str = "", + password: str = "", + database: str = "cs_comments_service", + authsource: Optional[str] = None, + tz_aware: bool = True, + **extra: Any +) -> Database: + """ + Connect to MongoDB. + + :Parameters: + + - `host`: hostname + - `port`: port + - `user`: collection username + - `password`: collection user password + - `database`: name of the database + - `authsource`: name of the authentication database + - `extra`: parameters to pymongo.MongoClient not listed above + + """ + connection: MongoClient[Any] = MongoClient( + host=host, port=port, tz_aware=tz_aware, **extra + ) + db = connection[database] + + if user or password: + db.authenticate(user, password, source=authsource) + + return db diff --git a/tests/conftest.py b/tests/conftest.py index 0e488ea0..50779f48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import json from typing import Any, Generator, Union -from unittest.mock import MagicMock, patch import mongomock import pytest @@ -13,59 +12,15 @@ from django.test import Client from pymongo import MongoClient -from forum.models import Comment, CommentThread, Contents, Users -from forum.mongo import MongoBackend - -@pytest.fixture(name="mock_mongo_backend") -def fixture_mock_mongo_backend() -> MagicMock: - """Mock MongoClient for tests.""" +@pytest.fixture(autouse=True) +def patch_default_mongo_database(monkeypatch: pytest.MonkeyPatch) -> None: + """Mock default mongodb database for tests.""" client: MongoClient[Any] = mongomock.MongoClient() - db = client["test_forum_db"] - - collections = { - "contents": db["contents"], - "users": db["users"], - } - - mock_backend: MagicMock = MagicMock(spec=MongoBackend) - for name, collection in collections.items(): - setattr(mock_backend, name, collection) - - return mock_backend - - -@pytest.fixture(name="patch_mongo_backend") -def fixture_patch_mongo_backend( - mock_mongo_backend: MagicMock, -) -> Generator[MagicMock, None, None]: - """Patch the MongoBackend instance with a mock.""" - with patch("forum.mongo.MongoBackend", return_value=mock_mongo_backend): - yield mock_mongo_backend - - -@pytest.fixture(name="users_model") -def fixture_users_model(patch_mongo_backend: MagicMock) -> Users: - """Get Users model with patched backend.""" - return Users(client=patch_mongo_backend.users) - - -@pytest.fixture(name="comment_thread_model") -def fixture_comment_thread_model(patch_mongo_backend: MagicMock) -> CommentThread: - """Get CommentThread model with patched backend.""" - return CommentThread(client=patch_mongo_backend.contents) - - -@pytest.fixture(name="comment_model") -def fixture_comment_model(patch_mongo_backend: MagicMock) -> Comment: - """Get Comment model with patched backend.""" - return Comment(client=patch_mongo_backend.contents) - - -@pytest.fixture(name="content_model") -def fixture_content_model(patch_mongo_backend: MagicMock) -> Contents: - """Get Contents model with patched backend.""" - return Contents(client=patch_mongo_backend.contents) + monkeypatch.setattr( + "forum.models.base_model.MongoBaseModel.MONGODB_DATABASE", + client["test_forum_db"], + ) class APIClient(Client): diff --git a/tests/test_models/test_comments.py b/tests/test_models/test_comments.py index f96678be..48ba9b7d 100644 --- a/tests/test_models/test_comments.py +++ b/tests/test_models/test_comments.py @@ -5,9 +5,9 @@ from forum.models import Comment -def test_insert(comment_model: Comment) -> None: +def test_insert() -> None: """Test insert a comment into MongoDB.""" - comment_id = comment_model.insert( + comment_id = Comment().insert( body="

This is a test comment

", course_id="course1", comment_thread_id="66af33634a1e1f001b7ed57f", @@ -15,51 +15,51 @@ def test_insert(comment_model: Comment) -> None: author_username="author_user", ) assert comment_id is not None - comment_data = comment_model.get(_id=comment_id) + comment_data = Comment().get(_id=comment_id) assert comment_data is not None assert comment_data["body"] == "

This is a test comment

" -def test_delete(comment_model: Comment) -> None: +def test_delete() -> None: """Test delete a comment from MongoDB.""" - comment_id = comment_model.insert( + comment_id = Comment().insert( body="

This is a test comment

", course_id="course1", comment_thread_id="66af33634a1e1f001b7ed57f", author_id="author1", author_username="author_user", ) - result = comment_model.delete(comment_id) + result = Comment().delete(comment_id) assert result == 1 - comment_data = comment_model.get(_id=comment_id) + comment_data = Comment().get(_id=comment_id) assert comment_data is None -def test_list(comment_model: Comment) -> None: +def test_list() -> None: """Test list all comments from MongoDB.""" course_id = "course-xyz" thread_id = "66af33634a1e1f001b7ed57f" author_id = "4" author_username = "edly" - comment_model.insert( + Comment().insert( "

Comment 1

", course_id, thread_id, author_id, author_username ) - comment_model.insert( + Comment().insert( "

Comment 2

", course_id, thread_id, author_id, author_username ) - comment_model.insert( + Comment().insert( "

Comment 3

", course_id, thread_id, author_id, author_username ) - comments_list = comment_model.list() + comments_list = Comment().list() assert len(list(comments_list)) == 3 assert all(comment["body"].startswith("

Comment") for comment in comments_list) -def test_update(comment_model: Comment) -> None: +def test_update() -> None: """Test update a comment in MongoDB.""" - comment_id = comment_model.insert( + comment_id = Comment().insert( body="

This is a test comment

", course_id="course1", comment_thread_id="66af33634a1e1f001b7ed57f", @@ -67,10 +67,10 @@ def test_update(comment_model: Comment) -> None: author_username="author_user", ) - result = comment_model.update( + result = Comment().update( comment_id=comment_id, body="

Updated comment

", ) assert result == 1 - comment_data = comment_model.get(_id=comment_id) or {} + comment_data = Comment().get(_id=comment_id) or {} assert comment_data.get("body", "") == "

Updated comment

" diff --git a/tests/test_models/test_threads.py b/tests/test_models/test_threads.py index 56f562fe..399642c7 100644 --- a/tests/test_models/test_threads.py +++ b/tests/test_models/test_threads.py @@ -6,9 +6,9 @@ from forum.models import CommentThread -def test_insert(comment_thread_model: CommentThread) -> None: +def test_insert() -> None: """Test insert a comment thread into MongoDB.""" - thread_id = comment_thread_model.insert( + thread_id = CommentThread().insert( title="Test Thread", body="This is a test thread", course_id="course1", @@ -17,15 +17,15 @@ def test_insert(comment_thread_model: CommentThread) -> None: author_username="author_user", ) assert thread_id is not None - thread_data = comment_thread_model.get(thread_id) + thread_data = CommentThread().get(thread_id) assert thread_data is not None assert thread_data["title"] == "Test Thread" assert thread_data["body"] == "This is a test thread" -def test_delete(comment_thread_model: CommentThread) -> None: +def test_delete() -> None: """Test delete a comment thread from MongoDB.""" - thread_id = comment_thread_model.insert( + thread_id = CommentThread().insert( title="Test Thread", body="This is a test thread", course_id="course1", @@ -33,15 +33,15 @@ def test_delete(comment_thread_model: CommentThread) -> None: author_id="author1", author_username="author_user", ) - result = comment_thread_model.delete(thread_id) + result = CommentThread().delete(thread_id) assert result == 1 - thread_data = comment_thread_model.get(thread_id) + thread_data = CommentThread().get(thread_id) assert thread_data is None -def test_list(comment_thread_model: CommentThread) -> None: +def test_list() -> None: """Test list all comment threads from MongoDB.""" - comment_thread_model.insert( + CommentThread().insert( "Thread 1", "Body 1", "_type", @@ -49,7 +49,7 @@ def test_list(comment_thread_model: CommentThread) -> None: "1", "user1", ) - comment_thread_model.insert( + CommentThread().insert( "Thread 2", "Body 2", "_type", @@ -57,7 +57,7 @@ def test_list(comment_thread_model: CommentThread) -> None: "1", "user1", ) - comment_thread_model.insert( + CommentThread().insert( "Thread 3", "Body 3", "_type", @@ -65,14 +65,14 @@ def test_list(comment_thread_model: CommentThread) -> None: "1", "user1", ) - threads_list = comment_thread_model.list() + threads_list = CommentThread().list() assert len(list(threads_list)) == 3 assert all(thread["title"].startswith("Thread") for thread in threads_list) -def test_update(comment_thread_model: CommentThread) -> None: +def test_update() -> None: """Test update a comment thread in MongoDB.""" - thread_id = comment_thread_model.insert( + thread_id = CommentThread().insert( title="Test Thread", body="This is a test thread", course_id="course1", @@ -81,14 +81,14 @@ def test_update(comment_thread_model: CommentThread) -> None: author_username="author_user", ) - result = comment_thread_model.update( + result = CommentThread().update( thread_id=thread_id, title="Updated Title", body="Updated body", commentable_id="new_commentable_id", ) assert result == 1 - thread_data = comment_thread_model.get(thread_id) + thread_data = CommentThread().get(thread_id) assert thread_data is not None assert thread_data["title"] == "Updated Title" assert thread_data["body"] == "Updated body" diff --git a/tests/test_models/test_users.py b/tests/test_models/test_users.py index ffa475af..3ec4c80d 100644 --- a/tests/test_models/test_users.py +++ b/tests/test_models/test_users.py @@ -6,17 +6,17 @@ from forum.models import Users -def test_get(users_model: Users) -> None: +def test_get() -> None: """Test get user from mongodb""" external_id = "test_external_id" username = "test_username" email = "test_email" - users_model.insert( + Users().insert( external_id, username, email, ) - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is not None assert user_data["_id"] == external_id assert user_data["external_id"] == external_id @@ -24,14 +24,14 @@ def test_get(users_model: Users) -> None: assert user_data["email"] == email -def test_insert(users_model: Users) -> None: +def test_insert() -> None: """Test insert user from mongodb""" external_id = "test_external_id" username = "test_username" email = "test_email" - result = users_model.insert(external_id, username, email) + result = Users().insert(external_id, username, email) assert result is not None - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is not None assert user_data["_id"] == external_id assert user_data["external_id"] == external_id @@ -39,44 +39,44 @@ def test_insert(users_model: Users) -> None: assert user_data["email"] == email -def test_delete(users_model: Users) -> None: +def test_delete() -> None: """Test delete user from mongodb""" external_id = "test_external_id" - users_model.insert(external_id, "test_username", "test_email") - result = users_model.delete(external_id) + Users().insert(external_id, "test_username", "test_email") + result = Users().delete(external_id) assert result == 1 - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is None -def test_list(users_model: Users) -> None: +def test_list() -> None: """Test list user from mongodb""" - users_model.insert( + Users().insert( external_id="user1", username="user1", email="user1", ) - users_model.insert( + Users().insert( external_id="user2", username="user2", email="user1", ) - users_model.insert( + Users().insert( external_id="user3", username="user3", email="user1", ) - users_list = users_model.list() + users_list = Users().list() assert len(list(users_list)) == 3 assert all(user["username"] in ["user1", "user2", "user3"] for user in users_list) -def test_update(users_model: Users) -> None: +def test_update() -> None: """Test update user from mongodb""" external_id = "test_external_id" username = "test_username" email = "test_email" - users_model.insert( + Users().insert( external_id=external_id, username=username, email=email, @@ -84,11 +84,11 @@ def test_update(users_model: Users) -> None: new_username = "new_username" new_email = "new_email" - result = users_model.update(external_id, username=new_username, email=new_email) + result = Users().update(external_id, username=new_username, email=new_email) assert result is not None assert result == 1 - user_data = users_model.get(external_id) + user_data = Users().get(external_id) assert user_data is not None assert user_data["external_id"] == external_id assert user_data["username"] == new_username diff --git a/tests/test_views/test_flags.py b/tests/test_views/test_flags.py index 7a8023b3..2b6ac8d3 100644 --- a/tests/test_views/test_flags.py +++ b/tests/test_views/test_flags.py @@ -7,9 +7,7 @@ from forum.models import Contents, Users -def test_comment_thread_api( - api_client: Client, users_model: Users, content_model: Contents -) -> None: +def test_comment_thread_api(api_client: Client) -> None: """ Test the comment thread flag API. @@ -17,16 +15,16 @@ def test_comment_thread_api( """ user_id = "1" comment_thread_id = "66ace22474ba69001e1440cd" - users_model.insert(user_id, username="user1", email="email1") - content_model.insert( + Users().insert(user_id, username="user1", email="email1") + Contents().insert( comment_thread_id, "3", abuse_flaggers=[], historical_abuse_flaggers=[], visible=True, ) - mock_users_class = Mock(return_value=users_model) - mock_contents_class = Mock(return_value=content_model) + mock_users_class = Mock(return_value=Users()) + mock_contents_class = Mock(return_value=Contents()) with patch("forum.models.Users", new=mock_users_class): with patch("forum.models.Contents", new=mock_contents_class): response = api_client.put( @@ -42,14 +40,12 @@ def test_comment_thread_api( data={"user_id": str(user_id)}, ) assert response.status_code == 200 - comment = content_model.get(comment_thread_id) + comment = Contents().get(comment_thread_id) assert comment is not None assert comment["abuse_flaggers"] == [] -def test_comment_flag_api( - api_client: Client, users_model: Users, content_model: Contents -) -> None: +def test_comment_flag_api(api_client: Client) -> None: """ Test the comment flag API. @@ -57,16 +53,16 @@ def test_comment_flag_api( """ user_id = "1" comment_id = "66ace22474ba69001e1440cd" - users_model.insert(user_id, username="user1", email="email1") - content_model.insert( + Users().insert(user_id, username="user1", email="email1") + Contents().insert( comment_id, "3", abuse_flaggers=[], historical_abuse_flaggers=[], visible=True, ) - mock_users_class = Mock(return_value=users_model) - mock_contents_class = Mock(return_value=content_model) + mock_users_class = Mock(return_value=Users()) + mock_contents_class = Mock(return_value=Contents()) with patch("forum.models.Users", new=mock_users_class): with patch("forum.models.Contents", new=mock_contents_class): response = api_client.put( @@ -82,7 +78,7 @@ def test_comment_flag_api( data={"user_id": str(user_id)}, ) assert response.status_code == 200 - comment = content_model.get(comment_id) + comment = Contents().get(comment_id) assert comment is not None assert comment["abuse_flaggers"] == [] @@ -91,23 +87,21 @@ def test_comment_flag_api( data={"user_id": str(user_id)}, ) assert response.status_code == 200 - comment = content_model.get(comment_id) + comment = Contents().get(comment_id) assert comment is not None assert comment["abuse_flaggers"] == [] -def test_comment_flag_api_invalid_data( - api_client: Client, users_model: Users, content_model: Contents -) -> None: +def test_comment_flag_api_invalid_data(api_client: Client) -> None: """ Test the comment flag API with invalid data. This test checks that the API returns a 400 error when the user or comment does not exist. """ user_id = "1" - users_model.insert(user_id, username="user1", email="email1") - mock_users_class = Mock(return_value=users_model) - mock_contents_class = Mock(return_value=content_model) + Users().insert(user_id, username="user1", email="email1") + mock_users_class = Mock(return_value=Users()) + mock_contents_class = Mock(return_value=Contents()) with patch("forum.models.Users", new=mock_users_class): with patch("forum.models.Contents", new=mock_contents_class): response = api_client.put(