diff --git a/forum/models/base_model.py b/forum/models/base_model.py index c64088ba..aa2470e5 100644 --- a/forum/models/base_model.py +++ b/forum/models/base_model.py @@ -23,22 +23,22 @@ def __init__( self.client: MongoBackend = client or MongoBackend(collection=collection_name) @property - def collection(self) -> Collection[Dict[str, Any]]: + def _collection(self) -> Collection[Dict[str, Any]]: """Get mongo db collection""" - return self.client.collection + return self.get_client.collection @property def get_client(self) -> MongoBackend: """Get mongo client""" return self.client - def get(self, **kwargs: Any) -> Optional[Dict[str, Any]]: + def get(self, _id: str) -> Optional[Dict[str, Any]]: """Get a document by filter""" - return self.collection.find_one(kwargs) + return self._collection.find_one({"_id": _id}) def list(self, **kwargs: Any) -> Cursor[Dict[str, Any]]: """Get a list of all documents filtered by kwargs""" - return self.collection.find(kwargs) + return self._collection.find(kwargs) @abstractmethod def insert(self, *args: Any, **kwargs: Any) -> str: @@ -55,7 +55,7 @@ def delete(self, _id: str) -> int: Returns: The number of documents deleted. """ - result = self.collection.delete_one({"_id": ObjectId(_id)}) + result = self._collection.delete_one({"_id": ObjectId(_id)}) return result.deleted_count @abstractmethod diff --git a/forum/models/contents.py b/forum/models/contents.py index db6552f8..f6761969 100644 --- a/forum/models/contents.py +++ b/forum/models/contents.py @@ -1,6 +1,8 @@ """Content Class for mongo backend.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional + +from bson import ObjectId from forum.models.base_model import MongoBaseModel from forum.mongo import MongoBackend @@ -25,19 +27,15 @@ def __init__( """ super().__init__(collection_name, client) - def get(self, **kwargs: Any) -> Optional[Dict[str, Any]]: + def get(self, _id: str) -> Optional[Dict[str, Any]]: # pylint: disable=arguments-differ """ - Retrieves a content document from the database based on provided arguments. - + Retrieves a contents document from the database based on the provided _id. Args: - kwargs: The filter arguments. - + _id: The ObjectId of the contents document to retrieve. Returns: - The thread contents if found, otherwise None. + The contents document if found, otherwise None. """ - if self.content_type: - kwargs["_type"] = self.content_type - return self.collection.find_one(kwargs) + return self._collection.find_one({"_id": ObjectId(_id)}) def list(self, **kwargs: Any) -> Any: """ @@ -51,10 +49,56 @@ def list(self, **kwargs: Any) -> Any: """ if self.content_type: kwargs["_type"] = self.content_type - return self.collection.find(kwargs) + return self._collection.find(kwargs) - def insert(self, *args: Any, **kwargs: Any) -> str: + def insert( # pylint: disable=arguments-differ + self, + _id: str, + author_id: str, + abuse_flaggers: List[str], + historical_abuse_flaggers: List[str], + visible: bool, + ) -> str: """ - Return not implemented error on the insert. + Inserts a new content document into the database. + + Args: + _id (str): The ID of the content. + author_id (str): The ID of the author who created the content. + abuse_flaggers (List[str]): A list of IDs of users who flagged the content as abusive. + historical_abuse_flaggers (List[str]): A list of IDs of users who previously flagged the content as abusive. + visible (bool): Whether the content is visible or not. + + Returns: + str: The ID of the inserted document. """ - raise NotImplementedError + content_data = { + "_id": ObjectId(_id), + "author_id": author_id, + "abuse_flaggers": abuse_flaggers, + "historical_abuse_flaggers": historical_abuse_flaggers, + "visible": visible, + } + result = self._collection.insert_one(content_data) + return str(result.inserted_id) + + def update(self, _id: str, **kwargs: Any) -> int: # pylint: disable=arguments-differ + """ + Updates a contents document in the database based on the provided _id. + + Args: + _id: The id of the contents document to update. + **kwargs: The fields to update in the contents document. + + Returns: + The number of documents modified. + """ + update_data = {} + + update_data["abuse_flaggers"] = kwargs.get("abuse_flaggers") + + result = self._collection.update_one( + {"_id": ObjectId(_id)}, + {"$set": update_data}, + ) + return result.modified_count diff --git a/forum/models/model_utils.py b/forum/models/model_utils.py new file mode 100644 index 00000000..c29073be --- /dev/null +++ b/forum/models/model_utils.py @@ -0,0 +1,86 @@ +"""Model util function for db operations.""" + +from typing import Any, Dict, Union + +from forum.models.contents import Contents +from forum.models.users import Users + + +def flag_as_abuse(user: Dict[str, Any], entity: Dict[str, Any]) -> Union[Dict[str, Any], None]: + """ + Flag an entity as abuse. + + Args: + user (Dict[str, Any]): The user who is flagging the entity as abuse. + entity (Dict[str, Any]): The entity being flagged as abuse. + + Returns: + Dict[str, Any]: The updated entity with the abuse flag. + + Raises: + ValueError: If user ID or entity is not provided. + """ + + abuse_flaggers = entity["abuse_flaggers"] + if user["_id"] not in abuse_flaggers: + abuse_flaggers.append(user["_id"]) + Contents().update( + entity["_id"], + abuse_flaggers=abuse_flaggers, + ) + + # Check if this is the first abuse flag + first_flag_added = len(entity["abuse_flaggers"]) == 1 + + # If this is the first abuse flag, update author's stats + active_flags = user.get("active_flags", 0) + 1 + if first_flag_added: + Users().update( + entity["author_id"], + active_flags=active_flags, + ) + + # Reload the object and return it as a JSON string + return Contents().get(entity["_id"]) + + +def un_flag_as_abuse(user: Dict[str, Any], entity: Dict[str, Any]) -> Union[Dict[str, Any], None]: + """ + Unflag an entity as abuse. + + Args: + user (Dict[str, Any]): The user who is unflagging the entity as abuse. + entity (Dict[str, Any]): The entity being unflagged as abuse. + + Returns: + Dict[str, Any]: The updated entity with the abuse flag removed. + + Raises: + ValueError: If user ID or entity is not provided. + """ + if user["_id"] in entity["abuse_flaggers"]: + entity["abuse_flaggers"].remove(user["_id"]) + Contents().update( + entity["_id"], + abuse_flaggers=entity["abuse_flaggers"], + ) + # TODO: Update course stats for abuse. + return Contents().get(entity["_id"]) + + +def un_flag_all_as_abuse(entity: Dict[str, Any]) -> Union[Dict[str, Any], None]: + """ + Unflag an entity as abuse for all users. + + Args: + entity (Dict[str, Any]): The entity being unflagged as abuse. + + Returns: + Dict[str, Any]: The updated entity with all abuse flags removed. + + Raises: + ValueError: If entity is not provided. + """ + Contents().update(entity["_id"], abuse_flaggers=[]) + # TODO: Update course stats for abuse. + return Contents().get(entity["_id"]) diff --git a/forum/models/threads.py b/forum/models/threads.py index b0c87419..79cba9a5 100644 --- a/forum/models/threads.py +++ b/forum/models/threads.py @@ -48,7 +48,7 @@ def get_votes( } return votes - def insert( + def insert( # type: ignore self, title: str, body: str, @@ -112,10 +112,10 @@ def insert( "updated_at": date, "last_activity_at": date, } - result = self.collection.insert_one(thread_data) + result = self._collection.insert_one(thread_data) return str(result.inserted_id) - def update( + def update( # type: ignore self, thread_id: str, thread_type: Optional[str] = None, @@ -191,7 +191,7 @@ def update( date = datetime.now() update_data["updated_at"] = date update_data["last_activity_at"] = date - result = self.collection.update_one( + result = self._collection.update_one( {"_id": ObjectId(thread_id)}, {"$set": update_data}, ) diff --git a/forum/models/users.py b/forum/models/users.py index 9fee17ed..09eb3617 100644 --- a/forum/models/users.py +++ b/forum/models/users.py @@ -59,7 +59,7 @@ def insert( "read_states": read_states, "course_stats": course_stats, } - result = self.collection.insert_one(user_data) + result = self._collection.insert_one(user_data) return str(result.inserted_id) def delete(self, _id: Any) -> int: @@ -73,7 +73,7 @@ def delete(self, _id: Any) -> int: The number of documents deleted. """ - result = self.collection.delete_one({"_id": _id}) + result = self._collection.delete_one({"_id": _id}) return result.deleted_count def update( @@ -84,17 +84,21 @@ def update( default_sort_key: Optional[str] = None, read_states: Optional[List[Dict[str, Any]]] = None, course_stats: Optional[List[Dict[str, Any]]] = None, + active_flags: Optional[int] = None, ) -> int: """ Updates a user document in the database based on the external_id. Args: external_id: The external ID of the user. - username: The new username of the user. - email: The new email of the user. - default_sort_key: The new default sort key for the user. - read_states: The new read states of the user. - course_stats: The new course statistics of the user. + **kwargs: Keyword arguments to update the user document. + Supported keys: + - username: The new username of the user. + - email: The new email of the user. + - default_sort_key: The new default sort key for the user. + - read_states: The new read states of the user. + - course_stats: The new course statistics of the user. + - active_flags: The new active flags of the user. Returns: The number of documents modified. @@ -106,12 +110,13 @@ def update( ("default_sort_key", default_sort_key), ("read_states", read_states), ("course_stats", course_stats), + ("active_flags", active_flags), ] update_data: dict[str, Any] = { field: value for field, value in fields if value is not None } - result = self.collection.update_one( + result = self._collection.update_one( {"external_id": external_id}, {"$set": update_data}, ) diff --git a/forum/serializers/contents.py b/forum/serializers/contents.py new file mode 100644 index 00000000..b01d1c77 --- /dev/null +++ b/forum/serializers/contents.py @@ -0,0 +1,62 @@ +"""Serializer class for content collection.""" + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from rest_framework import serializers + +from forum.serializers.votes import VotesSerializer + + +class ContentSerializer(serializers.Serializer[Dict[str, Any]]): + """ + Serializer for the content data. + + Attributes: + _id (CharField): The ID of the content. + votes (VotesSerializer): The votes data for the content. + visible (BooleanField): Whether the content is visible. + abuse_flaggers (ListField): List of user IDs who flagged the content for abuse. + historical_abuse_flaggers (ListField): List of user IDs who historically flagged the content for abuse. + parent_ids (ListField): List of parent IDs for the content. + at_position_list (ListField): List of positions for the content. + body (CharField): The body text of the content. + course_id (CharField): The ID of the course associated with the content. + _type (CharField): The type of content. + endorsed (BooleanField): Whether the content is endorsed. + anonymous (BooleanField): Whether the content is anonymous. + anonymous_to_peers (BooleanField): Whether the content is anonymous to peers. + parent_id (CharField): The ID of the parent content. + author_id (CharField): The ID of the author of the content. + comment_thread_id (CharField): The ID of the comment thread associated with the content. + child_count (IntegerField): The number of child comments. + depth (IntegerField): The depth of the content in the comment thread. + author_username (CharField): The username of the author of the content. + sk (CharField): The sorting key for the content. + updated_at (DateTimeField): The date and time the content was last updated. + created_at (DateTimeField): The date and time the content was created. + """ + + _id = serializers.CharField() + votes = VotesSerializer(allow_null=True) + visible = serializers.BooleanField(allow_null=True) + abuse_flaggers = serializers.ListField(child=serializers.CharField(), allow_null=True) + historical_abuse_flaggers = \ + serializers.ListField(child=serializers.CharField(), allow_null=True) + parent_ids = serializers.ListField(child=serializers.CharField(), allow_null=True) + at_position_list = serializers.ListField(allow_null=True) + body = serializers.CharField(allow_null=True) + course_id = serializers.CharField(allow_null=True) + _type = serializers.CharField(allow_null=True) + endorsed = serializers.BooleanField(allow_null=True) + anonymous = serializers.BooleanField(allow_null=True) + anonymous_to_peers = serializers.BooleanField(allow_null=True) + parent_id = serializers.CharField(allow_null=True) + author_id = serializers.CharField(allow_null=True) + comment_thread_id = serializers.CharField(allow_null=True) + child_count = serializers.IntegerField(allow_null=True) + depth = serializers.IntegerField(allow_null=True) + author_username = serializers.CharField(allow_null=True) + sk = serializers.CharField(allow_null=True) + updated_at = serializers.DateTimeField(allow_null=True) + created_at = serializers.DateTimeField(allow_null=True) diff --git a/forum/serializers/votes.py b/forum/serializers/votes.py new file mode 100644 index 00000000..4730b26c --- /dev/null +++ b/forum/serializers/votes.py @@ -0,0 +1,30 @@ +""" +Serializer for votes data. + +Serializes the votes field in the ContentSerializer. +""" + +from rest_framework import serializers + + +class VotesSerializer(serializers.Serializer): # type: ignore + """ + Serializer for votes data. + + Handles data of type dict[str, int]. + + Attributes: + up (list[str]): List of user IDs who upvoted the content. + down (list[str]): List of user IDs who downvoted the content. + up_count (int): Total number of upvotes. + down_count (int): Total number of downvotes. + count (int): Total number of votes. + point (int): The point value of the content. + """ + + up = serializers.ListField(child=serializers.CharField()) + down = serializers.ListField(child=serializers.CharField()) + up_count = serializers.IntegerField() + down_count = serializers.IntegerField() + count = serializers.IntegerField() + point = serializers.IntegerField() diff --git a/forum/urls.py b/forum/urls.py index 858d49ff..9bc4becb 100644 --- a/forum/urls.py +++ b/forum/urls.py @@ -4,10 +4,13 @@ from django.urls import include, path +from forum.views.flags import CommentFlagAPIView, ThreadFlagAPIView from forum.views.proxy import ForumProxyAPIView api_patterns = [ # Proxy view for various API endpoints + path("comments//abuse_", CommentFlagAPIView.as_view(), name="comment-flags-api"), + path("threads//abuse_", ThreadFlagAPIView.as_view(), name="thread-flags-api"), path("", ForumProxyAPIView.as_view(), name="forum_proxy"), ] diff --git a/tests/__init__.py b/forum/views/__init__.py similarity index 100% rename from tests/__init__.py rename to forum/views/__init__.py diff --git a/forum/views/flags.py b/forum/views/flags.py new file mode 100644 index 00000000..194b7e77 --- /dev/null +++ b/forum/views/flags.py @@ -0,0 +1,102 @@ +"""Forum Flag API Views.""" + +from rest_framework import status +from rest_framework.permissions import AllowAny +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView + +from forum.models.contents import Contents +from forum.models.model_utils import flag_as_abuse, un_flag_all_as_abuse, un_flag_as_abuse +from forum.models.users import Users +from forum.serializers.contents import ContentSerializer + + +class CommentFlagAPIView(APIView): + """ + API View for flagging/unflagging comments. + + Handles PUT requests to flag or unflag a comment. + """ + + permission_classes = (AllowAny,) + + def put(self, request: Request, comment_id: str, action: str) -> Response: + """ + Flag or unflag a comment. + + Parameters: + request (Request): The incoming request. + comment_id (str): The ID of the comment to flag/unflag. + action (str): The action to take (either "flag" or "unflag"). + + Returns: + Response: A response with the updated comment data. + """ + request_data = request.data + user = Users().get(request_data["user_id"]) + content = Contents().get(comment_id) + if not (user and content): + return Response( + {"error": "User / Comment doesn't exist"}, + status=status.HTTP_400_BAD_REQUEST + ) + if action == "flag": + comment = flag_as_abuse(user, content) + elif action == "unflag": + if request_data.get("all") and request_data.get("all") is True: + comment = un_flag_all_as_abuse(content) + else: + comment = un_flag_as_abuse(user, content) + else: + return Response( + {"error": "Invalid action"}, + status=status.HTTP_400_BAD_REQUEST + ) + serializer = ContentSerializer(comment) + return Response(serializer.data, status=status.HTTP_200_OK) + + +class ThreadFlagAPIView(APIView): + """ + API View for flagging/unflagging threads. + + Handles PUT requests to flag or unflag a thread. + """ + + permission_classes = (AllowAny,) + + def put(self, request: Request, thread_id: str, action: str) -> Response: + """ + Flag or unflag a thread. + + Parameters: + request (Request): The incoming request. + thread_id (str): The ID of the thread to flag/unflag. + action (str): The action to take (either "flag" or "unflag"). + + Returns: + Response: A response with the updated thread data. + """ + request_data = request.data + user = Users().get(request_data["user_id"]) + content = Contents().get(thread_id) + if not (user and content): + return Response( + {"error": "User / Comment doesn't exist"}, + status=status.HTTP_400_BAD_REQUEST + ) + if action == "flag": + thread = flag_as_abuse(user, content) + elif action == "unflag": + if request_data.get("all"): + thread = un_flag_all_as_abuse(content) + else: + thread = un_flag_as_abuse(user, content) + else: + return Response( + {"error": "Invalid action"}, + status=status.HTTP_400_BAD_REQUEST, + ) + serializer = ContentSerializer(thread) + return Response(serializer.data, status=status.HTTP_200_OK) diff --git a/test_settings.py b/test_settings.py index ad6b6aa5..d3be842d 100644 --- a/test_settings.py +++ b/test_settings.py @@ -44,9 +44,9 @@ def root(*args): SECRET_KEY = 'insecure-secret-key' MIDDLEWARE = ( + 'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', ) TEMPLATES = [{ @@ -61,3 +61,5 @@ def root(*args): }] FORUM_PORT = "4567" +MONGO_HOST = "mongo-test-url" +MONGO_PORT = 27017 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..f9696547 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,100 @@ +# conftest.py +""" +Init file for tests. +""" + +import json +from typing import Any, Generator, Union +from unittest.mock import MagicMock, patch + +import mongomock +import pytest +from django.http.response import HttpResponse +from django.test import Client +from pymongo import MongoClient + +from forum.models.contents import Contents +from forum.models.threads import CommentThread +from forum.models.users import Users +from forum.mongo import MongoBackend + + +@pytest.fixture(name="mock_mongo_backend") +def fixture_mock_mongo_backend() -> MagicMock: + """Mock MongoClient for tests.""" + client: MongoClient[Any] = mongomock.MongoClient() + db = client["test_forum_db"] + + collections = { + "contents": db["contents"], + "users": db["users"], + } + + mock_backend: MagicMock = MagicMock(spec=MongoBackend) + for name, collection in collections.items(): + setattr(mock_backend, name, collection) + + return mock_backend + + +@pytest.fixture(name="patch_mongo_backend") +def fixture_patch_mongo_backend( + mock_mongo_backend: MagicMock, +) -> Generator[MagicMock, None, None]: + """Patch the MongoBackend instance with a mock.""" + with patch("forum.mongo.MongoBackend", return_value=mock_mongo_backend): + yield mock_mongo_backend + + +@pytest.fixture(name="users_model") +def fixture_users_model(patch_mongo_backend: MagicMock) -> Users: + """Get Users model with patched backend.""" + return Users(client=patch_mongo_backend.users) + + +@pytest.fixture(name="comment_thread_model") +def fixture_comment_thread_model(patch_mongo_backend: MagicMock) -> CommentThread: + """Get CommentThread model with patched backend.""" + return CommentThread(client=patch_mongo_backend.contents) + + +@pytest.fixture(name="content_model") +def fixture_content_model(patch_mongo_backend: MagicMock) -> Contents: + """Get Contents model with patched backend.""" + return Contents(client=patch_mongo_backend.contents) + + +class APIClient(Client): + """ + Extends the Django test client to include a custom PUT method. + + This client sends JSON data with the correct headers. + """ + + def put( # type: ignore[override,no-untyped-def] # pylint: disable=arguments-differ + self, path: str, data: Any, **kwargs + ) -> Union[HttpResponse, Any]: + """ + Send a PUT request with JSON data. + + Args: + path (str): The URL path to send the request to. + data (dict): The data to be sent in the request body. + **kwargs: Additional keyword arguments to be passed to the parent method. + + Returns: + The response object from the request. + """ + headers = { + "accept": "application/json", + "content-type": "application/json", + "HTTP_X_API_KEY": "your_api_key", + } + return super().put(path, data=json.dumps(data), headers=headers, **kwargs) + + +@pytest.fixture(name="api_client") +def fixture_api_client() -> Generator[APIClient, Any, Any]: + """Create an API client for testing.""" + client = APIClient() + yield client diff --git a/tests/test_models/test_threads.py b/tests/test_models/test_threads.py index 8be93250..4bca278f 100644 --- a/tests/test_models/test_threads.py +++ b/tests/test_models/test_threads.py @@ -3,8 +3,6 @@ Tests for the `CommentThread` model. """ -from bson import ObjectId - from forum.models.threads import CommentThread @@ -19,7 +17,7 @@ def test_insert(comment_thread_model: CommentThread) -> None: author_username="author_user", ) assert thread_id is not None - thread_data = comment_thread_model.get(_id=ObjectId(thread_id)) + thread_data = comment_thread_model.get(thread_id) assert thread_data is not None assert thread_data["title"] == "Test Thread" assert thread_data["body"] == "This is a test thread" @@ -37,18 +35,20 @@ def test_delete(comment_thread_model: CommentThread) -> None: ) result = comment_thread_model.delete(thread_id) assert result == 1 - thread_data = comment_thread_model.get(_id=ObjectId(thread_id)) + thread_data = comment_thread_model.get(thread_id) assert thread_data is None def test_list(comment_thread_model: CommentThread) -> None: """Test list all comment threads from MongoDB.""" - comment_thread_model.collection.insert_many( - [ - {"title": "Thread 1", "body": "Body 1", "_type": "CommentThread"}, - {"title": "Thread 2", "body": "Body 2", "_type": "CommentThread"}, - {"title": "Thread 3", "body": "Body 3", "_type": "CommentThread"}, - ] + comment_thread_model.insert( + "Thread 1", "Body 1", "_type", "CommentThread", "1", "user1", + ) + comment_thread_model.insert( + "Thread 2", "Body 2", "_type", "CommentThread", "1", "user1", + ) + comment_thread_model.insert( + "Thread 3", "Body 3", "_type", "CommentThread", "1", "user1", ) threads_list = comment_thread_model.list() assert len(list(threads_list)) == 3 @@ -73,7 +73,7 @@ def test_update(comment_thread_model: CommentThread) -> None: commentable_id="new_commentable_id", ) assert result == 1 - thread_data = comment_thread_model.get(_id=ObjectId(thread_id)) + thread_data = comment_thread_model.get(thread_id) assert thread_data is not None assert thread_data["title"] == "Updated Title" assert thread_data["body"] == "Updated body" diff --git a/tests/test_models/test_users.py b/tests/test_models/test_users.py index 487c0004..85186d14 100644 --- a/tests/test_models/test_users.py +++ b/tests/test_models/test_users.py @@ -11,15 +11,12 @@ def test_get(users_model: Users) -> None: external_id = "test_external_id" username = "test_username" email = "test_email" - users_model.collection.insert_one( - { - "_id": external_id, - "external_id": external_id, - "username": username, - "email": email, - } + users_model.insert( + external_id, + username, + email, ) - user_data = users_model.get(external_id=external_id) + user_data = users_model.get(external_id) assert user_data is not None assert user_data["_id"] == external_id assert user_data["external_id"] == external_id @@ -34,7 +31,7 @@ def test_insert(users_model: Users) -> None: email = "test_email" result = users_model.insert(external_id, username, email) assert result is not None - user_data = users_model.get(external_id=external_id) + user_data = users_model.get(external_id) assert user_data is not None assert user_data["_id"] == external_id assert user_data["external_id"] == external_id @@ -45,21 +42,29 @@ def test_insert(users_model: Users) -> None: def test_delete(users_model: Users) -> None: """Test delete user from mongodb""" external_id = "test_external_id" - users_model.collection.insert_one({"_id": external_id, "external_id": external_id}) + users_model.insert(external_id, "test_username", "test_email") result = users_model.delete(external_id) assert result == 1 - user_data = users_model.get(external_id=external_id) + user_data = users_model.get(external_id) assert user_data is None def test_list(users_model: Users) -> None: """Test list user from mongodb""" - users_model.collection.insert_many( - [ - {"_id": "user1", "external_id": "user1", "username": "user1"}, - {"_id": "user2", "external_id": "user2", "username": "user2"}, - {"_id": "user3", "external_id": "user3", "username": "user3"}, - ] + users_model.insert( + external_id="user1", + username="user1", + email="user1", + ) + users_model.insert( + external_id="user2", + username="user2", + email="user1", + ) + users_model.insert( + external_id="user3", + username="user3", + email="user1", ) users_list = users_model.list() assert len(list(users_list)) == 3 @@ -71,21 +76,23 @@ def test_update(users_model: Users) -> None: external_id = "test_external_id" username = "test_username" email = "test_email" - users_model.collection.insert_one( - { - "_id": external_id, - "external_id": external_id, - "username": username, - "email": email, - } + users_model.insert( + external_id=external_id, + username=username, + email=email, ) new_username = "new_username" new_email = "new_email" - result = users_model.update(external_id, new_username, new_email) + result = users_model.update( + external_id, + username=new_username, + email=new_email + ) assert result is not None assert result == 1 - user_data = users_model.get(external_id=external_id) + + user_data = users_model.get(external_id) assert user_data is not None assert user_data["external_id"] == external_id assert user_data["username"] == new_username diff --git a/tests/test_views/test_flags.py b/tests/test_views/test_flags.py new file mode 100644 index 00000000..57ae16df --- /dev/null +++ b/tests/test_views/test_flags.py @@ -0,0 +1,114 @@ +"""Test flags api endpoints.""" + +from unittest.mock import Mock, patch + +from bson import ObjectId +from django.test import Client + +from forum.models.contents import Contents +from forum.models.users import Users + + +def test_comment_thread_api(api_client: Client, users_model: Users, content_model: Contents) -> None: + """ + Test the comment thread flag API. + + This test checks that a user can flag a comment thread for abuse and then unflag it. + """ + user_id = "1" + comment_thread_id = "66ace22474ba69001e1440cd" + users_model.insert(user_id, username="user1", email="email1") + content_model.insert( + comment_thread_id, + "3", + abuse_flaggers=[], + historical_abuse_flaggers=[], + visible=True, + ) + mock_users_class = Mock(return_value=users_model) + mock_contents_class = Mock(return_value=content_model) + with patch("forum.models.users.Users", new=mock_users_class): + with patch("forum.models.contents.Contents", new=mock_contents_class): + response = api_client.put( + f"/api/v2/threads/{comment_thread_id}/abuse_flag", + data={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + comment_thread = response.json() + assert comment_thread["abuse_flaggers"] == [str(user_id)] + + response = api_client.put( + path=f"/api/v2/threads/{comment_thread_id}/abuse_unflag", + data={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + comment = content_model.get(comment_thread_id) + assert comment is not None + assert comment["abuse_flaggers"] == [] + + +def test_comment_flag_api(api_client: Client, users_model: Users, content_model: Contents) -> None: + """ + Test the comment flag API. + + This test checks that a user can flag a comment for abuse and then unflag it. + """ + user_id = "1" + comment_id = "66ace22474ba69001e1440cd" + users_model.insert(user_id, username="user1", email="email1") + content_model.insert( + comment_id, + "3", + abuse_flaggers=[], + historical_abuse_flaggers=[], + visible=True, + ) + mock_users_class = Mock(return_value=users_model) + mock_contents_class = Mock(return_value=content_model) + with patch("forum.models.users.Users", new=mock_users_class): + with patch("forum.models.contents.Contents", new=mock_contents_class): + response = api_client.put( + f"/api/v2/comments/{comment_id}/abuse_flag", + data={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + comment_thread = response.json() + assert comment_thread["abuse_flaggers"] == [str(user_id)] + + response = api_client.put( + path=f"/api/v2/comments/{comment_id}/abuse_unflag", + data={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + comment = content_model.get(comment_id) + assert comment is not None + assert comment["abuse_flaggers"] == [] + + response = api_client.put( + path=f"/api/v2/comments/{comment_id}/abuse_unflag", + data={"user_id": str(user_id)}, + ) + assert response.status_code == 200 + comment = content_model.get(comment_id) + assert comment is not None + assert comment["abuse_flaggers"] == [] + + +def test_comment_flag_api_invalid_data(api_client: Client, users_model: Users, content_model: Contents) -> None: + """ + Test the comment flag API with invalid data. + + This test checks that the API returns a 400 error when the user or comment does not exist. + """ + user_id = "1" + users_model.insert(user_id, username="user1", email="email1") + mock_users_class = Mock(return_value=users_model) + mock_contents_class = Mock(return_value=content_model) + with patch("forum.models.users.Users", new=mock_users_class): + with patch("forum.models.contents.Contents", new=mock_contents_class): + response = api_client.put( + path="/api/v2/comments/66ace22474ba69001e1440bd/abuse_flag", + data={"user_id": str(user_id)}, + ) + assert response.status_code == 400 + assert response.json() == {"error": "User / Comment doesn't exist"}