From f35d53ea2f36685136e09b1814c0f81d7aae3c19 Mon Sep 17 00:00:00 2001 From: Taimoor Ahmed Date: Tue, 1 Oct 2024 12:53:38 +0500 Subject: [PATCH] feat: integrate mysql models with forum apis - Add backend interface - Refactor mongo api - Refactor mysql api - modify api views to use backend based on get_backend() close: #051 --- forum/api/commentables.py | 5 +- forum/api/comments.py | 115 +- forum/api/flags.py | 56 +- forum/api/pins.py | 21 +- forum/api/search.py | 5 +- forum/api/subscriptions.py | 55 +- forum/api/threads.py | 90 +- forum/api/users.py | 170 +- forum/api/votes.py | 102 +- forum/backend.py | 22 + forum/backends/backend.py | 446 +++ forum/backends/mongodb/api.py | 2867 +++++++++++--------- forum/backends/mongodb/contents.py | 3 +- forum/backends/mongodb/subscriptions.py | 7 +- forum/backends/mongodb/users.py | 2 +- forum/backends/mysql/api.py | 2860 ++++++++++++------- forum/backends/mysql/models.py | 97 +- forum/migrations/0001_initial.py | 10 +- forum/serializers/comment.py | 18 +- forum/serializers/thread.py | 49 +- forum/serializers/users.py | 1 + forum/settings/common.py | 2 + forum/settings/test.py | 2 + forum/views/commentables.py | 2 +- forum/views/users.py | 2 +- tests/conftest.py | 16 +- tests/e2e/test_search.py | 7 +- tests/e2e/test_users.py | 4 +- tests/test_backends/test_mysql/test_api.py | 24 +- tests/test_views/test_commentables.py | 54 +- tests/test_views/test_comments.py | 72 +- tests/test_views/test_flags.py | 150 +- tests/test_views/test_pins.py | 57 +- tests/test_views/test_search.py | 318 ++- tests/test_views/test_subscriptions.py | 220 +- tests/test_views/test_threads.py | 327 ++- tests/test_views/test_users.py | 236 +- tests/test_views/test_votes.py | 67 +- 38 files changed, 5183 insertions(+), 3378 deletions(-) create mode 100644 forum/backend.py create mode 100644 forum/backends/backend.py diff --git a/forum/api/commentables.py b/forum/api/commentables.py index 8c142297..8161367c 100644 --- a/forum/api/commentables.py +++ b/forum/api/commentables.py @@ -2,7 +2,7 @@ Native Python Commenttables APIs. """ -from forum.backends.mongodb.api import get_commentables_counts_based_on_type +from forum.backend import get_backend def get_commentables_stats(course_id: str) -> dict[str, int]: @@ -18,4 +18,5 @@ def get_commentables_stats(course_id: str) -> dict[str, int]: e.g. reponse = {'course': {'discussion': 1, 'question': 1}} """ - return get_commentables_counts_based_on_type(course_id) + backend = get_backend(course_id)() + return backend.get_commentables_counts_based_on_type(course_id) diff --git a/forum/api/comments.py b/forum/api/comments.py index e2440577..3df4168e 100644 --- a/forum/api/comments.py +++ b/forum/api/comments.py @@ -8,21 +8,7 @@ from django.core.exceptions import ObjectDoesNotExist from rest_framework.serializers import ValidationError -from forum.backends.mongodb.api import ( - create_comment, - delete_comment_by_id, - get_course_id_by_comment_id, - get_thread_by_id, - get_thread_id_by_comment_id, - get_user_by_id, - mark_as_read, - update_comment_and_get_updated_comment, - update_stats_for_course, - validate_object, -) -from forum.backends.mongodb.comments import Comment -from forum.backends.mongodb.threads import CommentThread -from forum.backends.mongodb import api +from forum.backend import get_backend from forum.serializers.comment import CommentSerializer from forum.utils import ForumV2RequestError @@ -31,6 +17,7 @@ def prepare_comment_api_response( comment: dict[str, Any], + backend: Any, exclude_fields: Optional[list[str]] = None, ) -> dict[str, Any]: """ @@ -58,6 +45,7 @@ def prepare_comment_api_response( serializer = CommentSerializer( data=comment_data, exclude_fields=exclude_fields, + backend=backend, ) if not serializer.is_valid(raise_exception=True): raise ValidationError(serializer.errors) @@ -65,7 +53,9 @@ def prepare_comment_api_response( return serializer.data -def get_parent_comment(comment_id: str) -> dict[str, Any]: +def get_parent_comment( + comment_id: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Get a parent comment. @@ -76,8 +66,9 @@ def get_parent_comment(comment_id: str) -> dict[str, Any]: Response: The details of the comment for the given comment_id. """ + backend = get_backend(course_id)() try: - comment = validate_object(Comment, comment_id) + comment = backend.validate_object("Comment", comment_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for get parent comment request.") raise ForumV2RequestError( @@ -85,6 +76,7 @@ def get_parent_comment(comment_id: str) -> dict[str, Any]: ) from exc return prepare_comment_api_response( comment, + backend, exclude_fields=["sk"], ) @@ -110,35 +102,40 @@ def create_child_comment( Response: The details of the comment that is created. """ + backend = get_backend(course_id)() try: - parent_comment = validate_object(Comment, parent_comment_id) + parent_comment = backend.validate_object("Comment", parent_comment_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for create child comment request.") raise ForumV2RequestError( f"Comment does not exists with Id: {parent_comment_id}" ) from exc - comment = create_comment( - body, - user_id, - course_id, - anonymous, - anonymous_to_peers, - 1, - get_thread_id_by_comment_id(parent_comment_id), - parent_id=parent_comment_id, + comment_id = backend.create_comment( + { + "body": body, + "author_id": user_id, + "course_id": course_id, + "anonymous": anonymous, + "anonymous_to_peers": anonymous_to_peers, + "depth": 1, + "comment_thread_id": backend.get_thread_id_by_comment_id(parent_comment_id), + "parent_id": parent_comment_id, + } ) + comment = backend.get_comment(comment_id) if not comment: log.error("Forumv2RequestError for create child comment request.") raise ForumV2RequestError("comment is not created") - user = get_user_by_id(user_id) - thread = get_thread_by_id(parent_comment["comment_thread_id"]) + user = backend.get_user(user_id) + thread = backend.get_thread(parent_comment["comment_thread_id"]) if user and thread and comment: - mark_as_read(user, thread) + backend.mark_as_read(user_id, parent_comment["comment_thread_id"]) try: comment_data = prepare_comment_api_response( comment, + backend, exclude_fields=["endorsement", "sk"], ) return comment_data @@ -158,6 +155,7 @@ def update_comment( editing_user_id: Optional[str] = None, edit_reason_code: Optional[str] = None, endorsement_user_id: Optional[str] = None, + course_key: Optional[str] = None, ) -> dict[str, Any]: """ Update an existing child/parent comment. @@ -177,15 +175,16 @@ def update_comment( Response: The details of the comment that is updated. """ + backend = get_backend(course_key)() try: - validate_object(Comment, comment_id) + backend.validate_object("Comment", comment_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for update comment request.") raise ForumV2RequestError( f"Comment does not exists with Id: {comment_id}" ) from exc - updated_comment = update_comment_and_get_updated_comment( + updated_comment = backend.update_comment_and_get_updated_comment( comment_id, body, course_id, @@ -204,6 +203,7 @@ def update_comment( try: return prepare_comment_api_response( updated_comment, + backend, exclude_fields=( ["endorsement", "sk"] if updated_comment.get("parent_id") else ["sk"] ), @@ -212,7 +212,7 @@ def update_comment( raise error -def delete_comment(comment_id: str) -> dict[str, Any]: +def delete_comment(comment_id: str, course_id: Optional[str] = None) -> dict[str, Any]: """ Delete a comment. @@ -223,8 +223,9 @@ def delete_comment(comment_id: str) -> dict[str, Any]: Response: The details of the comment that is deleted. """ + backend = get_backend(course_id)() try: - comment = validate_object(Comment, comment_id) + comment = backend.validate_object("Comment", comment_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for delete comment request.") raise ForumV2RequestError( @@ -232,16 +233,17 @@ def delete_comment(comment_id: str) -> dict[str, Any]: ) from exc data = prepare_comment_api_response( comment, + backend, exclude_fields=["endorsement", "sk"], ) - delete_comment_by_id(comment_id) + backend.delete_comment(comment_id) author_id = comment["author_id"] - course_id = comment["course_id"] + comment_course_id = comment["course_id"] parent_comment_id = data["parent_id"] if parent_comment_id: - update_stats_for_course(author_id, course_id, replies=-1) + backend.update_stats_for_course(author_id, comment_course_id, replies=-1) else: - update_stats_for_course(author_id, course_id, responses=-1) + backend.update_stats_for_course(author_id, comment_course_id, responses=-1) return data @@ -266,32 +268,37 @@ def create_parent_comment( Response: The details of the comment that is created. """ + backend = get_backend(course_id)() try: - thread = validate_object(CommentThread, thread_id) + backend.validate_object("CommentThread", thread_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for create parent comment request.") raise ForumV2RequestError( f"Thread does not exists with Id: {thread_id}" ) from exc - comment = create_comment( - body, - user_id, - course_id, - anonymous, - anonymous_to_peers, - 0, - thread_id=thread_id, + comment_id = backend.create_comment( + { + "body": body, + "author_id": user_id, + "course_id": course_id, + "anonymous": anonymous, + "anonymous_to_peers": anonymous_to_peers, + "depth": 0, + "comment_thread_id": thread_id, + } ) - if not comment: + if not comment_id: log.error("Forumv2RequestError for create parent comment request.") raise ForumV2RequestError("comment is not created") - user = get_user_by_id(user_id) + comment = backend.get_comment(comment_id) or {} + user = backend.get_user(user_id) if user and comment: - mark_as_read(user, thread) + backend.mark_as_read(user_id, thread_id) try: return prepare_comment_api_response( comment, + backend, exclude_fields=["endorsement", "sk"], ) except ValidationError as error: @@ -303,8 +310,12 @@ def get_course_id_by_comment(comment_id: str) -> str | None: Return course_id for the matching comment. It searches for comment_id both in mongodb and mysql. """ + # pylint: disable=C0415 + from forum.backends.mongodb.api import MongoBackend + from forum.backends.mysql.api import MySQLBackend + return ( - get_course_id_by_comment_id(comment_id) - or api.get_course_id_by_comment_id(comment_id) + MongoBackend.get_course_id_by_comment_id(comment_id) + or MySQLBackend.get_course_id_by_comment_id(comment_id) or None ) diff --git a/forum/api/flags.py b/forum/api/flags.py index fb7029ec..cd7c71e9 100644 --- a/forum/api/flags.py +++ b/forum/api/flags.py @@ -4,14 +4,7 @@ from typing import Any, Optional -from forum.backends.mongodb.api import ( - flag_as_abuse, - un_flag_all_as_abuse, - un_flag_as_abuse, -) -from forum.backends.mongodb.comments import Comment -from forum.backends.mongodb.threads import CommentThread -from forum.backends.mongodb.users import Users +from forum.backend import get_backend from forum.serializers.comment import CommentSerializer from forum.serializers.thread import ThreadSerializer from forum.utils import ForumV2RequestError @@ -22,6 +15,7 @@ def update_comment_flag( action: str, user_id: Optional[str] = None, update_all: Optional[bool] = False, + course_id: Optional[str] = None, ) -> dict[str, Any]: """ Update the flag status of a comment. @@ -32,20 +26,30 @@ def update_comment_flag( action (str): The action to perform ("flag" or "unflag"). update_all (bool, optional): Whether to update all flags. Defaults to False. """ + backend = get_backend(course_id)() if not user_id: raise ForumV2RequestError("user_id not provided in params") - user = Users().get(user_id) - comment = Comment().get(comment_id) + user = backend.get_user(user_id) + try: + comment = backend.get_comment(comment_id) + except ValueError as exc: + raise ForumV2RequestError("User / Comment doesn't exist") from exc if not user or not comment: raise ForumV2RequestError("User / Comment doesn't exist") if action == "flag": - updated_comment = flag_as_abuse(user, comment) + updated_comment = backend.flag_as_abuse( + user_id, comment_id, entity_type="Comment" + ) elif action == "unflag": if update_all: - updated_comment = un_flag_all_as_abuse(comment) + updated_comment = backend.un_flag_all_as_abuse( + comment_id, entity_type="Comment" + ) else: - updated_comment = un_flag_as_abuse(user, comment) + updated_comment = backend.un_flag_as_abuse( + user_id, comment_id, entity_type="Comment" + ) else: raise ForumV2RequestError("Invalid action") @@ -53,14 +57,14 @@ def update_comment_flag( raise ForumV2RequestError("Failed to update comment") context = { - "id": str(updated_comment["_id"]), + "id": updated_comment["_id"], **updated_comment, "user_id": user["_id"], "username": user["username"], "type": "comment", - "thread_id": str(updated_comment.get("comment_thread_id", None)), + "comment_thread_id": str(updated_comment.get("comment_thread_id", None)), } - return CommentSerializer(context).data + return CommentSerializer(context, backend=backend).data def update_thread_flag( @@ -68,6 +72,7 @@ def update_thread_flag( action: str, user_id: Optional[str] = None, update_all: Optional[bool] = False, + course_id: Optional[str] = None, ) -> dict[str, Any]: """ Update the flag status of a thread. @@ -78,20 +83,27 @@ def update_thread_flag( action (str): The action to perform ("flag" or "unflag"). update_all (bool, optional): Whether to update all flags. Defaults to False. """ + backend = get_backend(course_id)() if not user_id: raise ForumV2RequestError("user_id not provided in params") - user = Users().get(user_id) - thread = CommentThread().get(thread_id) + user = backend.get_user(user_id) + thread = backend.get_thread(thread_id) if not user or not thread: raise ForumV2RequestError("User / Thread doesn't exist") if action == "flag": - updated_thread = flag_as_abuse(user, thread) + updated_thread = backend.flag_as_abuse( + user_id, thread_id, entity_type="CommentThread" + ) elif action == "unflag": if update_all: - updated_thread = un_flag_all_as_abuse(thread) + updated_thread = backend.un_flag_all_as_abuse( + thread_id, entity_type="CommentThread" + ) else: - updated_thread = un_flag_as_abuse(user, thread) + updated_thread = backend.un_flag_as_abuse( + user_id, thread_id, entity_type="CommentThread" + ) else: raise ForumV2RequestError("Invalid action") @@ -106,4 +118,4 @@ def update_thread_flag( "type": "thread", "thread_id": str(updated_thread.get("comment_thread_id", None)), } - return ThreadSerializer(context).data + return ThreadSerializer(context, backend=backend).data diff --git a/forum/api/pins.py b/forum/api/pins.py index ce3ab2be..33a027aa 100644 --- a/forum/api/pins.py +++ b/forum/api/pins.py @@ -3,9 +3,9 @@ """ import logging -from typing import Any +from typing import Any, Optional -from forum.backends.mongodb.api import handle_pin_unpin_thread_request +from forum.backend import get_backend from forum.serializers.thread import ThreadSerializer from forum.utils import ForumV2RequestError @@ -16,6 +16,7 @@ def pin_unpin_thread( user_id: str, thread_id: str, action: str, + course_id: Optional[str] = None, ) -> dict[str, Any]: """ Helper method to Pin or Unpin a thread. @@ -26,8 +27,9 @@ def pin_unpin_thread( Response: A response with the updated thread data. """ + backend = get_backend(course_id)() try: - thread_data: dict[str, Any] = handle_pin_unpin_thread_request( + thread_data: dict[str, Any] = backend.handle_pin_unpin_thread_request( user_id, thread_id, action, ThreadSerializer ) except ValueError as e: @@ -37,7 +39,9 @@ def pin_unpin_thread( return thread_data -def pin_thread(user_id: str, thread_id: str) -> dict[str, Any]: +def pin_thread( + user_id: str, thread_id: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Pin a thread. Parameters: @@ -46,10 +50,13 @@ def pin_thread(user_id: str, thread_id: str) -> dict[str, Any]: Response: A response with the updated thread data. """ - return pin_unpin_thread(user_id, thread_id, "pin") + + return pin_unpin_thread(user_id, thread_id, "pin", course_id) -def unpin_thread(user_id: str, thread_id: str) -> dict[str, Any]: +def unpin_thread( + user_id: str, thread_id: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Unpin a thread. Parameters: @@ -58,4 +65,4 @@ def unpin_thread(user_id: str, thread_id: str) -> dict[str, Any]: Response: A response with the updated thread data. """ - return pin_unpin_thread(user_id, thread_id, "unpin") + return pin_unpin_thread(user_id, thread_id, "unpin", course_id) diff --git a/forum/api/search.py b/forum/api/search.py index 789dbd92..d6daad53 100644 --- a/forum/api/search.py +++ b/forum/api/search.py @@ -4,7 +4,7 @@ from typing import Any, Optional -from forum.backends.mongodb.api import handle_threads_query +from forum.backends.mongodb.api import MongoBackend as backend from forum.constants import FORUM_DEFAULT_PAGE, FORUM_DEFAULT_PER_PAGE from forum.search.comment_search import ThreadSearch from forum.serializers.thread import ThreadSerializer @@ -85,7 +85,7 @@ def search_threads( context, group_ids, text, commentable_ids, course_id ) - data = handle_threads_query( + data = backend.handle_threads_query( thread_ids, user_id, course_id, @@ -112,6 +112,7 @@ def search_threads( "include_endorsed": True, "include_read_state": True, }, + backend=backend, ) data["collection"] = thread_serializer.data diff --git a/forum/api/subscriptions.py b/forum/api/subscriptions.py index 5737609c..02587dcc 100644 --- a/forum/api/subscriptions.py +++ b/forum/api/subscriptions.py @@ -2,22 +2,13 @@ API for subscriptions. """ -from typing import Any +from typing import Any, Optional from django.http import QueryDict from rest_framework.request import Request from rest_framework.test import APIRequestFactory -from forum.backends.mongodb.api import ( - find_subscribed_threads, - get_threads, - subscribe_user, - unsubscribe_user, - validate_params, -) -from forum.backends.mongodb.subscriptions import Subscriptions -from forum.backends.mongodb.threads import CommentThread -from forum.backends.mongodb.users import Users +from forum.backend import get_backend from forum.pagination import ForumPagination from forum.serializers.subscriptions import SubscriptionSerializer from forum.serializers.thread import ThreadSerializer @@ -25,42 +16,50 @@ def validate_user_and_thread( - user_id: str, source_id: str + user_id: str, source_id: str, course_id: Optional[str] = None ) -> tuple[dict[str, Any], dict[str, Any]]: """ Validate if user and thread exist. """ - user = Users().get(user_id) - thread = CommentThread().get(source_id) + backend = get_backend(course_id)() + user = backend.get_user(user_id) + thread = backend.get_thread(source_id) if not (user and thread): raise ForumV2RequestError("User / Thread doesn't exist") return user, thread -def create_subscription(user_id: str, source_id: str) -> dict[str, Any]: +def create_subscription( + user_id: str, source_id: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Create a subscription for a user. """ - _, thread = validate_user_and_thread(user_id, source_id) - subscription = subscribe_user(user_id, source_id, thread["_type"]) + backend = get_backend(course_id)() + _, _ = validate_user_and_thread(user_id, source_id) + subscription = backend.subscribe_user( + user_id, source_id, source_type="CommentThread" + ) serializer = SubscriptionSerializer(subscription) return serializer.data -def delete_subscription(user_id: str, source_id: str) -> dict[str, Any]: +def delete_subscription( + user_id: str, source_id: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Delete a subscription for a user. """ + backend = get_backend(course_id)() _, _ = validate_user_and_thread(user_id, source_id) - subscription = Subscriptions().get_subscription( - user_id, - source_id, + subscription = backend.get_subscription( + user_id, source_id, source_type="CommentThread" ) if not subscription: raise ForumV2RequestError("Subscription doesn't exist") - unsubscribe_user(user_id, source_id) + backend.unsubscribe_user(user_id, source_id, source_type="CommentThread") serializer = SubscriptionSerializer(subscription) return serializer.data @@ -71,14 +70,15 @@ def get_user_subscriptions( """ Get a user's subscriptions. """ - validate_params(query_params, user_id) - thread_ids = find_subscribed_threads(user_id, course_id) - threads = get_threads(query_params, ThreadSerializer, thread_ids, user_id) + backend = get_backend(course_id)() + backend.validate_params(query_params, user_id) + thread_ids = backend.find_subscribed_threads(user_id, course_id) + threads = backend.get_threads(query_params, user_id, ThreadSerializer, thread_ids) return threads def get_thread_subscriptions( - thread_id: str, page: int = 1, per_page: int = 20 + thread_id: str, page: int = 1, per_page: int = 20, course_id: Optional[str] = None ) -> dict[str, Any]: """ Retrieve subscriptions to a specific thread. @@ -91,8 +91,9 @@ def get_thread_subscriptions( Returns: dict: A dictionary containing the paginated subscription data. """ + backend = get_backend(course_id)() query = {"source_id": thread_id, "source_type": "CommentThread"} - subscriptions_list = list(Subscriptions().find(query)) + subscriptions_list = list(backend.get_subscriptions(query)) factory = APIRequestFactory() query_params = QueryDict("", mutable=True) diff --git a/forum/api/threads.py b/forum/api/threads.py index cf070913..a396a4fa 100644 --- a/forum/api/threads.py +++ b/forum/api/threads.py @@ -8,21 +8,8 @@ from django.core.exceptions import ObjectDoesNotExist from rest_framework.serializers import ValidationError -from forum.backends.mongodb.api import ( - delete_comments_of_a_thread, - delete_subscriptions_of_a_thread, - get_course_id_by_thread_id, - get_threads, -) -from forum.backends.mongodb.api import mark_as_read as mark_thread_as_read -from forum.backends.mongodb.api import ( - update_stats_for_course, - validate_object, - validate_params, -) -from forum.backends.mongodb.threads import CommentThread -from forum.backends.mongodb.users import Users -from forum.backends.mysql import api +from forum.api.users import mark_thread_as_read +from forum.backend import get_backend from forum.serializers.thread import ThreadSerializer from forum.utils import ForumV2RequestError, get_int_value_from_collection, str_to_bool @@ -65,7 +52,7 @@ def get_thread_data(thread: dict[str, Any]) -> dict[str, Any]: thread_data = { **thread, "id": str(thread.get("_id")), - "type": "thread" if _type == "commentthread" else _type, + "type": "thread" if _type.lower() == "commentthread" else _type, "user_id": thread.get("author_id"), "username": str(thread.get("author_username")), "comments_count": thread["comment_count"], @@ -75,6 +62,7 @@ def get_thread_data(thread: dict[str, Any]) -> dict[str, Any]: def prepare_thread_api_response( thread: dict[str, Any], + backend: Any, include_context: Optional[bool] = False, data_or_params: Optional[dict[str, Any]] = None, include_data_from_params: Optional[bool] = False, @@ -109,12 +97,13 @@ def prepare_thread_api_response( for param in params: if value := data_or_params.get(param): context[param] = str_to_bool(value) - if user_id and (user := Users().get(user_id)): - mark_thread_as_read(user, thread) + if user_id and backend.get_user(user_id): + mark_thread_as_read(user_id, thread["_id"]) serializer = ThreadSerializer( data=thread_data, context=context, + backend=backend, ) if not serializer.is_valid(raise_exception=True): log.error(f"validation error in thread API call: {serializer.errors}") @@ -126,6 +115,7 @@ def prepare_thread_api_response( def get_thread( thread_id: str, params: Optional[dict[str, Any]] = None, + course_id: Optional[str] = None, ) -> dict[str, Any]: """ Get the thread for the given thread_id. @@ -143,8 +133,9 @@ def get_thread( Response: The details of the thread for the given thread_id. """ + backend = get_backend(course_id)() try: - thread = validate_object(CommentThread, thread_id) + thread = backend.validate_object("CommentThread", thread_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for get thread request.") raise ForumV2RequestError( @@ -154,6 +145,7 @@ def get_thread( try: return prepare_thread_api_response( thread, + backend, True, params, True, @@ -163,7 +155,7 @@ def get_thread( raise ForumV2RequestError("Failed to prepare thread API response") from error -def delete_thread(thread_id: str) -> dict[str, Any]: +def delete_thread(thread_id: str, course_id: Optional[str] = None) -> dict[str, Any]: """ Delete the thread for the given thread_id. @@ -172,27 +164,30 @@ def delete_thread(thread_id: str) -> dict[str, Any]: Response: The details of the thread that is deleted. """ + backend = get_backend(course_id)() try: - thread = validate_object(CommentThread, thread_id) + thread = backend.validate_object("CommentThread", thread_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for delete thread request.") raise ForumV2RequestError( f"Thread does not exist with Id: {thread_id}" ) from exc - delete_comments_of_a_thread(thread_id) - thread = validate_object(CommentThread, thread_id) + backend.delete_comments_of_a_thread(thread_id) + thread = backend.validate_object("CommentThread", thread_id) try: - serialized_data = prepare_thread_api_response(thread) + serialized_data = prepare_thread_api_response(thread, backend) except ValidationError as error: log.error(f"Validation error in get_thread: {error}") raise ForumV2RequestError("Failed to prepare thread API response") from error - result = CommentThread().delete(thread_id) - delete_subscriptions_of_a_thread(thread_id) + backend.delete_subscriptions_of_a_thread(thread_id) + result = backend.delete_thread(thread_id) if result and not (thread["anonymous"] or thread["anonymous_to_peers"]): - update_stats_for_course(thread["author_id"], thread["course_id"], threads=-1) + backend.update_stats_for_course( + thread["author_id"], thread["course_id"], threads=-1 + ) return serialized_data @@ -214,6 +209,7 @@ def update_thread( close_reason_code: Optional[str] = None, closing_user_id: Optional[str] = None, endorsed: Optional[bool] = None, + course_key: Optional[str] = None, ) -> dict[str, Any]: """ Update the thread for the given thread_id. @@ -224,8 +220,9 @@ def update_thread( Response: The details of the thread that is updated. """ + backend = get_backend(course_key)() try: - thread = validate_object(CommentThread, thread_id) + thread = backend.validate_object("CommentThread", thread_id) except ObjectDoesNotExist as exc: log.error("Forumv2RequestError for update thread request.") raise ForumV2RequestError( @@ -262,12 +259,13 @@ def update_thread( raise ForumV2RequestError( f"Missing required fields: {', '.join(missing_fields)}" ) - CommentThread().update(thread_id, **update_thread_data) - thread = CommentThread().get(thread_id) + backend.update_thread(thread_id, **update_thread_data) + thread = backend.get_thread(thread_id) try: return prepare_thread_api_response( thread, + backend, True, data, ) @@ -303,6 +301,7 @@ def create_thread( Response: The details of the thread that is created. """ + backend = get_backend(course_id)() data = { "title": title, "body": body, @@ -316,17 +315,20 @@ def create_thread( } thread_data: dict[str, Any] = _get_thread_data_from_request_data(data) - thread_id = CommentThread().insert(**thread_data) - thread = CommentThread().get(thread_id) + thread_id = backend.create_thread(thread_data) + thread = backend.get_thread(thread_id) if not thread: raise ForumV2RequestError(f"Failed to create thread with data: {data}") if not (anonymous or anonymous_to_peers): - update_stats_for_course(thread["author_id"], thread["course_id"], threads=1) + backend.update_stats_for_course( + thread["author_id"], thread["course_id"], threads=1 + ) try: return prepare_thread_api_response( thread, + backend, True, data, ) @@ -336,7 +338,7 @@ def create_thread( def get_user_threads( - course_id: Optional[str] = None, + course_id: str, author_id: Optional[str] = None, thread_type: Optional[str] = None, flagged: Optional[bool] = None, @@ -354,6 +356,7 @@ def get_user_threads( """ Get the threads for the given thread_ids. """ + backend = get_backend(course_id)() params = { "course_id": course_id, "author_id": author_id, @@ -371,15 +374,12 @@ def get_user_threads( "user_id": user_id, } params = {k: v for k, v in params.items() if v is not None} - validate_params(params) + backend.validate_params(params) - thread_filter = { - "_type": {"$in": [CommentThread.content_type]}, - "course_id": {"$in": [course_id]}, - } - filtered_threads = CommentThread().find(thread_filter) + thread_filter = backend.get_user_thread_filter(course_id) + filtered_threads = backend.get_filtered_threads(thread_filter) thread_ids = [thread["_id"] for thread in filtered_threads] - threads = get_threads(params, ThreadSerializer, thread_ids, user_id or "") + threads = backend.get_threads(params, user_id or "", ThreadSerializer, thread_ids) return threads @@ -389,8 +389,12 @@ def get_course_id_by_thread(thread_id: str) -> str | None: Return course_id for the matching thread. It searches for thread_id both in mongodb and mysql. """ + # pylint: disable=C0415 + from forum.backends.mongodb.api import MongoBackend + from forum.backends.mysql.api import MySQLBackend + return ( - get_course_id_by_thread_id(thread_id) - or api.get_course_id_by_thread_id(thread_id) + MongoBackend.get_course_id_by_thread_id(thread_id) + or MySQLBackend.get_course_id_by_thread_id(thread_id) or None ) diff --git a/forum/api/users.py b/forum/api/users.py index 30831d7c..9ada2daa 100644 --- a/forum/api/users.py +++ b/forum/api/users.py @@ -6,20 +6,7 @@ import math from typing import Any, Optional -from forum.backends.mongodb import Users -from forum.backends.mongodb.api import ( - find_or_create_user, - get_user_by_username, - handle_threads_query, - mark_as_read, - replace_username_in_all_content, - retire_all_content, - unsubscribe_all, - update_all_users_in_course, - user_to_hash, -) -from forum.backends.mongodb.contents import Contents -from forum.backends.mongodb.threads import CommentThread +from forum.backend import get_backend from forum.constants import FORUM_DEFAULT_PAGE, FORUM_DEFAULT_PER_PAGE from forum.serializers.thread import ThreadSerializer from forum.serializers.users import UserSerializer @@ -43,7 +30,8 @@ def get_user( Response: A response with the users data. """ - user = Users().get(user_id) + backend = get_backend(course_id)() + user = backend.get_user(user_id) if not user: log.error(f"Forumv2RequestError for retrieving user's data for id {user_id}.") raise ForumV2RequestError(str(f"user not found with id: {user_id}")) @@ -53,7 +41,7 @@ def get_user( "group_ids": group_ids, "course_id": course_id, } - hashed_user = user_to_hash(user, params) + hashed_user = backend.user_to_hash(user_id, params) serializer = UserSerializer(hashed_user) return serializer.data @@ -67,17 +55,21 @@ def update_user( complete: Optional[bool] = False, ) -> dict[str, Any]: """Update user.""" - user = Users().get(user_id) - user_by_username = get_user_by_username(username) + backend = get_backend(course_id)() + user = backend.get_user(user_id) + user_by_username = backend.get_user_by_username(username) if user and user_by_username: if user["external_id"] != user_by_username["external_id"]: raise ForumV2RequestError("user does not match") elif user_by_username: raise ForumV2RequestError(f"user already exists with username: {username}") else: - user_id = find_or_create_user(user_id) - Users().update(user_id, username=username, default_sort_key=default_sort_key) - updated_user = Users().get(user_id) + user_id = backend.find_or_create_user(user_id) + update_data = {"username": username} + if default_sort_key is not None: + update_data["default_sort_key"] = default_sort_key + backend.update_user(user_id, update_data) + updated_user = backend.get_user(user_id) if not updated_user: raise ForumV2RequestError(f"user not found with id: {user_id}") params = { @@ -85,7 +77,7 @@ def update_user( "group_ids": group_ids, "course_id": course_id, } - hashed_user = user_to_hash(updated_user, params) + hashed_user = backend.user_to_hash(user_id, params) serializer = UserSerializer(hashed_user) return serializer.data @@ -99,16 +91,17 @@ def create_user( complete: bool = False, ) -> dict[str, Any]: """Create user.""" - user_by_id = Users().get(user_id) - user_by_username = get_user_by_username(username) + backend = get_backend(course_id)() + user_by_id = backend.get_user(user_id) + user_by_username = backend.get_user_by_username(username) if user_by_id or user_by_username: raise ForumV2RequestError(f"user already exists with id: {id}") - Users().insert( - external_id=user_id, username=username, default_sort_key=default_sort_key + backend.find_or_create_user( + user_id, username=username, default_sort_key=default_sort_key ) - user = Users().get(user_id) + user = backend.get_user(user_id) if not user: raise ForumV2RequestError(f"user not found with id: {user_id}") params = { @@ -116,34 +109,42 @@ def create_user( "group_ids": group_ids, "course_id": course_id, } - hashed_user = user_to_hash(user, params) + hashed_user = backend.user_to_hash(user_id, params) serializer = UserSerializer(hashed_user) return serializer.data -def update_username(user_id: str, new_username: str) -> dict[str, str]: +def update_username( + user_id: str, new_username: str, course_id: Optional[str] = None +) -> dict[str, str]: """Update username.""" - user = Users().get(user_id) + backend = get_backend(course_id)() + user = backend.get_user(user_id) if not user: raise ForumV2RequestError(str(f"user not found with id: {user_id}")) - Users().update(user_id, username=new_username) - replace_username_in_all_content(user_id, new_username) + backend.update_user(user_id, {"username": new_username}) + backend.replace_username_in_all_content(user_id, new_username) return {"message": "Username updated successfully"} -def retire_user(user_id: str, retired_username: str) -> dict[str, str]: +def retire_user( + user_id: str, retired_username: str, course_id: Optional[str] = None +) -> dict[str, str]: """Retire user.""" - user = Users().get(user_id) + backend = get_backend(course_id)() + user = backend.get_user(user_id) if not user: raise ForumV2RequestError(f"user not found with id: {user_id}") - Users().update( + backend.update_user( user_id, - email="", - username=retired_username, - read_states=[], + data={ + "email": "", + "username": retired_username, + "read_states": [], + }, ) - unsubscribe_all(user_id) - retire_all_content(user_id, retired_username) + backend.unsubscribe_all(user_id) + backend.retire_all_content(user_id, retired_username) return {"message": "User retired successfully"} @@ -156,17 +157,18 @@ def mark_thread_as_read( group_ids: Optional[list[int]] = None, ) -> dict[str, Any]: """Mark thread as read.""" - user = Users().get(user_id) + backend = get_backend(course_id)() + user = backend.get_user(user_id) if not user: raise ForumV2RequestError(str(f"user not found with id: {user_id}")) - thread = CommentThread().get(source_id) + thread = backend.get_thread(source_id) if not thread: raise ForumV2RequestError(str(f"source not found with id: {source_id}")) - mark_as_read(user, thread) + backend.mark_as_read(user_id, source_id) - user = Users().get(user_id) + user = backend.get_user(user_id) if not user: raise ForumV2RequestError(str(f"user not found with id: {user_id}")) @@ -176,7 +178,7 @@ def mark_thread_as_read( "course_id": course_id, } - hashed_user = user_to_hash(user, params) + hashed_user = backend.user_to_hash(user_id, params) serializer = UserSerializer(hashed_user) return serializer.data @@ -197,11 +199,12 @@ def get_user_active_threads( group_id: Optional[str] = None, ) -> dict[str, Any]: """Get user active threads.""" + backend = get_backend(course_id)() raw_query = bool(sort_key == "user_activity") if not course_id: return {} active_contents = list( - Contents().get_list( + backend.get_contents( author_id=user_id, anonymous=False, anonymous_to_peers=False, @@ -247,7 +250,7 @@ def get_user_active_threads( "context": "course", "raw_query": raw_query, } - data = handle_threads_query(**params) + data = backend.handle_threads_query(**params) if collections := data.get("collection"): thread_serializer = ThreadSerializer( @@ -258,6 +261,7 @@ def get_user_active_threads( "include_endorsed": True, "include_read_state": True, }, + backend=backend, ) data["collection"] = thread_serializer.data else: @@ -265,64 +269,13 @@ def get_user_active_threads( for thread in collection: thread["_id"] = str(thread.pop("_id")) thread["type"] = str(thread.get("_type", "")).lower() - data["collection"] = ThreadSerializer(collection, many=True).data + data["collection"] = ThreadSerializer( + collection, many=True, backend=backend + ).data return data -def _create_pipeline( - course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any] -) -> list[dict[str, Any]]: - """Get pipeline for course stats api.""" - pipeline: list[dict[str, Any]] = [ - {"$match": {"course_stats.course_id": course_id}}, - {"$project": {"username": 1, "course_stats": 1}}, - {"$unwind": "$course_stats"}, - {"$match": {"course_stats.course_id": course_id}}, - {"$sort": sort_criterion}, - { - "$facet": { - "pagination": [{"$count": "total_count"}], - "data": [ - {"$skip": (page - 1) * per_page}, - {"$limit": per_page}, - ], - } - }, - ] - return pipeline - - -def _get_sort_criterion(sort_by: str) -> dict[str, Any]: - """Get sort criterion based on sort_by parameter.""" - if sort_by == "flagged": - return { - "course_stats.active_flags": -1, - "course_stats.inactive_flags": -1, - "username": -1, - } - elif sort_by == "recency": - return { - "course_stats.last_activity_at": -1, - "username": -1, - } - else: - return { - "course_stats.threads": -1, - "course_stats.responses": -1, - "course_stats.replies": -1, - "username": -1, - } - - -def _get_paginated_stats( - course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any] -) -> dict[str, Any]: - """Get paginated stats for a course.""" - pipeline = _create_pipeline(course_id, page, per_page, sort_criterion) - return list(Users().aggregate(pipeline))[0] - - def _get_user_data( user_stats: dict[str, Any], exclude_from_stats: list[str] ) -> dict[str, Any]: @@ -335,10 +288,10 @@ def _get_user_data( def _get_stats_for_usernames( - course_id: str, usernames: list[str] + course_id: str, usernames: list[str], backend: Any ) -> list[dict[str, Any]]: """Get stats for specific usernames.""" - users = Users().get_list() + users = backend.get_users() stats_query = [] for user in users: if user["username"] not in usernames: @@ -363,8 +316,8 @@ def get_user_course_stats( with_timestamps: bool = False, ) -> dict[str, Any]: """Get user course stats.""" - - sort_criterion = _get_sort_criterion(sort_key) + backend = get_backend(course_id)() + sort_criterion = backend.get_user_sort_criterion(sort_key) exclude_from_stats = ["_id", "course_id"] if not with_timestamps: exclude_from_stats.append("last_activity_at") @@ -373,7 +326,7 @@ def get_user_course_stats( data = [] if not usernames_list: - paginated_stats = _get_paginated_stats( + paginated_stats = backend.get_paginated_user_stats( course_id, page, per_page, sort_criterion ) num_pages = 0 @@ -387,7 +340,7 @@ def get_user_course_stats( for user_stats in paginated_stats["data"] ] else: - stats_query = _get_stats_for_usernames(course_id, usernames_list) + stats_query = _get_stats_for_usernames(course_id, usernames_list, backend) total_count = len(stats_query) num_pages = 1 data = [ @@ -412,5 +365,6 @@ def get_user_course_stats( def update_users_in_course(course_id: str) -> dict[str, int]: """Update all user stats in a course.""" - updated_users = update_all_users_in_course(course_id) + backend = get_backend(course_id)() + updated_users = backend.update_all_users_in_course(course_id) return {"user_count": len(updated_users)} diff --git a/forum/api/votes.py b/forum/api/votes.py index 085e6a44..45413f26 100644 --- a/forum/api/votes.py +++ b/forum/api/votes.py @@ -2,12 +2,9 @@ API for votes. """ -from typing import Any +from typing import Any, Optional -from forum.backends.mongodb.api import downvote_content, remove_vote, upvote_content -from forum.backends.mongodb.comments import Comment -from forum.backends.mongodb.threads import CommentThread -from forum.backends.mongodb.users import Users +from forum.backend import get_backend from forum.serializers.comment import CommentSerializer from forum.serializers.thread import ThreadSerializer from forum.serializers.votes import VotesInputSerializer @@ -15,7 +12,9 @@ def _get_thread_and_user( - thread_id: str, user_id: str + thread_id: str, + user_id: str, + course_id: Optional[str] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: """ Fetches the thread and user based on provided IDs. @@ -30,11 +29,12 @@ def _get_thread_and_user( Raises: ValueError: If the thread or user is not found. """ - thread = CommentThread().get(_id=thread_id) + backend = get_backend(course_id)() + thread = backend.get_thread(thread_id) if not thread: raise ValueError("Thread not found") - user = Users().get(_id=user_id) + user = backend.get_user(user_id) if not user: raise ValueError("User not found") @@ -42,7 +42,7 @@ def _get_thread_and_user( def _prepare_thread_response( - thread: dict[str, Any], user: dict[str, Any] + thread: dict[str, Any], user: dict[str, Any], backend: Any ) -> dict[str, Any]: """ Prepares the serialized response data after voting. @@ -64,13 +64,15 @@ def _prepare_thread_response( "username": user["username"], "type": "thread", } - serializer = ThreadSerializer(data=context) + serializer = ThreadSerializer(data=context, backend=backend) if not serializer.is_valid(): raise ValueError(serializer.errors) return serializer.data -def update_thread_votes(thread_id: str, user_id: str, value: str) -> dict[str, Any]: +def update_thread_votes( + thread_id: str, user_id: str, value: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Updates the votes for a thread. @@ -79,6 +81,7 @@ def update_thread_votes(thread_id: str, user_id: str, value: str) -> dict[str, A user_id (str): The ID of the user. value (str): The vote value ("up" or "down"). """ + backend = get_backend(course_id)() data = {"user_id": user_id, "value": value} vote_serializer = VotesInputSerializer(data=data) @@ -91,17 +94,23 @@ def update_thread_votes(thread_id: str, user_id: str, value: str) -> dict[str, A raise ForumV2RequestError(str(error)) from error if vote_serializer.data["value"] == "up": - is_updated = upvote_content(thread, user) + is_updated = backend.upvote_content( + thread_id, user_id, entity_type="CommentThread" + ) else: - is_updated = downvote_content(thread, user) + is_updated = backend.downvote_content( + thread_id, user_id, entity_type="CommentThread" + ) if is_updated: - thread = CommentThread().get(_id=thread_id) or {} + thread = backend.get_thread(thread_id) or {} - return _prepare_thread_response(thread, user) + return _prepare_thread_response(thread, user, backend) -def delete_thread_vote(thread_id: str, user_id: str) -> dict[str, Any]: +def delete_thread_vote( + thread_id: str, user_id: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Deletes the vote for a thread. @@ -109,19 +118,24 @@ def delete_thread_vote(thread_id: str, user_id: str) -> dict[str, Any]: thread_id (str): The ID of the thread. user_id (str): The ID of the user. """ + backend = get_backend(course_id)() try: - thread, user = _get_thread_and_user(thread_id, user_id) + _, user = _get_thread_and_user(thread_id, user_id) except ValueError as error: raise ForumV2RequestError(str(error)) from error - if remove_vote(thread, user): - thread = CommentThread().get(_id=thread_id) or {} + deleted_thread = None + if backend.remove_vote(thread_id, user_id, entity_type="CommentThread"): + deleted_thread = backend.get_thread(thread_id) + + if not deleted_thread: + raise ForumV2RequestError("Thread not found") - return _prepare_thread_response(thread, user) + return _prepare_thread_response(deleted_thread, user, backend) def _get_comment_and_user( - comment_id: str, user_id: str + comment_id: str, user_id: str, backend: Any ) -> tuple[dict[str, Any], dict[str, Any]]: """ Fetches the comment and user based on provided IDs. @@ -136,11 +150,11 @@ def _get_comment_and_user( Raises: ValueError: If the comment or user is not found. """ - comment = Comment().get(_id=comment_id) + comment = backend.get_comment(comment_id) if not comment: raise ValueError("Comment not found") - user = Users().get(_id=user_id) + user = backend.get_user(user_id) if not user: raise ValueError("User not found") @@ -148,7 +162,7 @@ def _get_comment_and_user( def _prepare_comment_response( - comment: dict[str, Any], user: dict[str, Any] + comment: dict[str, Any], user: dict[str, Any], backend: Any ) -> dict[str, Any]: """ Prepares the serialized response data after voting. @@ -171,13 +185,15 @@ def _prepare_comment_response( "type": "comment", "thread_id": str(comment.get("comment_thread_id", None)), } - serializer = CommentSerializer(data=context) + serializer = CommentSerializer(data=context, backend=backend) if not serializer.is_valid(): raise ValueError(serializer.errors) return serializer.data -def update_comment_votes(comment_id: str, user_id: str, value: str) -> dict[str, Any]: +def update_comment_votes( + comment_id: str, user_id: str, value: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Updates the votes for a comment. @@ -186,6 +202,7 @@ def update_comment_votes(comment_id: str, user_id: str, value: str) -> dict[str, user_id (str): The ID of the user. value (str): The vote value ("up" or "down"). """ + backend = get_backend(course_id)() data = {"user_id": user_id, "value": value} vote_serializer = VotesInputSerializer(data=data) @@ -193,22 +210,30 @@ def update_comment_votes(comment_id: str, user_id: str, value: str) -> dict[str, raise ForumV2RequestError(vote_serializer.errors) try: - comment, user = _get_comment_and_user(comment_id, user_id) + _, user = _get_comment_and_user(comment_id, user_id, backend) except ValueError as error: raise ForumV2RequestError(str(error)) from error if vote_serializer.data["value"] == "up": - is_updated = upvote_content(comment, user) + is_updated = backend.upvote_content(comment_id, user_id, entity_type="Comment") else: - is_updated = downvote_content(comment, user) + is_updated = backend.downvote_content( + comment_id, user_id, entity_type="Comment" + ) + updated_comment = None if is_updated: - comment = Comment().get(_id=comment_id) or {} + updated_comment = backend.get_comment(comment_id) + + if not updated_comment: + raise ForumV2RequestError("Comment not found") - return _prepare_comment_response(comment, user) + return _prepare_comment_response(updated_comment, user, backend) -def delete_comment_vote(comment_id: str, user_id: str) -> dict[str, Any]: +def delete_comment_vote( + comment_id: str, user_id: str, course_id: Optional[str] = None +) -> dict[str, Any]: """ Deletes the vote for a comment. @@ -216,12 +241,17 @@ def delete_comment_vote(comment_id: str, user_id: str) -> dict[str, Any]: comment_id (str): The ID of the comment. user_id (str): The ID of the user. """ + backend = get_backend(course_id)() try: - comment, user = _get_comment_and_user(comment_id, user_id) + _, user = _get_comment_and_user(comment_id, user_id, backend) except ValueError as error: raise ForumV2RequestError(str(error)) from error - if remove_vote(comment, user): - comment = Comment().get(_id=comment_id) or {} + deleted_comment = None + if backend.remove_vote(comment_id, user_id, entity_type="Comment"): + deleted_comment = backend.get_comment(comment_id) + + if not deleted_comment: + raise ForumV2RequestError("Comment not found") - return _prepare_comment_response(comment, user) + return _prepare_comment_response(deleted_comment, user, backend) diff --git a/forum/backend.py b/forum/backend.py new file mode 100644 index 00000000..34198c7f --- /dev/null +++ b/forum/backend.py @@ -0,0 +1,22 @@ +"""Backend module for forum.""" + +from typing import Callable, Optional + +from forum.backends.mongodb.api import MongoBackend +from forum.backends.mysql.api import MySQLBackend + + +def get_backend( + course_id: Optional[str] = "", +) -> Callable[[], MongoBackend | MySQLBackend]: + """Return a factory function that lazily loads the backend API based on course_id.""" + + def _get_backend() -> MongoBackend | MySQLBackend: + if not course_id: + # Lazy loading MongoBackend + return MongoBackend() + # TODO: add condition for course waffle flag. + # Lazy loading MySQLBackend + return MySQLBackend() + + return _get_backend diff --git a/forum/backends/backend.py b/forum/backends/backend.py new file mode 100644 index 00000000..01323cd4 --- /dev/null +++ b/forum/backends/backend.py @@ -0,0 +1,446 @@ +"""Forum backend interface class.""" + +from typing import Any, Optional + + +class AbstractBackend: + """Abstract backend interface class.""" + + @classmethod + def update_stats_for_course( + cls, user_id: str, course_id: str, **kwargs: Any + ) -> None: + """Update statistics for a course.""" + raise NotImplementedError + + @classmethod + def flag_as_abuse( + cls, user_id: str, entity_id: str, **kwargs: Any + ) -> dict[str, Any]: + """Flag an entity as abuse.""" + raise NotImplementedError + + @classmethod + def update_stats_after_unflag( + cls, user_id: str, entity_id: str, has_no_historical_flags: bool, **kwargs: Any + ) -> None: + """Update statistics after unflagging an entity.""" + raise NotImplementedError + + @classmethod + def un_flag_as_abuse( + cls, user_id: str, entity_id: str, **kwargs: Any + ) -> dict[str, Any]: + """Unflag an entity as abuse.""" + raise NotImplementedError + + @classmethod + def un_flag_all_as_abuse(cls, entity_id: str, **kwargs: Any) -> dict[str, Any]: + """Unflag all entities as abuse.""" + raise NotImplementedError + + @staticmethod + def update_vote( + content_id: str, + user_id: str, + vote_type: str = "", + is_deleted: bool = False, + **kwargs: Any + ) -> bool: + """Update vote for a content.""" + raise NotImplementedError + + @classmethod + def upvote_content(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """Upvote a content.""" + raise NotImplementedError + + @classmethod + def downvote_content(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """Downvote a content.""" + raise NotImplementedError + + @classmethod + def remove_vote(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """Remove a vote for a content.""" + raise NotImplementedError + + @staticmethod + def validate_thread_and_user( + user_id: str, thread_id: str + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Validate a thread and user.""" + raise NotImplementedError + + @staticmethod + def pin_unpin_thread(thread_id: str, action: str) -> None: + """Pin or unpin a thread.""" + raise NotImplementedError + + @staticmethod + def get_pinned_unpinned_thread_serialized_data( + user_id: str, thread_id: str, serializer_class: Any + ) -> dict[str, Any]: + """Get pinned or unpinned thread serialized data.""" + raise NotImplementedError + + @classmethod + def handle_pin_unpin_thread_request( + cls, user_id: str, thread_id: str, action: str, serializer_class: Any + ) -> dict[str, Any]: + """Handle pin or unpin thread request.""" + raise NotImplementedError + + @staticmethod + def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]: + """Get abuse flagged count.""" + raise NotImplementedError + + @staticmethod + def get_read_states( + thread_ids: list[str], user_id: str, course_id: str + ) -> dict[str, list[Any]]: + """Get read states.""" + raise NotImplementedError + + @staticmethod + def get_endorsed(thread_ids: list[str]) -> dict[str, bool]: + """Get endorsed.""" + raise NotImplementedError + + @staticmethod + def get_user_read_state_by_course_id( + user_id: str, course_id: str + ) -> dict[str, Any]: + """Get user read state by course id.""" + raise NotImplementedError + + @classmethod + def handle_threads_query( + cls, + comment_thread_ids: list[str], + user_id: str, + course_id: str, + group_ids: list[int], + author_id: Optional[str], + thread_type: Optional[str], + filter_flagged: bool, + filter_unread: bool, + filter_unanswered: bool, + filter_unresponded: bool, + count_flagged: bool, + sort_key: str, + page: int, + per_page: int, + context: str = "course", + raw_query: bool = False, + ) -> dict[str, Any]: + """Handle threads query.""" + raise NotImplementedError + + @staticmethod + def prepare_thread( + thread_id: str, + is_read: bool, + unread_count: int, + is_endorsed: bool, + abuse_flagged_count: int, + ) -> dict[str, Any]: + """Prepare thread.""" + raise NotImplementedError + + @classmethod + def threads_presentor( + cls, + thread_ids: list[str], + user_id: str, + course_id: str, + count_flagged: bool = False, + ) -> list[dict[str, Any]]: + """Threads presenter.""" + raise NotImplementedError + + @staticmethod + def get_username_from_id(user_id: str) -> Optional[str]: + """Get username from id.""" + raise NotImplementedError + + @staticmethod + def validate_object(model: str, obj_id: str) -> Any: + """Validate object.""" + raise NotImplementedError + + @staticmethod + def find_subscribed_threads( + user_id: str, course_id: Optional[str] = None + ) -> list[str]: + """Find subscribed threads.""" + raise NotImplementedError + + @staticmethod + def subscribe_user( + user_id: str, source_id: str, source_type: str + ) -> dict[str, Any] | None: + """Subscribe user.""" + raise NotImplementedError + + @staticmethod + def unsubscribe_user(user_id: str, source_id: str, source_type: str) -> None: + """Unsubscribe user.""" + raise NotImplementedError + + @staticmethod + def delete_comments_of_a_thread(thread_id: str) -> None: + """Delete comments of a thread.""" + raise NotImplementedError + + @staticmethod + def delete_subscriptions_of_a_thread(thread_id: str) -> None: + """Delete subscriptions of a thread.""" + raise NotImplementedError + + @staticmethod + def validate_params(params: dict[str, Any], user_id: Optional[str] = None) -> Any: + """Validate params.""" + raise NotImplementedError + + @classmethod + def get_threads( + cls, + params: dict[str, Any], + user_id: str, + serializer: Any, + thread_ids: list[str], + ) -> dict[str, Any]: + """Get threads.""" + raise NotImplementedError + + @staticmethod + def get_commentables_counts_based_on_type(course_id: str) -> dict[str, Any]: + """Get commentables counts based on type.""" + raise NotImplementedError + + @classmethod + def get_user_voted_ids(cls, user_id: str, vote: str) -> list[str]: + """Get user voted ids.""" + raise NotImplementedError + + @staticmethod + def filter_standalone_threads(comment_ids: list[str]) -> list[str]: + """Filter standalone threads.""" + raise NotImplementedError + + @classmethod + def user_to_hash( + cls, user_id: str, params: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: + """User to hash.""" + raise NotImplementedError + + @staticmethod + def replace_username_in_all_content(user_id: str, username: str) -> None: + """Replace username in all content.""" + raise NotImplementedError + + @staticmethod + def unsubscribe_all(user_id: str) -> None: + """Unsubscribe all.""" + raise NotImplementedError + + @staticmethod + def retire_all_content(user_id: str, username: str) -> None: + """Retire all content.""" + raise NotImplementedError + + @staticmethod + def find_or_create_read_state(user_id: str, thread_id: str) -> dict[str, Any]: + """Find or create read state.""" + raise NotImplementedError + + @classmethod + def mark_as_read(cls, user_id: str, thread_id: str) -> None: + """Mark as read.""" + raise NotImplementedError + + @staticmethod + def find_or_create_user_stats(user_id: str, course_id: str) -> dict[str, Any]: + """Find or create user stats.""" + raise NotImplementedError + + @staticmethod + def update_user_stats_for_course(user_id: str, stat: dict[str, Any]) -> None: + """Update user stats for course.""" + raise NotImplementedError + + @classmethod + def build_course_stats(cls, author_id: str, course_id: str) -> None: + """Build course stats.""" + raise NotImplementedError + + @classmethod + def update_all_users_in_course(cls, course_id: str) -> list[str]: + """Update all users in course.""" + raise NotImplementedError + + @staticmethod + def get_user_by_username(username: str | None) -> dict[str, Any] | None: + """Get user by username.""" + raise NotImplementedError + + @staticmethod + def find_or_create_user( + user_id: str, username: Optional[str] = "", default_sort_key: Optional[str] = "" + ) -> str: + """Find or create user.""" + raise NotImplementedError + + @staticmethod + def get_comment(comment_id: str) -> dict[str, Any] | None: + """Get comment.""" + raise NotImplementedError + + @staticmethod + def get_thread(thread_id: str) -> dict[str, Any] | None: + """Get thread.""" + raise NotImplementedError + + @staticmethod + def get_comments(**kwargs: Any) -> list[dict[str, Any]]: + """Get comments.""" + raise NotImplementedError + + @classmethod + def create_comment(cls, data: dict[str, Any]) -> Any: + """Create comment.""" + raise NotImplementedError + + @staticmethod + def delete_comment(comment_id: str) -> None: + """Delete comment.""" + raise NotImplementedError + + @staticmethod + def update_comment(comment_id: str, **kwargs: Any) -> int: + """Update comment.""" + raise NotImplementedError + + @staticmethod + def get_thread_id_from_comment(comment_id: str) -> dict[str, Any] | None: + """Get thread id from comment.""" + raise NotImplementedError + + @staticmethod + def get_user(user_id: str) -> dict[str, Any] | None: + """Get user.""" + raise NotImplementedError + + @staticmethod + def get_subscription( + subscriber_id: str, source_id: str, **kwargs: Any + ) -> dict[str, Any] | None: + """Get subscription.""" + raise NotImplementedError + + @staticmethod + def get_subscriptions(query: dict[str, Any]) -> list[dict[str, Any]]: + """Get subscriptions.""" + raise NotImplementedError + + @staticmethod + def delete_thread(thread_id: str) -> int: + """Delete thread.""" + raise NotImplementedError + + @staticmethod + def create_thread(data: dict[str, Any]) -> str: + """Create thread.""" + raise NotImplementedError + + @staticmethod + def update_thread(thread_id: str, **kwargs: Any) -> int: + """Update thread.""" + raise NotImplementedError + + @staticmethod + def get_filtered_threads(query: dict[str, Any]) -> list[dict[str, Any]]: + """Get filtered threads.""" + raise NotImplementedError + + @staticmethod + def update_user(user_id: str, data: dict[str, Any]) -> int: + """Update user.""" + raise NotImplementedError + + @staticmethod + def get_thread_id_by_comment_id(parent_comment_id: str) -> str: + """ + The thread Id from the parent comment. + """ + raise NotImplementedError + + @staticmethod + def update_comment_and_get_updated_comment( + comment_id: str, + body: Optional[str] = None, + course_id: Optional[str] = None, + user_id: Optional[str] = None, + anonymous: Optional[bool] = False, + anonymous_to_peers: Optional[bool] = False, + endorsed: Optional[bool] = False, + closed: Optional[bool] = False, + editing_user_id: Optional[str] = None, + edit_reason_code: Optional[str] = None, + endorsement_user_id: Optional[str] = None, + ) -> dict[str, Any] | None: + """Update comment and get updated comment.""" + raise NotImplementedError + + @staticmethod + def get_contents(**kwargs: Any) -> list[dict[str, Any]]: + """Get contents.""" + raise NotImplementedError + + @staticmethod + def get_users(**kwargs: Any) -> list[dict[str, Any]]: + """Get users.""" + raise NotImplementedError + + @staticmethod + def get_user_sort_criterion(sort_by: str) -> dict[str, Any]: + """Get sort criterion.""" + raise NotImplementedError + + @staticmethod + def get_thread_index_name() -> str: + """Get the name of the thread index.""" + return "comment_threads" + + @staticmethod + def get_votes_dict(up: list[str], down: list[str]) -> dict[str, Any]: + """ + Calculates and returns the vote summary for a thread. + + Args: + up (list): A list of user IDs who upvoted the thread. + down (list): A list of user IDs who downvoted the thread. + + Returns: + dict: A dictionary containing the vote summary with the following keys: + - "up" (list): The list of user IDs who upvoted. + - "down" (list): The list of user IDs who downvoted. + - "up_count" (int): The count of upvotes. + - "down_count" (int): The count of downvotes. + - "count" (int): The total number of votes (upvotes + downvotes). + - "point" (int): The vote score (upvotes - downvotes). + """ + up = up or [] + down = down or [] + votes = { + "up": up, + "down": down, + "up_count": len(up), + "down_count": len(down), + "count": len(up) + len(down), + "point": len(up) - len(down), + } + return votes diff --git a/forum/backends/mongodb/api.py b/forum/backends/mongodb/api.py index a978b454..c8d45d91 100644 --- a/forum/backends/mongodb/api.py +++ b/forum/backends/mongodb/api.py @@ -2,11 +2,12 @@ import math from datetime import datetime, timezone -from typing import Any, Optional, Union +from typing import Any, Optional from bson import ObjectId from django.core.exceptions import ObjectDoesNotExist +from forum.backends.backend import AbstractBackend from forum.backends.mongodb import ( Comment, CommentThread, @@ -20,1414 +21,1648 @@ get_group_ids_from_params, get_sort_criteria, make_aware, + str_to_bool, ) -def update_stats_for_course(user_id: str, course_id: str, **kwargs: Any) -> None: - """Update stats for a course.""" - user = Users().get(user_id) - if not user: - raise ObjectDoesNotExist - course_stats = user.get("course_stats", []) - for course_stat in course_stats: - if course_stat["course_id"] == course_id: - course_stat.update( - {k: course_stat[k] + v for k, v in kwargs.items() if k in course_stat} +class MongoBackend(AbstractBackend): + """Mongodb Backend API.""" + + @classmethod + def update_stats_for_course( + cls, user_id: str, course_id: str, **kwargs: Any + ) -> None: + """Update stats for a course.""" + user = Users().get(user_id) + if not user: + return + course_stats = user.get("course_stats", []) + for course_stat in course_stats: + if course_stat["course_id"] == course_id: + course_stat.update( + { + k: course_stat[k] + v + for k, v in kwargs.items() + if k in course_stat + } + ) + Users().update( + user_id, + course_stats=course_stats, + ) + return + cls.build_course_stats(user["_id"], course_id) + + @classmethod + def flag_as_abuse( + cls, user_id: str, entity_id: str, **kwargs: Any + ) -> dict[str, Any]: + """ + Flag an entity as abuse. + + Args: + user (dict[str, Any]): The user who is flagging the entity as abuse. + entity (dict[str, Any]): The entity being flagged as abuse. + + Returns: + dict[str, Any]: The updated entity with the abuse flag. + + Raises: + ValueError: If user ID or entity is not provided. + """ + user = Users().get(user_id) + entity = Contents().get(entity_id) + if not (user and entity): + raise ValueError("User ID or entity is not provided") + abuse_flaggers = entity["abuse_flaggers"] + first_flag_added = False + if user["_id"] not in abuse_flaggers: + abuse_flaggers.append(user["_id"]) + first_flag_added = len(abuse_flaggers) == 1 + Contents().update( + entity["_id"], + abuse_flaggers=abuse_flaggers, ) - Users().update( - user_id, - course_stats=course_stats, + if first_flag_added: + cls.update_stats_for_course( + entity["author_id"], + entity["course_id"], + active_flags=1, ) - return - build_course_stats(user["_id"], course_id) - + updated_content = Contents().get(entity["_id"]) + if not updated_content: + raise ValueError("Entity not found") + return updated_content + + @classmethod + def update_stats_after_unflag( + cls, user_id: str, entity_id: str, has_no_historical_flags: bool, **kwargs: Any + ) -> None: + """Update the stats for the course after unflagging an entity.""" + entity = Contents().get(entity_id) + if not entity: + raise ObjectDoesNotExist + + first_historical_flag = ( + has_no_historical_flags and not entity["historical_abuse_flaggers"] + ) + if first_historical_flag: + cls.update_stats_for_course(user_id, entity["course_id"], inactive_flags=1) -def flag_as_abuse( - user: dict[str, Any], entity: dict[str, Any] -) -> Union[dict[str, Any], None]: - """ - Flag an entity as abuse. + if not entity["abuse_flaggers"]: + cls.update_stats_for_course(user_id, entity["course_id"], active_flags=-1) - Args: - user (dict[str, Any]): The user who is flagging the entity as abuse. - entity (dict[str, Any]): The entity being flagged as abuse. + @classmethod + def un_flag_as_abuse( + cls, user_id: str, entity_id: str, **kwargs: Any + ) -> dict[str, Any]: + """ + Unflag an entity as abuse. - Returns: - dict[str, Any]: The updated entity with the abuse flag. + Args: + user (dict[str, Any]): The user who is unflagging the entity as abuse. + entity (dict[str, Any]): The entity being unflagged as abuse. - Raises: - ValueError: If user ID or entity is not provided. - """ + Returns: + dict[str, Any]: The updated entity with the abuse flag removed. - abuse_flaggers = entity["abuse_flaggers"] - first_flag_added = False - if user["_id"] not in abuse_flaggers: - abuse_flaggers.append(user["_id"]) - first_flag_added = len(abuse_flaggers) == 1 - Contents().update( - entity["_id"], - abuse_flaggers=abuse_flaggers, - ) - if first_flag_added: - update_stats_for_course( - entity["author_id"], - entity["course_id"], - active_flags=1, + Raises: + ValueError: If user ID or entity is not provided. + """ + user = Users().get(user_id) + entity = Contents().get(entity_id) + if not (user and entity): + raise ValueError("User ID or entity is not provided") + + has_no_historical_flags = len(entity["historical_abuse_flaggers"]) == 0 + if user["_id"] in entity["abuse_flaggers"]: + entity["abuse_flaggers"].remove(user["_id"]) + Contents().update( + entity["_id"], + abuse_flaggers=entity["abuse_flaggers"], + ) + cls.update_stats_after_unflag( + entity["author_id"], entity["_id"], has_no_historical_flags + ) + updated_content = Contents().get(entity["_id"]) + if not updated_content: + raise ValueError("Entity not found") + return updated_content + + @classmethod + def un_flag_all_as_abuse(cls, entity_id: str, **kwargs: Any) -> dict[str, Any]: + """ + Unflag an entity as abuse for all users. + + Args: + entity (dict[str, Any]): The entity being unflagged as abuse. + + Returns: + dict[str, Any]: The updated entity with all abuse flags removed. + + Raises: + ValueError: If entity is not provided. + """ + entity = Contents().get(entity_id) + if not entity: + raise ValueError("Entity is not provided") + has_no_historical_flags = len(entity["historical_abuse_flaggers"]) == 0 + historical_abuse_flaggers = list( + set(entity["historical_abuse_flaggers"]) | set(entity["abuse_flaggers"]) ) - return Contents().get(entity["_id"]) - - -def update_stats_after_unflag( - user_id: str, entity_id: str, has_no_historical_flags: bool -) -> None: - """Update the stats for the course after unflagging an entity.""" - entity = Contents().get(entity_id) - if not entity: - raise ObjectDoesNotExist - - first_historical_flag = ( - has_no_historical_flags and not entity["historical_abuse_flaggers"] - ) - if first_historical_flag: - update_stats_for_course(user_id, entity["course_id"], inactive_flags=1) - - if not entity["abuse_flaggers"]: - update_stats_for_course(user_id, entity["course_id"], active_flags=-1) - - -def un_flag_as_abuse( - user: dict[str, Any], entity: dict[str, Any] -) -> Union[dict[str, Any], None]: - """ - Unflag an entity as abuse. - - Args: - user (dict[str, Any]): The user who is unflagging the entity as abuse. - entity (dict[str, Any]): The entity being unflagged as abuse. - - Returns: - dict[str, Any]: The updated entity with the abuse flag removed. - - Raises: - ValueError: If user ID or entity is not provided. - """ - has_no_historical_flags = len(entity["historical_abuse_flaggers"]) == 0 - if user["_id"] in entity["abuse_flaggers"]: - entity["abuse_flaggers"].remove(user["_id"]) Contents().update( entity["_id"], - abuse_flaggers=entity["abuse_flaggers"], + abuse_flaggers=[], + historical_abuse_flaggers=historical_abuse_flaggers, ) - update_stats_after_unflag( + cls.update_stats_after_unflag( entity["author_id"], entity["_id"], has_no_historical_flags ) + updated_content = Contents().get(entity["_id"]) + if not updated_content: + raise ValueError("Entity not found") + return updated_content + + @staticmethod + def update_vote( + content_id: str, + user_id: str, + vote_type: str = "", + is_deleted: bool = False, + **kwargs: Any, + ) -> bool: + """ + Update a vote on a thread (either upvote or downvote). + + :param content: The content document containing vote data. + :param user: The user document for the user voting. + :param vote_type: String indicating the type of vote ('up' or 'down'). + :param is_deleted: Boolean indicating if the user is removing their vote (True) or voting (False). + :return: True if the vote was successfully updated, False otherwise. + """ + user = Users().get(user_id) + content = Contents().get(content_id) + if not (user and content): + raise ValueError("User ID or entity is not provided") - return Contents().get(entity["_id"]) - - -def un_flag_all_as_abuse(entity: dict[str, Any]) -> Union[dict[str, Any], None]: - """ - Unflag an entity as abuse for all users. - - Args: - entity (dict[str, Any]): The entity being unflagged as abuse. - - Returns: - dict[str, Any]: The updated entity with all abuse flags removed. - - Raises: - ValueError: If entity is not provided. - """ - has_no_historical_flags = len(entity["historical_abuse_flaggers"]) == 0 - historical_abuse_flaggers = list( - set(entity["historical_abuse_flaggers"]) | set(entity["abuse_flaggers"]) - ) - Contents().update( - entity["_id"], - abuse_flaggers=[], - historical_abuse_flaggers=historical_abuse_flaggers, - ) - update_stats_after_unflag( - entity["author_id"], entity["_id"], has_no_historical_flags - ) - - return Contents().get(entity["_id"]) - - -def update_vote( - content: dict[str, Any], - user: dict[str, Any], - vote_type: str = "", - is_deleted: bool = False, -) -> bool: - """ - Update a vote on a thread (either upvote or downvote). - - :param content: The content document containing vote data. - :param user: The user document for the user voting. - :param vote_type: String indicating the type of vote ('up' or 'down'). - :param is_deleted: Boolean indicating if the user is removing their vote (True) or voting (False). - :return: True if the vote was successfully updated, False otherwise. - """ - user_id: str = user["_id"] - content_id: str = str(content["_id"]) - votes: dict[str, Any] = content["votes"] - - update_needed: bool = False - - if not is_deleted: - if vote_type not in ["up", "down"]: - raise ValueError("Invalid vote_type, use ('up' or 'down')") - - if vote_type == "up": - current_votes = set(votes["up"]) - opposite_votes = set(votes["down"]) - else: - current_votes = set(votes["down"]) - opposite_votes = set(votes["up"]) - - # Check if user is voting - if user_id not in current_votes: - current_votes.add(user_id) - update_needed = True - if user_id in opposite_votes: - opposite_votes.remove(user_id) - - updated_up_votes = opposite_votes if vote_type == "down" else current_votes - updated_down_votes = current_votes if vote_type == "down" else opposite_votes - - else: - # Handle vote deletion - updated_up_votes = set(votes["up"]) - updated_down_votes = set(votes["down"]) - - if user_id in updated_up_votes: - updated_up_votes.remove(user_id) - update_needed = True - if user_id in updated_down_votes: - updated_down_votes.remove(user_id) - update_needed = True - - if update_needed: - # Prepare updated votes - content_model = Contents() - updated_votes = content_model.get_votes_dict( - list(updated_up_votes), list(updated_down_votes) - ) - updated_count = content_model.update_votes( - content_id=content_id, votes=updated_votes - ) - return bool(updated_count) + votes: dict[str, Any] = content["votes"] + update_needed: bool = False - return False + if not is_deleted: + if vote_type not in ["up", "down"]: + raise ValueError("Invalid vote_type, use ('up' or 'down')") + if vote_type == "up": + current_votes = set(votes["up"]) + opposite_votes = set(votes["down"]) + else: + current_votes = set(votes["down"]) + opposite_votes = set(votes["up"]) + + # Check if user is voting + if user_id not in current_votes: + current_votes.add(user_id) + update_needed = True + if user_id in opposite_votes: + opposite_votes.remove(user_id) + + updated_up_votes = opposite_votes if vote_type == "down" else current_votes + updated_down_votes = ( + current_votes if vote_type == "down" else opposite_votes + ) -def upvote_content(thread: dict[str, Any], user: dict[str, Any]) -> bool: - """ - Upvotes the specified thread or comment by the given user. + else: + # Handle vote deletion + updated_up_votes = set(votes["up"]) + updated_down_votes = set(votes["down"]) + + if user_id in updated_up_votes: + updated_up_votes.remove(user_id) + update_needed = True + if user_id in updated_down_votes: + updated_down_votes.remove(user_id) + update_needed = True + + if update_needed: + # Prepare updated votes + content_model = Contents() + updated_votes = content_model.get_votes_dict( + list(updated_up_votes), list(updated_down_votes) + ) + updated_count = content_model.update_votes( + content_id=content_id, votes=updated_votes + ) + return bool(updated_count) - Args: - thread (dict): The thread or comment data to be upvoted. - user (dict): The user who is performing the upvote. + return False - Returns: - bool: True if the vote was successfully updated, False otherwise. - """ - return update_vote(thread, user, vote_type="up") + @classmethod + def upvote_content(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """ + Upvotes the specified thread or comment by the given user. + Args: + thread (dict): The thread or comment data to be upvoted. + user (dict): The user who is performing the upvote. -def downvote_content(thread: dict[str, Any], user: dict[str, Any]) -> bool: - """ - Downvotes the specified thread or comment by the given user. + Returns: + bool: True if the vote was successfully updated, False otherwise. + """ + user = Users().get(user_id) + entity = Contents().get(entity_id) + if not (user and entity): + raise ValueError("User ID or entity is not provided") - Args: - thread (dict): The thread or comment data to be downvoted. - user (dict): The user who is performing the downvote. + return cls.update_vote(entity["_id"], user["external_id"], vote_type="up") - Returns: - bool: True if the vote was successfully updated, False otherwise. - """ - return update_vote(thread, user, vote_type="down") + @classmethod + def downvote_content(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """ + Downvotes the specified thread or comment by the given user. + Args: + thread (dict): The thread or comment data to be downvoted. + user (dict): The user who is performing the downvote. -def remove_vote(thread: dict[str, Any], user: dict[str, Any]) -> bool: - """ - Remove the vote (upvote or downvote) from the specified thread or comment for the given user. + Returns: + bool: True if the vote was successfully updated, False otherwise. + """ + user = Users().get(user_id) + entity = Contents().get(entity_id) + if not (user and entity): + raise ValueError("User ID or entity is not provided") - Args: - thread (dict): The thread or comment data from which the vote should be removed. - user (dict): The user who is removing their vote. + return cls.update_vote(entity["_id"], user["external_id"], vote_type="down") - Returns: - bool: True if the vote was successfully removed, False otherwise. - """ - return update_vote(thread, user, is_deleted=True) + @classmethod + def remove_vote(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """ + Remove the vote (upvote or downvote) from the specified thread or comment for the given user. + Args: + thread (dict): The thread or comment data from which the vote should be removed. + user (dict): The user who is removing their vote. -def validate_thread_and_user( - user_id: str, thread_id: str -) -> tuple[dict[str, Any], dict[str, Any]]: - """ - Validate thread and user. + Returns: + bool: True if the vote was successfully removed, False otherwise. + """ + user = Users().get(user_id) + entity = Contents().get(entity_id) + if not (user and entity): + raise ValueError("User ID or entity is not provided") + + return cls.update_vote(entity["_id"], user["external_id"], is_deleted=True) + + @staticmethod + def validate_thread_and_user( + user_id: str, thread_id: str + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Validate thread and user. + + Arguments: + user_id (str): The ID of the user making the request. + thread_id (str): The ID of the thread. + + Returns: + tuple[dict[str, Any], dict[str, Any]]: A tuple containing the user and thread data. + + Raises: + ValueError: If the thread or user is not found. + """ + thread = CommentThread().get(thread_id) + user = Users().get(user_id) + if not (thread and user): + raise ValueError("User / Thread doesn't exist") + + return user, thread + + @staticmethod + def pin_unpin_thread(thread_id: str, action: str) -> None: + """ + Pin or unpin the thread based on action parameter. + + Arguments: + thread_id (str): The ID of the thread to pin/unpin. + action (str): The action to perform ("pin" or "unpin"). + """ + CommentThread().update(thread_id, pinned=action == "pin") + + @classmethod + def get_pinned_unpinned_thread_serialized_data( + cls, user_id: str, thread_id: str, serializer_class: Any + ) -> dict[str, Any]: + """ + Return serialized data of pinned or unpinned thread. + + Arguments: + user (dict[str, Any]): The user who requested the action. + thread_id (str): The ID of the thread to pin/unpin. + + Returns: + dict[str, Any]: The serialized data of the pinned/unpinned thread. + + Raises: + ValueError: If the serialization is not valid. + """ + user = Users().get(user_id) + updated_thread = CommentThread().get(thread_id) + if not (user and updated_thread): + raise ValueError("User ID or entity is not provided") + + context = { + "user_id": user["_id"], + "username": user["username"], + "type": "thread", + "id": thread_id, + } + if updated_thread is not None: + context = {**context, **updated_thread} + serializer = serializer_class(data=context, backend=cls) + if not serializer.is_valid(): + raise ValueError(serializer.errors) + + return serializer.data + + @classmethod + def handle_pin_unpin_thread_request( + cls, user_id: str, thread_id: str, action: str, serializer_class: Any + ) -> dict[str, Any]: + """ + Catches pin/unpin thread request. + + - validates thread and user. + - pin or unpin the thread based on action parameter. + - return serialized data of thread. + + Arguments: + user_id (str): The ID of the user making the request. + thread_id (str): The ID of the thread to pin/unpin. + action (str): The action to perform ("pin" or "unpin"). + + Returns: + dict[str, Any]: The serialized data of the pinned/unpinned thread. + """ + user, _ = cls.validate_thread_and_user(user_id, thread_id) + cls.pin_unpin_thread(thread_id, action) + return cls.get_pinned_unpinned_thread_serialized_data( + user["external_id"], thread_id, serializer_class + ) - Arguments: - user_id (str): The ID of the user making the request. - thread_id (str): The ID of the thread. + @staticmethod + def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]: + """ + Retrieves the count of abuse-flagged comments for each thread in the provided list of thread IDs. - Returns: - tuple[dict[str, Any], dict[str, Any]]: A tuple containing the user and thread data. + Args: + thread_ids (list[str]): List of thread IDs to check for abuse flags. - Raises: - ValueError: If the thread or user is not found. - """ - thread = CommentThread().get(thread_id) - user = Users().get(user_id) - if not (thread and user): - raise ValueError("User / Thread doesn't exist") + Returns: + dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count. + """ + pipeline: list[dict[str, Any]] = [ + { + "$match": { + "comment_thread_id": {"$in": [ObjectId(tid) for tid in thread_ids]}, + "abuse_flaggers": {"$ne": []}, + } + }, + {"$group": {"_id": "$comment_thread_id", "flagged_count": {"$sum": 1}}}, + ] + flagged_threads = Contents().aggregate(pipeline) + + return {str(item["_id"]): item["flagged_count"] for item in flagged_threads} + + @staticmethod + def get_read_states( + thread_ids: list[str], user_id: str, course_id: str + ) -> dict[str, list[Any]]: + """ + Retrieves the read state and unread comment count for each thread in the provided list. + + Args: + threads (list[dict[str, Any]]): list of threads to check read state for. + user_id (str): The ID of the user whose read states are being retrieved. + course_id (str): The course ID associated with the threads. + + Returns: + dict[str, list[Any]]: A dictionary mapping thread IDs to a list containing + whether the thread is read and the unread comment count. + """ + threads = CommentThread().find( + {"_id": {"$in": [ObjectId(thread_id) for thread_id in thread_ids]}} + ) + read_states = {} + if user_id: + user = Users().find_one( + {"_id": user_id, "read_states.course_id": course_id} + ) + read_state = user["read_states"][0] if user else {} + if read_state: + read_dates = read_state.get("last_read_times", {}) + for thread in threads: + thread_key = str(thread["_id"]) + if thread_key in read_dates: + read_date = make_aware(read_dates[thread_key]) + last_activity_at = make_aware(thread["last_activity_at"]) + is_read = read_date >= last_activity_at + unread_comment_count = Contents().count_documents( + { + "comment_thread_id": ObjectId(thread_key), + "created_at": {"$gte": read_dates[thread_key]}, + "author_id": {"$ne": str(user_id)}, + } + ) + read_states[thread_key] = [is_read, unread_comment_count] + + return read_states + + @staticmethod + def get_endorsed(thread_ids: list[str]) -> dict[str, bool]: + """ + Retrieves endorsed status for each thread in the provided list of thread IDs. + + Args: + thread_ids (list[str]): List of thread IDs to check for endorsement. + + Returns: + dict[str, bool]: A dictionary of thread IDs to their endorsed status (True if endorsed, False otherwise). + """ + endorsed_comments = Comment().find( + { + "comment_thread_id": {"$in": [ObjectId(tid) for tid in thread_ids]}, + "endorsed": True, + } + ) - return user, thread + return {str(item["comment_thread_id"]): True for item in endorsed_comments} + @staticmethod + def get_user_read_state_by_course_id( + user_id: str, course_id: str + ) -> dict[str, Any]: + """ + Retrieves the user's read state for a specific course. -def pin_unpin_thread(thread_id: str, action: str) -> None: - """ - Pin or unpin the thread based on action parameter. + Args: + user (dict[str, Any]): The user object containing read states. + course_id (str): The course ID to filter the user's read state by. - Arguments: - thread_id (str): The ID of the thread to pin/unpin. - action (str): The action to perform ("pin" or "unpin"). - """ - CommentThread().update(thread_id, pinned=action == "pin") + Returns: + dict[str, Any]: The user's read state for the specified course, or an empty dictionary if not found. + """ + user = Users().get(user_id) + if not user: + raise ValueError("User does not exist.") + + for read_state in user.get("read_states", []): + if read_state["course_id"] == course_id: + return read_state + return {} + + # TODO: Make this function modular + # pylint: disable=too-many-nested-blocks,too-many-statements + @classmethod + def handle_threads_query( + cls, + comment_thread_ids: list[str], + user_id: str, + course_id: str, + group_ids: list[int], + author_id: Optional[str], + thread_type: Optional[str], + filter_flagged: bool, + filter_unread: bool, + filter_unanswered: bool, + filter_unresponded: bool, + count_flagged: bool, + sort_key: str, + page: int, + per_page: int, + context: str = "course", + raw_query: bool = False, + ) -> dict[str, Any]: + """ + Handles complex thread queries based on various filters and returns paginated results. + + Args: + comment_thread_ids (list[str]): List of comment thread IDs to filter. + user_id (str): The ID of the user making the request. + course_id (str): The course ID associated with the threads. + group_ids (list[int]): List of group IDs for group-based filtering. + author_id (str): The ID of the author to filter threads by. + thread_type (str): The type of thread to filter by. + filter_flagged (bool): Whether to filter threads flagged for abuse. + filter_unread (bool): Whether to filter unread threads. + filter_unanswered (bool): Whether to filter unanswered questions. + filter_unresponded (bool): Whether to filter threads with no responses. + count_flagged (bool): Whether to include flagged content count. + sort_key (str): The key to sort the threads by. + page (int): The page number for pagination. + per_page (int): The number of threads per page. + context (str): The context to filter threads by. + raw_query (bool): Whether to return raw query results without further processing. + + Returns: + dict[str, Any]: A dictionary containing the paginated thread results and associated metadata. + """ + # Convert thread_ids to ObjectId + comment_thread_obj_ids: list[ObjectId] = [ + ObjectId(tid) for tid in comment_thread_ids + ] + # Base query + base_query: dict[str, Any] = { + "_id": {"$in": comment_thread_obj_ids}, + "context": context, + } -def get_pinned_unpinned_thread_serialized_data( - user: dict[str, Any], thread_id: str, serializer_class: Any -) -> dict[str, Any]: - """ - Return serialized data of pinned or unpinned thread. + # Group filtering + if group_ids: + base_query["$or"] = [ + {"group_id": {"$in": group_ids}}, + {"group_id": {"$exists": False}}, + ] - Arguments: - user (dict[str, Any]): The user who requested the action. - thread_id (str): The ID of the thread to pin/unpin. + # Author filtering + if author_id: + base_query["author_id"] = author_id + if author_id != user_id: + base_query["anonymous"] = False + base_query["anonymous_to_peers"] = False - Returns: - dict[str, Any]: The serialized data of the pinned/unpinned thread. + # Thread type filtering + if thread_type: + base_query["thread_type"] = thread_type - Raises: - ValueError: If the serialization is not valid. - """ - updated_thread = CommentThread().get(thread_id) - context = { - "user_id": user["_id"], - "username": user["username"], - "type": "thread", - "id": thread_id, - } - if updated_thread is not None: - context = {**context, **updated_thread} - serializer = serializer_class(data=context) - if not serializer.is_valid(): - raise ValueError(serializer.errors) - - return serializer.data - - -def handle_pin_unpin_thread_request( - user_id: str, thread_id: str, action: str, serializer_class: Any -) -> dict[str, Any]: - """ - Catches pin/unpin thread request. - - - validates thread and user. - - pin or unpin the thread based on action parameter. - - return serialized data of thread. - - Arguments: - user_id (str): The ID of the user making the request. - thread_id (str): The ID of the thread to pin/unpin. - action (str): The action to perform ("pin" or "unpin"). - - Returns: - dict[str, Any]: The serialized data of the pinned/unpinned thread. - """ - user, _ = validate_thread_and_user(user_id, thread_id) - pin_unpin_thread(thread_id, action) - return get_pinned_unpinned_thread_serialized_data(user, thread_id, serializer_class) - - -def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]: - """ - Retrieves the count of abuse-flagged comments for each thread in the provided list of thread IDs. - - Args: - thread_ids (list[str]): List of thread IDs to check for abuse flags. - - Returns: - dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count. - """ - pipeline: list[dict[str, Any]] = [ - { - "$match": { - "comment_thread_id": {"$in": [ObjectId(tid) for tid in thread_ids]}, - "abuse_flaggers": {"$ne": []}, + # Flagged content filtering + if filter_flagged: + flagged_query = { + "course_id": course_id, + "abuse_flaggers": {"$ne": [], "$exists": True}, } - }, - {"$group": {"_id": "$comment_thread_id", "flagged_count": {"$sum": 1}}}, - ] - flagged_threads = Contents().aggregate(pipeline) - - return {str(item["_id"]): item["flagged_count"] for item in flagged_threads} - - -def get_read_states( - threads: list[dict[str, Any]], user_id: str, course_id: str -) -> dict[str, list[Any]]: - """ - Retrieves the read state and unread comment count for each thread in the provided list. - - Args: - threads (list[dict[str, Any]]): list of threads to check read state for. - user_id (str): The ID of the user whose read states are being retrieved. - course_id (str): The course ID associated with the threads. - - Returns: - dict[str, list[Any]]: A dictionary mapping thread IDs to a list containing - whether the thread is read and the unread comment count. - """ - read_states = {} - if user_id: - user = Users().find_one({"_id": user_id, "read_states.course_id": course_id}) - read_state = user["read_states"][0] if user else {} - if read_state: - read_dates = read_state.get("last_read_times", {}) - for thread in threads: - thread_key = str(thread["_id"]) - if thread_key in read_dates: - read_date = make_aware(read_dates[thread_key]) - last_activity_at = make_aware(thread["last_activity_at"]) - is_read = read_date >= last_activity_at - unread_comment_count = Contents().count_documents( - { - "comment_thread_id": ObjectId(thread_key), - "created_at": {"$gte": read_dates[thread_key]}, - "author_id": {"$ne": str(user_id)}, - } - ) - read_states[thread_key] = [is_read, unread_comment_count] - - return read_states - - -def get_filtered_thread_ids( - thread_ids: list[str], context: str, group_ids: list[str] -) -> set[str]: - """ - Filters thread IDs based on context and group ID criteria. - - Args: - thread_ids (list[str]): List of thread IDs to filter. - context (str): The context to filter by. - group_ids (list[str]): List of group IDs for group-based filtering. - - Returns: - set: A set of filtered thread IDs based on the context and group ID criteria. - """ - context_query = { - "_id": {"$in": [ObjectId(tid) for tid in thread_ids]}, - "context": context, - } - context_threads = CommentThread().find(context_query) - context_thread_ids = {str(thread["_id"]) for thread in context_threads} - - if not group_ids: - return context_thread_ids - - group_query = { - "_id": {"$in": [ObjectId(tid) for tid in thread_ids]}, - "$or": [ - {"group_id": {"$in": group_ids}}, - {"group_id": {"$exists": False}}, - ], - } - group_threads = CommentThread().find(group_query) - group_thread_ids = {str(thread["_id"]) for thread in group_threads} - - return context_thread_ids.union(group_thread_ids) - - -def get_endorsed(thread_ids: list[str]) -> dict[str, bool]: - """ - Retrieves endorsed status for each thread in the provided list of thread IDs. - - Args: - thread_ids (list[str]): List of thread IDs to check for endorsement. - - Returns: - dict[str, bool]: A dictionary mapping thread IDs to their endorsed status (True if endorsed, False otherwise). - """ - endorsed_comments = Comment().find( - { - "comment_thread_id": {"$in": [ObjectId(tid) for tid in thread_ids]}, - "endorsed": True, - } - ) - - return {str(item["comment_thread_id"]): True for item in endorsed_comments} - - -def get_user_read_state_by_course_id( - user: dict[str, Any], course_id: str -) -> dict[str, Any]: - """ - Retrieves the user's read state for a specific course. - - Args: - user (dict[str, Any]): The user object containing read states. - course_id (str): The course ID to filter the user's read state by. - - Returns: - dict[str, Any]: The user's read state for the specified course, or an empty dictionary if not found. - """ - for read_state in user.get("read_states", []): - if read_state["course_id"] == course_id: - return read_state - return {} - - -# TODO: Make this function modular -# pylint: disable=too-many-nested-blocks,too-many-statements -def handle_threads_query( - comment_thread_ids: list[str], - user_id: str, - course_id: str, - group_ids: list[int], - author_id: Optional[str], - thread_type: Optional[str], - filter_flagged: bool, - filter_unread: bool, - filter_unanswered: bool, - filter_unresponded: bool, - count_flagged: bool, - sort_key: str, - page: int, - per_page: int, - context: str = "course", - raw_query: bool = False, -) -> dict[str, Any]: - """ - Handles complex thread queries based on various filters and returns paginated results. - - Args: - comment_thread_ids (list[str]): List of comment thread IDs to filter. - user_id (str): The ID of the user making the request. - course_id (str): The course ID associated with the threads. - group_ids (list[int]): List of group IDs for group-based filtering. - author_id (str): The ID of the author to filter threads by. - thread_type (str): The type of thread to filter by. - filter_flagged (bool): Whether to filter threads flagged for abuse. - filter_unread (bool): Whether to filter unread threads. - filter_unanswered (bool): Whether to filter unanswered questions. - filter_unresponded (bool): Whether to filter threads with no responses. - count_flagged (bool): Whether to include flagged content count. - sort_key (str): The key to sort the threads by. - page (int): The page number for pagination. - per_page (int): The number of threads per page. - context (str): The context to filter threads by. - raw_query (bool): Whether to return raw query results without further processing. - - Returns: - dict[str, Any]: A dictionary containing the paginated thread results and associated metadata. - """ - # Convert thread_ids to ObjectId - comment_thread_obj_ids: list[ObjectId] = [ - ObjectId(tid) for tid in comment_thread_ids - ] - - # Base query - base_query: dict[str, Any] = { - "_id": {"$in": comment_thread_obj_ids}, - "context": context, - } - - # Group filtering - if group_ids: - base_query["$or"] = [ - {"group_id": {"$in": group_ids}}, - {"group_id": {"$exists": False}}, - ] - - # Author filtering - if author_id: - base_query["author_id"] = author_id - if author_id != user_id: - base_query["anonymous"] = False - base_query["anonymous_to_peers"] = False - - # Thread type filtering - if thread_type: - base_query["thread_type"] = thread_type + flagged_comments = Comment().distinct("comment_thread_id", flagged_query) + flagged_threads = CommentThread().distinct("_id", flagged_query) + base_query["_id"]["$in"] = list( + set(comment_thread_obj_ids) & set(flagged_comments + flagged_threads) + ) - # Flagged content filtering - if filter_flagged: - flagged_query = { - "course_id": course_id, - "abuse_flaggers": {"$ne": [], "$exists": True}, - } - flagged_comments = Comment().distinct("comment_thread_id", flagged_query) - flagged_threads = CommentThread().distinct("_id", flagged_query) - base_query["_id"]["$in"] = list( - set(comment_thread_obj_ids) & set(flagged_comments + flagged_threads) - ) + # Unanswered questions filtering + if filter_unanswered: + endorsed_threads = Comment().distinct( + "comment_thread_id", + { + "course_id": course_id, + "parent_id": {"$exists": False}, + "endorsed": True, + }, + ) + base_query["thread_type"] = "question" + base_query["_id"]["$nin"] = endorsed_threads + + # Unresponded threads filtering + if filter_unresponded: + base_query["comment_count"] = 0 + + sort_criteria = get_sort_criteria(sort_key) + + comment_threads = CommentThread().find(base_query) + thread_count = CommentThread().count_documents(base_query) + + if sort_criteria or raw_query: + request_user = Users().get(user_id) if user_id else None + + if not raw_query: + comment_threads = comment_threads.sort(sort_criteria) + + if filter_unread and request_user: + read_state = cls.get_user_read_state_by_course_id( + request_user["external_id"], course_id + ) + read_dates = read_state.get("last_read_times", {}) + + threads = [] + skipped = 0 + to_skip = (page - 1) * per_page + has_more = False + batch_size = 100 + + for thread in comment_threads.batch_size(batch_size): + thread_key = str(thread["_id"]) + if ( + thread_key not in read_dates + or read_dates[thread_key] < thread["last_activity_at"] + ): + if raw_query: + threads.append(thread) + elif skipped >= to_skip: + if len(threads) == per_page: + has_more = True + break + threads.append(thread) + else: + skipped += 1 + num_pages = page + 1 if has_more else page + else: + if raw_query: + threads = list(comment_threads) + else: + page = max(1, page) + paginated_collection = comment_threads.skip( + (page - 1) * per_page + ).limit(per_page) + threads = list(paginated_collection) + num_pages = max(1, math.ceil(thread_count / per_page)) - # Unanswered questions filtering - if filter_unanswered: - endorsed_threads = Comment().distinct( - "comment_thread_id", - {"course_id": course_id, "parent_id": {"$exists": False}, "endorsed": True}, - ) - base_query["thread_type"] = "question" - base_query["_id"]["$nin"] = endorsed_threads - - # Unresponded threads filtering - if filter_unresponded: - base_query["comment_count"] = 0 - - sort_criteria = get_sort_criteria(sort_key) - - comment_threads = CommentThread().find(base_query) - thread_count = CommentThread().count_documents(base_query) - - if sort_criteria or raw_query: - request_user = Users().get(_id=user_id) if user_id else None - - if not raw_query: - comment_threads = comment_threads.sort(sort_criteria) - - if filter_unread and request_user: - read_state = get_user_read_state_by_course_id(request_user, course_id) - read_dates = read_state.get("last_read_times", {}) - - threads = [] - skipped = 0 - to_skip = (page - 1) * per_page - has_more = False - batch_size = 100 - - for thread in comment_threads.batch_size(batch_size): - thread_key = str(thread["_id"]) - if ( - thread_key not in read_dates - or read_dates[thread_key] < thread["last_activity_at"] - ): - if raw_query: - threads.append(thread) - elif skipped >= to_skip: - if len(threads) == per_page: - has_more = True - break - threads.append(thread) - else: - skipped += 1 - num_pages = page + 1 if has_more else page - else: if raw_query: - threads = list(comment_threads) + return {"result": threads} + if len(threads) == 0: + collection = [] else: - page = max(1, page) - paginated_collection = comment_threads.skip( - (page - 1) * per_page - ).limit(per_page) - threads = list(paginated_collection) - num_pages = max(1, math.ceil(thread_count / per_page)) - - if raw_query: - return {"result": threads} - if len(threads) == 0: - collection = [] - else: - collection = threads_presentor(threads, user_id, course_id, count_flagged) + thread_ids = [str(thread["_id"]) for thread in threads] + collection = cls.threads_presentor( + thread_ids, user_id, course_id, count_flagged + ) + + return { + "collection": collection, + "num_pages": num_pages, + "page": page, + "thread_count": thread_count, + } + + return {} + + @staticmethod + def prepare_thread( + thread_id: str, + is_read: bool, + unread_count: int, + is_endorsed: bool, + abuse_flagged_count: int, + ) -> dict[str, Any]: + """ + Prepares thread data for presentation. + + Args: + thread (dict[str, Any]): The thread data. + is_read (bool): Whether the thread is read. + unread_count (int): The count of unread comments. + is_endorsed (bool): Whether the thread is endorsed. + abuse_flagged_count (int): The abuse flagged count. + + Returns: + dict[str, Any]: A dictionary representing the prepared thread data. + """ + thread = CommentThread().get(thread_id) + if not thread: + raise ValueError("Thread does not exist.") return { - "collection": collection, - "num_pages": num_pages, - "page": page, - "thread_count": thread_count, + "id": str(thread["_id"]), + **thread, + "type": "thread", + "read": is_read, + "unread_comments_count": unread_count, + "endorsed": is_endorsed, + "abuse_flagged_count": abuse_flagged_count, } - return {} - - -def prepare_thread( - thread: dict[str, Any], - is_read: bool, - unread_count: int, - is_endorsed: bool, - abuse_flagged_count: int, -) -> dict[str, Any]: - """ - Prepares thread data for presentation. - - Args: - thread (dict[str, Any]): The thread data. - is_read (bool): Whether the thread is read. - unread_count (int): The count of unread comments. - is_endorsed (bool): Whether the thread is endorsed. - abuse_flagged_count (int): The abuse flagged count. - - Returns: - dict[str, Any]: A dictionary representing the prepared thread data. - """ - return { - "id": str(thread["_id"]), - **thread, - "type": "thread", - "read": is_read, - "unread_comments_count": unread_count, - "endorsed": is_endorsed, - "abuse_flagged_count": abuse_flagged_count, - } - - -def threads_presentor( - threads: list[dict[str, Any]], - user_id: str, - course_id: str, - count_flagged: bool = False, -) -> list[dict[str, Any]]: - """ - Presents the threads by preparing them for display. - - Args: - threads (list[dict[str, Any]]): List of threads to present. - user_id (str): The ID of the user presenting the threads. - course_id (str): The course ID associated with the threads. - count_flagged (bool, optional): Whether to include flagged content count. Defaults to False. - - Returns: - list[dict[str, Any]]: A list of prepared thread data. - """ - thread_ids = [str(thread["_id"]) for thread in threads] - read_states = get_read_states(threads, user_id, course_id) - threads_endorsed = get_endorsed(thread_ids) - threads_flagged = get_abuse_flagged_count(thread_ids) if count_flagged else {} - - presenters = [] - for thread in threads: - thread_key = str(thread["_id"]) - is_read, unread_count = read_states.get( - thread_key, (False, thread["comment_count"]) + @classmethod + def threads_presentor( + cls, + thread_ids: list[str], + user_id: str, + course_id: str, + count_flagged: bool = False, + ) -> list[dict[str, Any]]: + """ + Presents the threads by preparing them for display. + + Args: + threads (list[dict[str, Any]]): List of threads to present. + user_id (str): The ID of the user presenting the threads. + course_id (str): The course ID associated with the threads. + count_flagged (bool, optional): Whether to include flagged content count. Defaults to False. + + Returns: + list[dict[str, Any]]: A list of prepared thread data. + """ + threads = CommentThread().find( + {"_id": {"$in": [ObjectId(thread_id) for thread_id in thread_ids]}} ) - is_endorsed = threads_endorsed.get(thread_key, False) - abuse_flagged_count = threads_flagged.get(thread_key, 0) - presenters.append( - prepare_thread( - thread, - is_read, - unread_count, - is_endorsed, - abuse_flagged_count, - ) + read_states = cls.get_read_states(thread_ids, user_id, course_id) + threads_endorsed = cls.get_endorsed(thread_ids) + threads_flagged = ( + cls.get_abuse_flagged_count(thread_ids) if count_flagged else {} ) + threads_dict = {str(thread["_id"]): thread for thread in threads} + + presenters = [] + for thread_id in thread_ids: + thread = threads_dict.get(thread_id) + if thread: + thread_key = thread_id + is_read, unread_count = read_states.get( + thread_key, (False, thread["comment_count"]) + ) + is_endorsed = threads_endorsed.get(thread_key, False) + abuse_flagged_count = threads_flagged.get(thread_key, 0) + presenters.append( + cls.prepare_thread( + thread["_id"], + is_read, + unread_count, + is_endorsed, + abuse_flagged_count, + ) + ) - return presenters - - -def get_username_from_id(user_id: str) -> Optional[str]: - """ - Retrieve the username associated with a given user ID. - - Args: - _id (int): The unique identifier of the user. - - Returns: - Optional[str]: The username of the user if found, or None if not. - - """ - user = Users().get(_id=user_id) or {} - if username := user.get("username"): - return username - return None - - -def validate_object(model: Any, obj_id: str) -> Any: - """ - Validates the object if it exists or not. - - Parameters: - model: The model for which to validate the id. - id: The ID of the object to validate in the model. - Response: - raise exception if object does not exists. - return object - """ - instance = model().get(obj_id) - if not instance: - raise ObjectDoesNotExist - return instance - - -def find_subscribed_threads(user_id: str, course_id: Optional[str] = None) -> list[str]: - """ - Find threads that a user is subscribed to in a specific course. - - Args: - user_id (str): The ID of the user. - course_id (str): The ID of the course. - - Returns: - list: A list of thread ids that the user is subscribed to in the course. - """ - subscriptions = Subscriptions() - threads = CommentThread() - - subscription_filter = {"subscriber_id": user_id} - subscriptions_cursor = subscriptions.find(subscription_filter) - - thread_ids = [] - for subscription in subscriptions_cursor: - thread_ids.append(ObjectId(subscription["source_id"])) + return presenters - thread_filter: dict[str, Any] = {"_id": {"$in": thread_ids}} - if course_id: - thread_filter["course_id"] = course_id - threads_cursor = threads.find(thread_filter) + @staticmethod + def get_username_from_id(user_id: str) -> Optional[str]: + """ + Retrieve the username associated with a given user ID. - subscribed_ids = [] - for thread in threads_cursor: - subscribed_ids.append(thread["_id"]) + Args: + _id (int): The unique identifier of the user. - return subscribed_ids + Returns: + Optional[str]: The username of the user if found, or None if not. + """ + user = Users().get(_id=user_id) or {} + if username := user.get("username"): + return username + return None -def subscribe_user( - user_id: str, source_id: str, source_type: str -) -> dict[str, Any] | None: - """Subscribe a user to a source.""" - subscription = Subscriptions().get_subscription(user_id, source_id) - if not subscription: - Subscriptions().insert(user_id, source_id, source_type) + @staticmethod + def validate_object(model: str, obj_id: str) -> Any: + """ + Validates the object if it exists or not. + + Parameters: + model: The model for which to validate the id. + id: The ID of the object to validate in the model. + Response: + raise exception if object does not exists. + return object + """ + models = { + "Comment": Comment, + "CommentThread": CommentThread, + } + instance = models[model]().get(obj_id) + if not instance: + raise ObjectDoesNotExist + return instance + + @staticmethod + def find_subscribed_threads( + user_id: str, course_id: Optional[str] = None + ) -> list[str]: + """ + Find threads that a user is subscribed to in a specific course. + + Args: + user_id (str): The ID of the user. + course_id (str): The ID of the course. + + Returns: + list: A list of thread ids that the user is subscribed to in the course. + """ + subscriptions = Subscriptions() + threads = CommentThread() + + subscription_filter = {"subscriber_id": user_id} + subscriptions_cursor = subscriptions.find(subscription_filter) + + thread_ids = [] + for subscription in subscriptions_cursor: + thread_ids.append(ObjectId(subscription["source_id"])) + + thread_filter: dict[str, Any] = {"_id": {"$in": thread_ids}} + if course_id: + thread_filter["course_id"] = course_id + threads_cursor = threads.find(thread_filter) + + subscribed_ids = [] + for thread in threads_cursor: + subscribed_ids.append(str(thread["_id"])) + + return subscribed_ids + + @staticmethod + def subscribe_user( + user_id: str, source_id: str, source_type: str + ) -> dict[str, Any] | None: + """Subscribe a user to a source.""" subscription = Subscriptions().get_subscription(user_id, source_id) - return subscription + if not subscription: + Subscriptions().insert(user_id, source_id, source_type) + subscription = Subscriptions().get_subscription(user_id, source_id) + if subscription: + subscription["_id"] = str(subscription["_id"]) + return subscription + + @staticmethod + def unsubscribe_user( + user_id: str, source_id: str, source_type: Optional[str] = "" + ) -> None: + """Unsubscribe a user from a source.""" + Subscriptions().delete_subscription(user_id, source_id, source_type=source_type) + + @staticmethod + def delete_comments_of_a_thread(thread_id: str) -> None: + """Delete comments of a thread.""" + for comment in Comment().get_list( + comment_thread_id=ObjectId(thread_id), + depth=0, + parent_id=None, + ): + Comment().delete(comment["_id"]) + + @staticmethod + def delete_subscriptions_of_a_thread(thread_id: str) -> None: + """Delete subscriptions of a thread.""" + for subscription in Subscriptions().get_list( + source_id=thread_id, source_type="CommentThread" + ): + Subscriptions().delete_subscription( + subscription["subscriber_id"], + subscription["source_id"], + source_type="CommentThread", + ) + @staticmethod + def validate_params(params: dict[str, Any], user_id: Optional[str] = None) -> None: + """ + Validate the request parameters. + Args: + params (dict): The request parameters. + user_id (optional[str]): The Id of the user for validation. + + Returns: + Response: A Response object with an error message if doesn't exist. + """ + valid_params = [ + "course_id", + "author_id", + "thread_type", + "flagged", + "unread", + "unanswered", + "unresponded", + "count_flagged", + "sort_key", + "page", + "per_page", + "request_id", + "commentable_ids", + ] + if not user_id: + valid_params.append("user_id") + user_id = params.get("user_id") + + for key in params: + if key not in valid_params: + raise ForumV2RequestError(f"Invalid parameter: {key}") + + if "course_id" not in params: + raise ForumV2RequestError("Missing required parameter: course_id") + + if user_id: + user = Users().get(user_id) + if not user: + raise ForumV2RequestError("User doesn't exist") + + @classmethod + def get_threads( + cls, + params: dict[str, Any], + user_id: str, + serializer: Any, + thread_ids: list[str], + ) -> dict[str, Any]: + """get subscribed or all threads of a specific course for a specific user.""" + count_flagged = bool(params.get("count_flagged", False)) + threads = cls.handle_threads_query( + thread_ids, + user_id, + params["course_id"], + get_group_ids_from_params(params), + params.get("author_id", ""), + params.get("thread_type"), + bool(params.get("flagged", False)), + bool(params.get("unread", False)), + bool(params.get("unanswered", False)), + bool(params.get("unresponded", False)), + count_flagged, + params.get("sort_key", ""), + int(params.get("page", 1)), + int(params.get("per_page", 100)), + ) + context: dict[str, Any] = { + "count_flagged": count_flagged, + "include_endorsed": True, + "include_read_state": True, + } + if user_id: + context["user_id"] = user_id + serializer = serializer( + threads.pop("collection"), many=True, context=context, backend=cls + ) + threads["collection"] = serializer.data + return threads + + @staticmethod + def generate_id() -> str: + return str(ObjectId()) + + @staticmethod + def find_or_create_user( + user_id: str, username: Optional[str] = "", default_sort_key: Optional[str] = "" + ) -> str: + """Find or create user.""" + user = Users().get(user_id) + if user: + return user["external_id"] + user_id = Users().insert( + user_id, username=username, default_sort_key=default_sort_key + ) + return user_id + + @classmethod + def create_comment(cls, data: dict[str, Any]) -> str: + """ + handle comment creation and returns a comment. + + Parameters: + data: The content of the comment. + + Response: + The details of the comment that is created. + """ + new_comment_id = Comment().insert( + body=data["body"], + author_id=data["author_id"], + author_username=data.get("author_username"), + course_id=data["course_id"], + anonymous=data.get("anonymous", False), + anonymous_to_peers=data.get("anonymous_to_peers", False), + depth=data.get("depth", 0), + comment_thread_id=data["comment_thread_id"], + parent_id=data.get("parent_id"), + ) -def unsubscribe_user(user_id: str, source_id: str) -> None: - """Unsubscribe a user from a source.""" - Subscriptions().delete_subscription(user_id, source_id) + if data.get("parent_id"): + cls.update_stats_for_course(data["author_id"], data["course_id"], replies=1) + else: + cls.update_stats_for_course( + data["author_id"], data["course_id"], responses=1 + ) + return str(new_comment_id) + + @staticmethod + def update_comment_and_get_updated_comment( + comment_id: str, + body: Optional[str] = None, + course_id: Optional[str] = None, + user_id: Optional[str] = None, + anonymous: Optional[bool] = False, + anonymous_to_peers: Optional[bool] = False, + endorsed: Optional[bool] = None, + closed: Optional[bool] = False, + editing_user_id: Optional[str] = None, + edit_reason_code: Optional[str] = None, + endorsement_user_id: Optional[str] = None, + ) -> dict[str, Any] | None: + """ + Update an existing child/parent comment. + + Parameters: + comment_id: The ID of the comment to be edited. + body (Optional[str]): The content of the comment. + course_id (Optional[str]): The Id of the respective course. + user_id (Optional[str]): The requesting user id. + anonymous (Optional[bool]): anonymous flag(True or False). + anonymous_to_peers (Optional[bool]): anonymous to peers flag(True or False). + endorsed (Optional[bool]): Flag indicating if the comment is endorsed by any user. + closed (Optional[bool]): Flag indicating if the comment thread is closed. + editing_user_id (Optional[str]): The ID of the user editing the comment. + edit_reason_code (Optional[str]): The reason for editing the comment, typically represented by a code. + endorsement_user_id (Optional[str]): The ID of the user endorsing the comment. + Response: + The details of the comment that is updated. + """ + Comment().update( + comment_id, + body=body, + course_id=course_id, + author_id=user_id, + anonymous=anonymous, + anonymous_to_peers=anonymous_to_peers, + endorsed=endorsed, + closed=closed, + editing_user_id=editing_user_id, + edit_reason_code=edit_reason_code, + endorsement_user_id=endorsement_user_id, + ) + return Comment().get(comment_id) -def delete_comments_of_a_thread(thread_id: str) -> None: - """Delete comments of a thread.""" - for comment in Comment().get_list( - comment_thread_id=ObjectId(thread_id), - depth=0, - parent_id=None, - ): - Comment().delete(comment["_id"]) + @staticmethod + def get_commentables_counts_based_on_type(course_id: str) -> dict[str, Any]: + """Return commentables counts in a course based on thread's type.""" + pipeline: list[dict[str, Any]] = [ + {"$match": {"course_id": course_id, "_type": "CommentThread"}}, + { + "$group": { + "_id": {"topic_id": "$commentable_id", "type": "$thread_type"}, + "count": {"$sum": 1}, + } + }, + ] + result = CommentThread().aggregate(pipeline) + commentable_counts = {} + for commentable in result: + topic_id = commentable["_id"]["topic_id"] + if topic_id not in commentable_counts: + commentable_counts[topic_id] = {"discussion": 0, "question": 0} + commentable_counts[topic_id].update( + {commentable["_id"]["type"]: commentable["count"]} + ) -def delete_subscriptions_of_a_thread(thread_id: str) -> None: - """Delete subscriptions of a thread.""" - for subscription in Subscriptions().get_list( - source_id=thread_id, source_type="CommentThread" - ): - Subscriptions().delete_subscription( - subscription["subscriber_id"], subscription["source_id"] - ) + return commentable_counts + @classmethod + def get_user_voted_ids(cls, user_id: str, vote: str) -> list[str]: + """Get the IDs of the posts voted by a user.""" + if vote not in ["up", "down"]: + raise ValueError("Invalid vote type") -def validate_params(params: dict[str, Any], user_id: Optional[str] = None) -> None: - """ - Validate the request parameters. - - Args: - params (dict): The request parameters. - user_id (optional[str]): The Id of the user for validation. - - Returns: - Response: A Response object with an error message if doesn't exist. - """ - valid_params = [ - "course_id", - "author_id", - "thread_type", - "flagged", - "unread", - "unanswered", - "unresponded", - "count_flagged", - "sort_key", - "page", - "per_page", - "request_id", - "commentable_ids", - ] - if not user_id: - valid_params.append("user_id") - user_id = params.get("user_id") - - for key in params: - if key not in valid_params: - raise ForumV2RequestError(f"Invalid parameter: {key}") - - if "course_id" not in params: - raise ForumV2RequestError("Missing required parameter: course_id") - - if user_id: + content_model = Contents() + contents = content_model.get_list() + voted_ids = [] + for content in contents: + votes = content["votes"][vote] + if user_id in votes: + voted_ids.append(content["_id"]) + + return voted_ids + + @staticmethod + def filter_standalone_threads(comment_ids: list[str]) -> list[str]: + """Filter out standalone threads from the list of threads.""" + comments = Comment().find({"_id": {"$in": comment_ids}}) + filtered_comments = [] + for comment in comments: + if not comment["context"] == "standalone": + filtered_comments.append(comment) + return [str([comment["comment_thread_id"]]) for comment in filtered_comments] + + @classmethod + def user_to_hash( + cls, user_id: str, params: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: + """ + Converts user data to a hash + """ user = Users().get(user_id) if not user: - raise ForumV2RequestError("User doesn't exist") - - -def get_threads( - params: dict[str, Any], - serializer: Any, - thread_ids: list[str], - user_id: str = "", -) -> dict[str, Any]: - """get subscribed or all threads of a specific course for a specific user.""" - count_flagged = bool(params.get("count_flagged", False)) - threads = handle_threads_query( - thread_ids, - user_id, - params["course_id"], - get_group_ids_from_params(params), - params.get("author_id", ""), - params.get("thread_type"), - bool(params.get("flagged", False)), - bool(params.get("unread", False)), - bool(params.get("unanswered", False)), - bool(params.get("unresponded", False)), - count_flagged, - params.get("sort_key", ""), - int(params.get("page", 1)), - int(params.get("per_page", 100)), - ) - context: dict[str, Any] = { - "count_flagged": count_flagged, - "include_endorsed": True, - "include_read_state": True, - } - if user_id: - context["user_id"] = user_id - serializer = serializer(threads.pop("collection"), many=True, context=context) - threads["collection"] = serializer.data - return threads - - -def get_commentables_counts_based_on_type(course_id: str) -> dict[str, Any]: - """Return commentables counts in a course based on thread's type.""" - pipeline: list[dict[str, Any]] = [ - {"$match": {"course_id": course_id, "_type": "CommentThread"}}, - { - "$group": { - "_id": {"topic_id": "$commentable_id", "type": "$thread_type"}, - "count": {"$sum": 1}, - } - }, - ] - - result = CommentThread().aggregate(pipeline) - commentable_counts = {} - for commentable in result: - topic_id = commentable["_id"]["topic_id"] - if topic_id not in commentable_counts: - commentable_counts[topic_id] = {"discussion": 0, "question": 0} - commentable_counts[topic_id].update( - {commentable["_id"]["type"]: commentable["count"]} - ) - - return commentable_counts - - -def get_user_voted_ids(user_id: str, vote: str) -> list[str]: - """Get the IDs of the posts voted by a user.""" - if vote not in ["up", "down"]: - raise ValueError("Invalid vote type") - - content_model = Contents() - contents = content_model.get_list() - voted_ids = [] - for content in contents: - votes = content["votes"][vote] - if user_id in votes: - voted_ids.append(content["_id"]) + raise ValueError("User not found.") + if params is None: + params = {} + + hash_data = {} + hash_data["username"] = user["username"] + hash_data["external_id"] = user["external_id"] + + comment_model = Comment() + thread_model = CommentThread() + + if params.get("complete"): + subscribed_thread_ids = cls.find_subscribed_threads(user["external_id"]) + upvoted_ids = cls.get_user_voted_ids(user["external_id"], "up") + downvoted_ids = cls.get_user_voted_ids(user["external_id"], "down") + hash_data.update( + { + "subscribed_thread_ids": subscribed_thread_ids, + "subscribed_commentable_ids": [], + "subscribed_user_ids": [], + "follower_ids": [], + "id": user["external_id"], + "upvoted_ids": upvoted_ids, + "downvoted_ids": downvoted_ids, + "default_sort_key": user["default_sort_key"], + } + ) - return voted_ids + if params.get("course_id"): + threads = thread_model.find( + { + "author_id": user["external_id"], + "course_id": params["course_id"], + "anonymous": False, + "anonymouse_to_peers": False, + } + ) + comments = comment_model.find( + { + "author_id": user["external_id"], + "course_id": params["course_id"], + "anonymous": False, + "anonymouse_to_peers": False, + } + ) + if params.get("group_ids"): + specified_groups_or_global = params["group_ids"] + [None] + group_query = { + "_id": {"$in": [thread["_id"] for thread in threads]}, + "$and": [ + {"group_id": {"$in": specified_groups_or_global}}, + {"group_id": {"$exists": False}}, + ], + } + group_threads = CommentThread().find(group_query) + group_thread_ids = [str(thread["_id"]) for thread in group_threads] + threads_count = len(group_thread_ids) + comment_ids = [comment["_id"] for comment in comments] + comment_thread_ids = cls.filter_standalone_threads(comment_ids) + + group_query = { + "_id": {"$in": [ObjectId(tid) for tid in comment_thread_ids]}, + "$and": [ + {"group_id": {"$in": specified_groups_or_global}}, + {"group_id": {"$exists": False}}, + ], + } + group_comment_threads = thread_model.find(group_query) + group_comment_thread_ids = [ + str(thread["_id"]) for thread in group_comment_threads + ] + comments_count = sum( + 1 + for comment_thread_id in comment_thread_ids + if comment_thread_id in group_comment_thread_ids + ) + else: + thread_ids = [str(thread["_id"]) for thread in threads] + threads_count = len(thread_ids) + comment_ids = [comment["_id"] for comment in comments] + comment_thread_ids = cls.filter_standalone_threads(comment_ids) + comments_count = len(comment_thread_ids) + + hash_data.update( + { + "threads_count": threads_count, + "comments_count": comments_count, + } + ) + return hash_data -def filter_standalone_threads(comments: list[dict[str, Any]]) -> list[str]: - """Filter out standalone threads from the list of threads.""" - filtered_comments = [] - for comment in comments: - if not comment["context"] == "standalone": - filtered_comments.append(comment) - return [str([comment["comment_thread_id"]]) for comment in filtered_comments] + @staticmethod + def replace_username_in_all_content(user_id: str, username: str) -> None: + """Replace new username in all content documents.""" + content_model = Contents() + contents = content_model.get_list(author_id=user_id) + for content in contents: + content_model.update( + content["_id"], + author_username=username, + ) + @staticmethod + def unsubscribe_all(user_id: str) -> None: + """Unsubscribe user from all content.""" + subscriptions = Subscriptions() + subscription_filter = {"subscriber_id": user_id} + subscriptions_cursor = subscriptions.find(subscription_filter) -def user_to_hash( - user: dict[str, Any], params: Optional[dict[str, Any]] = None -) -> dict[str, Any]: - """ - Converts user data to a hash - """ - if params is None: - params = {} + for subscription in subscriptions_cursor: + subscriptions.delete(subscription["_id"]) - hash_data = {} - hash_data["username"] = user["username"] - hash_data["external_id"] = user["external_id"] - hash_data["id"] = user["external_id"] + @staticmethod + def retire_all_content(user_id: str, username: str) -> None: + """Retire all content from user.""" + content_model = Contents() + contents = content_model.get_list(author_id=user_id) + for content in contents: + content_model.update( + content["_id"], + author_username=username, + body=RETIRED_BODY, + ) + if content["_type"] == "CommentThread": + content_model.update( + content["_id"], + title=RETIRED_TITLE, + ) + + @staticmethod + def find_or_create_read_state(user_id: str, thread_id: str) -> dict[str, Any]: + """Find or create user read states.""" + user = Users().get(user_id) + if not user: + raise ObjectDoesNotExist + thread = CommentThread().get(thread_id) + if not thread: + raise ObjectDoesNotExist + + read_states = user.get("read_states", []) + for state in read_states: + if state["course_id"] == thread["course_id"]: + return state + + read_state = { + "_id": ObjectId(), + "course_id": thread["course_id"], + "last_read_times": {}, + } + read_states.append(read_state) + Users().update(user_id, read_states=read_states) + return read_state - comment_model = Comment() - thread_model = CommentThread() + @classmethod + def mark_as_read(cls, user_id: str, thread_id: str) -> None: + """Mark thread as read.""" + user = Users().get(user_id) + thread = CommentThread().get(thread_id) + if not (user and thread): + raise ValueError("User and/or Thread not found.") + read_state = cls.find_or_create_read_state(user["external_id"], thread["_id"]) - if params.get("complete"): - subscribed_thread_ids = find_subscribed_threads(user["external_id"]) - upvoted_ids = get_user_voted_ids(user["external_id"], "up") - downvoted_ids = get_user_voted_ids(user["external_id"], "down") - hash_data.update( + read_state["last_read_times"].update( { - "subscribed_thread_ids": subscribed_thread_ids, - "subscribed_commentable_ids": [], - "subscribed_user_ids": [], - "follower_ids": [], - "id": user["external_id"], - "upvoted_ids": upvoted_ids, - "downvoted_ids": downvoted_ids, - "default_sort_key": user["default_sort_key"], + str(thread["_id"]): datetime.now(timezone.utc), } ) + update_user = Users().get(user["external_id"]) + if not update_user: + raise ObjectDoesNotExist + new_read_states = update_user["read_states"] + updated_read_states = [] + for state in new_read_states: + if state["course_id"] == thread["course_id"]: + state = read_state + updated_read_states.append(state) + + Users().update(user["external_id"], read_states=updated_read_states) + + @staticmethod + def find_or_create_user_stats(user_id: str, course_id: str) -> dict[str, Any]: + """Find or create user stats document.""" + user = Users().get(user_id) + if not user: + raise ObjectDoesNotExist + + course_stats = user.get("course_stats", []) + for stat in course_stats: + if stat["course_id"] == course_id: + return stat + + course_stat = { + "_id": ObjectId(), + "active_flags": 0, + "inactive_flags": 0, + "threads": 0, + "responses": 0, + "replies": 0, + "course_id": course_id, + "last_activity_at": "", + } + course_stats.append(course_stat) + Users().update(user["external_id"], course_stats=course_stats) + return course_stat - if params.get("course_id"): - threads = thread_model.find( + @staticmethod + def update_user_stats_for_course(user_id: str, stat: dict[str, Any]) -> None: + """Update user stats for course.""" + user = Users().get(user_id) + if not user: + raise ObjectDoesNotExist + updated_course_stats = [] + course_stats = user["course_stats"] + for course_stat in course_stats: + if course_stat["course_id"] == stat["course_id"]: + course_stat.update(stat) + updated_course_stats.append(course_stat) + Users().update(user_id, course_stats=updated_course_stats) + + @classmethod + def build_course_stats(cls, author_id: str, course_id: str) -> None: + """Build course stats.""" + user = Users().get(author_id) + if not user: + raise ObjectDoesNotExist + pipeline = [ { - "author_id": user["external_id"], - "course_id": params["course_id"], - "anonymous": False, - "anonymouse_to_peers": False, - } - ) - comments = comment_model.find( + "$match": { + "course_id": course_id, + "author_id": user["external_id"], + "anonymous_to_peers": False, + "anonymous": False, + } + }, { - "author_id": user["external_id"], - "course_id": params["course_id"], - "anonymous": False, - "anonymouse_to_peers": False, - } - ) - if params.get("group_ids"): - specified_groups_or_global = params["group_ids"] + [None] - group_query = { - "_id": {"$in": [thread["_id"] for thread in threads]}, - "$and": [ - {"group_id": {"$in": specified_groups_or_global}}, - {"group_id": {"$exists": False}}, - ], - } - group_threads = CommentThread().find(group_query) - group_thread_ids = [str(thread["_id"]) for thread in group_threads] - threads_count = len(group_thread_ids) - comment_thread_ids = filter_standalone_threads(list(comments)) - - group_query = { - "_id": {"$in": [ObjectId(tid) for tid in comment_thread_ids]}, - "$and": [ - {"group_id": {"$in": specified_groups_or_global}}, - {"group_id": {"$exists": False}}, - ], - } - group_comment_threads = thread_model.find(group_query) - group_comment_thread_ids = [ - str(thread["_id"]) for thread in group_comment_threads - ] - comments_count = sum( - 1 - for comment_thread_id in comment_thread_ids - if comment_thread_id in group_comment_thread_ids - ) - else: - thread_ids = [str(thread["_id"]) for thread in threads] - threads_count = len(thread_ids) - comment_thread_ids = filter_standalone_threads(list(comments)) - comments_count = len(comment_thread_ids) - - hash_data.update( + "$addFields": { + "is_reply": {"$ne": [{"$ifNull": ["$parent_id", None]}, None]} + } + }, { - "threads_count": threads_count, - "comments_count": comments_count, - } - ) - - return hash_data - + "$group": { + "_id": {"type": "$_type", "is_reply": "$is_reply"}, + "count": {"$sum": 1}, + "active_flags": { + "$sum": { + "$cond": { + "if": {"$gt": [{"$size": "$abuse_flaggers"}, 0]}, + "then": 1, + "else": 0, + } + } + }, + "inactive_flags": { + "$sum": { + "$cond": { + "if": { + "$gt": [{"$size": "$historical_abuse_flaggers"}, 0] + }, + "then": 1, + "else": 0, + } + } + }, + "latest_update_at": {"$max": "$updated_at"}, + } + }, + ] -def replace_username_in_all_content(user_id: str, username: str) -> None: - """Replace new username in all content documents.""" - content_model = Contents() - contents = content_model.get_list(author_id=user_id) - for content in contents: - content_model.update( - content["_id"], - author_username=username, + data = list(Contents().aggregate(pipeline)) + active_flags = 0 + inactive_flags = 0 + threads = 0 + responses = 0 + replies = 0 + updated_at = datetime.utcfromtimestamp(0) + + for counts in data: + _type, is_reply = counts["_id"]["type"], counts["_id"]["is_reply"] + last_update_at = counts.get("latest_update_at", datetime(1970, 1, 1)) + if _type == "Comment" and is_reply: + replies = counts["count"] + elif _type == "Comment" and not is_reply: + responses = counts["count"] + else: + threads = counts["count"] + last_update_at = make_aware(last_update_at) + updated_at = make_aware(updated_at) + updated_at = max(last_update_at, updated_at) + active_flags += counts["active_flags"] + inactive_flags += counts["inactive_flags"] + + stats = cls.find_or_create_user_stats(user["external_id"], course_id) + stats["replies"] = replies + stats["responses"] = responses + stats["threads"] = threads + stats["active_flags"] = active_flags + stats["inactive_flags"] = inactive_flags + stats["last_activity_at"] = updated_at + cls.update_user_stats_for_course(user["external_id"], stats) + + @classmethod + def update_all_users_in_course(cls, course_id: str) -> list[str]: + """Update all user stats in a course.""" + course_contents = Contents().get_list( + anonymous=False, + anonymous_to_peers=False, + course_id=course_id, ) - - -def unsubscribe_all(user_id: str) -> None: - """Unsubscribe user from all content.""" - subscriptions = Subscriptions() - subscription_filter = {"subscriber_id": user_id} - subscriptions_cursor = subscriptions.find(subscription_filter) - - for subscription in subscriptions_cursor: - subscriptions.delete(subscription["_id"]) - - -def retire_all_content(user_id: str, username: str) -> None: - """Retire all content from user.""" - content_model = Contents() - contents = content_model.get_list(author_id=user_id) - for content in contents: - content_model.update( - content["_id"], - author_username=username, - body=RETIRED_BODY, + author_ids = [] + for content in course_contents: + if content["author_id"] not in author_ids: + author_ids.append(content["author_id"]) + + for author_id in author_ids: + cls.build_course_stats(author_id, course_id) + return author_ids + + @staticmethod + def get_user_by_username(username: str | None) -> dict[str, Any] | None: + """Return user from username.""" + cursor = Users().find({"username": username}) + try: + return next(cursor) + except StopIteration: + return None + + @staticmethod + def get_comment(comment_id: str) -> dict[str, Any] | None: + """Get comment from id.""" + comment = Comment().get(comment_id) + return comment + + @staticmethod + def get_thread(thread_id: str) -> dict[str, Any] | None: + """Get thread from id.""" + thread = CommentThread().get(thread_id) + if not thread: + return None + return thread + + @staticmethod + def get_comments(**kwargs: Any) -> list[dict[str, Any]]: + """Return comments from kwargs.""" + if "comment_thread_id" in kwargs: + kwargs["comment_thread_id"] = ObjectId(kwargs["comment_thread_id"]) + + return list(Comment().get_list(**kwargs)) + + @staticmethod + def update_comment(comment_id: str, **kwargs: Any) -> int: + """Update comment.""" + return Comment().update(comment_id, **kwargs) + + @staticmethod + def delete_comment(comment_id: str) -> None: + """Delete comment.""" + Comment().delete(comment_id) + + @staticmethod + def get_thread_id_from_comment(comment_id: str) -> dict[str, Any] | None: + """Return thread_id from comment_id.""" + parent_comment = Comment().get(comment_id) + if parent_comment: + return parent_comment["comment_thread_id"] + raise ValueError("Comment doesn't have the thread.") + + @staticmethod + def get_user(user_id: str) -> dict[str, Any] | None: + """Return user from user_id.""" + return Users().get(user_id) + + @staticmethod + def get_subscription( + subscriber_id: str, source_id: str, **kwargs: Any + ) -> dict[str, Any] | None: + """Return subscription from subscriber_id and source_id.""" + subscription = Subscriptions().get_subscription(subscriber_id, source_id) + if not subscription: + return None + return subscription + + @staticmethod + def get_subscriptions(query: dict[str, Any]) -> list[dict[str, Any]]: + """Return subscriptions from filter.""" + return list(Subscriptions().find(query)) + + @staticmethod + def delete_thread(thread_id: str) -> int: + """Delete thread.""" + return CommentThread().delete(thread_id) + + @staticmethod + def create_thread(data: dict[str, Any]) -> str: + """Create thread.""" + new_thread_id = CommentThread().insert( + title=data["title"], + body=data["body"], + course_id=data["course_id"], + anonymous=str_to_bool(data.get("anonymous", "False")), + anonymous_to_peers=str_to_bool(data.get("anonymous_to_peers", "False")), + author_id=data["author_id"], + commentable_id=data.get("commentable_id", "course"), + thread_type=data.get("thread_type", "discussion"), + author_username=data.get("author_username"), + context=data.get("context", "course"), + pinned=data.get("pinned", False), + visible=data.get("visible", True), + abuse_flaggers=data.get("abuse_flaggers"), + historical_abuse_flaggers=data.get("historical_abuse_flaggers"), + group_id=data.get("group_id"), ) - if content["_type"] == "CommentThread": - content_model.update( - content["_id"], - title=RETIRED_TITLE, - ) + return new_thread_id + + @staticmethod + def update_thread(thread_id: str, **kwargs: Any) -> int: + """Update thread.""" + return CommentThread().update(thread_id, **kwargs) + + @staticmethod + def get_filtered_threads(query: dict[str, Any]) -> list[dict[str, Any]]: + """Return threads from filter.""" + thread_filter = { + "_type": {"$in": [CommentThread().content_type]}, + "course_id": query.get("course_id"), + } + return list(CommentThread().find(thread_filter)) + + @staticmethod + def update_user(user_id: str, data: dict[str, Any]) -> int: + """Update user.""" + return Users().update(user_id, **data) + + @staticmethod + def get_thread_id_by_comment_id(parent_comment_id: str) -> str: + """ + The thread Id from the parent comment. + """ + parent_comment = Comment().get(parent_comment_id) + if parent_comment: + return parent_comment["comment_thread_id"] + raise ValueError("Comment doesn't have the thread.") + + @staticmethod + def get_course_id_by_thread_id(thread_id: str) -> str | None: + """ + Return course_id for the matching thread. + """ + thread = CommentThread().get(thread_id) + if thread: + return thread.get("course_id") + return None + @staticmethod + def get_course_id_by_comment_id(comment_id: str) -> str | None: + """ + Return course_id for the matching comment. + """ + comment = Comment().get(comment_id) + if comment: + return comment.get("course_id") + return None -def find_or_create_read_state(user_id: str, thread_id: str) -> dict[str, Any]: - """Find or create user read states.""" - user = Users().get(user_id) - if not user: - raise ObjectDoesNotExist - thread = CommentThread().get(thread_id) - if not thread: - raise ObjectDoesNotExist - - read_states = user.get("read_states", []) - for state in read_states: - if state["course_id"] == thread["course_id"]: - return state - - read_state = { - "_id": ObjectId(), - "course_id": thread["course_id"], - "last_read_times": {}, - } - read_states.append(read_state) - Users().update(user_id, read_states=read_states) - return read_state - - -def mark_as_read(user: dict[str, Any], thread: dict[str, Any]) -> None: - """Mark thread as read.""" - read_state = find_or_create_read_state(user["external_id"], thread["_id"]) - - read_state["last_read_times"].update( - { - str(thread["_id"]): datetime.now(timezone.utc), - } - ) - update_user = Users().get(user["external_id"]) - if not update_user: - raise ObjectDoesNotExist - new_read_states = update_user["read_states"] - updated_read_states = [] - for state in new_read_states: - if state["course_id"] == thread["course_id"]: - state = read_state - updated_read_states.append(state) - - Users().update(user["external_id"], read_states=updated_read_states) - - -def find_or_create_user_stats(user_id: str, course_id: str) -> dict[str, Any]: - """Find or create user stats document.""" - user = Users().get(user_id) - if not user: - raise ObjectDoesNotExist - - course_stats = user.get("course_stats", []) - for stat in course_stats: - if stat["course_id"] == course_id: - return stat - - course_stat = { - "_id": ObjectId(), - "active_flags": 0, - "inactive_flags": 0, - "threads": 0, - "responses": 0, - "replies": 0, - "course_id": course_id, - "last_activity_at": "", - } - course_stats.append(course_stat) - Users().update(user["external_id"], course_stats=course_stats) - return course_stat - - -def update_user_stats_for_course(user_id: str, stat: dict[str, Any]) -> None: - """Update user stats for course.""" - user = Users().get(user_id) - if not user: - raise ObjectDoesNotExist - updated_course_stats = [] - course_stats = user["course_stats"] - for course_stat in course_stats: - if course_stat["course_id"] == stat["course_id"]: - course_stat.update(stat) - updated_course_stats.append(course_stat) - Users().update(user_id, course_stats=updated_course_stats) - - -def build_course_stats(author_id: str, course_id: str) -> None: - """Build course stats.""" - user = Users().get(author_id) - if not user: - raise ObjectDoesNotExist - pipeline = [ - { - "$match": { - "course_id": course_id, - "author_id": user["external_id"], - "anonymous_to_peers": False, - "anonymous": False, + @staticmethod + def get_users(**kwargs: Any) -> list[dict[str, Any]]: + """Get users.""" + return list(Users().get_list(**kwargs)) + + @staticmethod + def get_user_sort_criterion(sort_by: str) -> dict[str, Any]: + """Get sort criterion based on sort_by parameter.""" + if sort_by == "flagged": + return { + "course_stats.active_flags": -1, + "course_stats.inactive_flags": -1, + "username": -1, } - }, - { - "$addFields": { - "is_reply": {"$ne": [{"$ifNull": ["$parent_id", None]}, None]} + elif sort_by == "recency": + return { + "course_stats.last_activity_at": -1, + "username": -1, } - }, - { - "$group": { - "_id": {"type": "$_type", "is_reply": "$is_reply"}, - "count": {"$sum": 1}, - "active_flags": { - "$sum": { - "$cond": { - "if": {"$gt": [{"$size": "$abuse_flaggers"}, 0]}, - "then": 1, - "else": 0, - } - } - }, - "inactive_flags": { - "$sum": { - "$cond": { - "if": {"$gt": [{"$size": "$historical_abuse_flaggers"}, 0]}, - "then": 1, - "else": 0, - } - } - }, - "latest_update_at": {"$max": "$updated_at"}, - } - }, - ] - - data = list(Contents().aggregate(pipeline)) - active_flags = 0 - inactive_flags = 0 - threads = 0 - responses = 0 - replies = 0 - updated_at = datetime.utcfromtimestamp(0) - - for counts in data: - _type, is_reply = counts["_id"]["type"], counts["_id"]["is_reply"] - last_update_at = counts.get("latest_update_at", datetime(1970, 1, 1)) - if _type == "Comment" and is_reply: - replies = counts["count"] - elif _type == "Comment" and not is_reply: - responses = counts["count"] else: - threads = counts["count"] - last_update_at = make_aware(last_update_at) - updated_at = make_aware(updated_at) - updated_at = max(last_update_at, updated_at) - active_flags += counts["active_flags"] - inactive_flags += counts["inactive_flags"] - - stats = find_or_create_user_stats(user["external_id"], course_id) - stats["replies"] = replies - stats["responses"] = responses - stats["threads"] = threads - stats["active_flags"] = active_flags - stats["inactive_flags"] = inactive_flags - stats["last_activity_at"] = updated_at - update_user_stats_for_course(user["external_id"], stats) - - -def update_all_users_in_course(course_id: str) -> list[str]: - """Update all user stats in a course.""" - course_contents = Contents().get_list( - anonymous=False, - anonymous_to_peers=False, - course_id=course_id, - ) - author_ids = [] - for content in course_contents: - if content["author_id"] not in author_ids: - author_ids.append(content["author_id"]) - - for author_id in author_ids: - build_course_stats(author_id, course_id) - return author_ids - - -def get_user_by_username(username: str | None) -> dict[str, Any] | None: - """Return user from username.""" - cursor = Users().find({"username": username}) - try: - return next(cursor) - except StopIteration: - return None - + return { + "course_stats.threads": -1, + "course_stats.responses": -1, + "course_stats.replies": -1, + "username": -1, + } -def find_or_create_user(user_id: str) -> str: - """Find or create user.""" - user = Users().get(user_id) - if user: - return user["external_id"] - user_id = Users().insert(user_id) - return user_id - - -def create_comment( - body: str, - user_id: str, - course_id: str, - anonymous: bool, - anonymous_to_peers: bool, - depth: int, - thread_id: str, - parent_id: Optional[str] = None, -) -> Any: - """ - handle comment creation and returns a comment. - - Parameters: - body: The content of the comment. - course_id: The Id of the respective course. - user_id: The requesting user id. - anonymous: anonymous flag(True or False). - anonymous_to_peers: anonymous to peers flag(True or False). - depth: It's value is 0 for parent comment and 1 for child comment. - thread_id (Optional): Id of the Thread where this comment will belong. - parent_id (Optional): Id of the parent comment. It will be given - if creating a child comment. - Response: - The details of the comment that is created. - """ - new_comment_id = Comment().insert( - body=body, - author_id=user_id, - course_id=course_id, - anonymous=anonymous, - anonymous_to_peers=anonymous_to_peers, - depth=depth, - comment_thread_id=thread_id, - parent_id=parent_id, - ) - if parent_id: - update_stats_for_course(user_id, course_id, replies=1) - else: - update_stats_for_course(user_id, course_id, responses=1) - return Comment().get(new_comment_id) - - -def get_user_by_id(user_id: str) -> dict[str, Any] | None: - """Get user by it's id.""" - return Users().get(user_id) - - -def get_thread_by_id(comment_thread_id: str) -> dict[str, Any] | None: - """Get thread by it's id.""" - return CommentThread().get(comment_thread_id) - - -def update_comment_and_get_updated_comment( - comment_id: str, - body: Optional[str] = None, - course_id: Optional[str] = None, - user_id: Optional[str] = None, - anonymous: Optional[bool] = False, - anonymous_to_peers: Optional[bool] = False, - endorsed: Optional[bool] = False, - closed: Optional[bool] = False, - editing_user_id: Optional[str] = None, - edit_reason_code: Optional[str] = None, - endorsement_user_id: Optional[str] = None, -) -> dict[str, Any] | None: - """ - Update an existing child/parent comment. - - Parameters: - comment_id: The ID of the comment to be edited. - body (Optional[str]): The content of the comment. - course_id (Optional[str]): The Id of the respective course. - user_id (Optional[str]): The requesting user id. - anonymous (Optional[bool]): anonymous flag(True or False). - anonymous_to_peers (Optional[bool]): anonymous to peers flag(True or False). - endorsed (Optional[bool]): Flag indicating if the comment is endorsed by any user. - closed (Optional[bool]): Flag indicating if the comment thread is closed. - editing_user_id (Optional[str]): The ID of the user editing the comment. - edit_reason_code (Optional[str]): The reason for editing the comment, typically represented by a code. - endorsement_user_id (Optional[str]): The ID of the user endorsing the comment. - Response: - The details of the comment that is updated. - """ - Comment().update( - comment_id, - body=body, - course_id=course_id, - author_id=user_id, - anonymous=anonymous, - anonymous_to_peers=anonymous_to_peers, - endorsed=endorsed, - closed=closed, - editing_user_id=editing_user_id, - edit_reason_code=edit_reason_code, - endorsement_user_id=endorsement_user_id, - ) - return Comment().get(comment_id) - - -def delete_comment_by_id(comment_id: str) -> None: - """Delete a comment by it's Id.""" - Comment().delete(comment_id) - - -def get_thread_id_by_comment_id(parent_comment_id: str) -> str: - """ - The thread Id from the parent comment. - """ - parent_comment = Comment().get(parent_comment_id) - if parent_comment: - return parent_comment["comment_thread_id"] - raise ValueError("Comment doesn't have the thread.") - - -def get_course_id_by_thread_id(thread_id: str) -> str | None: - """ - Return course_id for the matching thread. - """ - thread = CommentThread().get(thread_id) - if thread: - return thread.get("course_id") - return None - - -def get_course_id_by_comment_id(comment_id: str) -> str | None: - """ - Return course_id for the matching comment. - """ - comment = Comment().get(comment_id) - if comment: - return comment.get("course_id") - return None + @staticmethod + def create_user_pipeline( + course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any] + ) -> list[dict[str, Any]]: + """Get pipeline for course stats api.""" + pipeline: list[dict[str, Any]] = [ + {"$match": {"course_stats.course_id": course_id}}, + {"$project": {"username": 1, "course_stats": 1}}, + {"$unwind": "$course_stats"}, + {"$match": {"course_stats.course_id": course_id}}, + {"$sort": sort_criterion}, + { + "$facet": { + "pagination": [{"$count": "total_count"}], + "data": [ + {"$skip": (page - 1) * per_page}, + {"$limit": per_page}, + ], + } + }, + ] + return pipeline + + # pylint: disable=E1121 + @classmethod + def get_paginated_user_stats( + cls, course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any] + ) -> dict[str, Any]: + """Get paginated stats for a course.""" + pipeline = cls.create_user_pipeline(course_id, page, per_page, sort_criterion) + return list(Users().aggregate(pipeline))[0] + + @staticmethod + def get_contents(**kwargs: Any) -> list[dict[str, Any]]: + """Return contents.""" + return list(Contents().get_list(**kwargs)) + + @staticmethod + def get_user_thread_filter(course_id: str) -> dict[str, Any]: + """Get user thread filter.""" + return { + "_type": {"$in": [CommentThread.content_type]}, + "course_id": {"$in": [course_id]}, + } diff --git a/forum/backends/mongodb/contents.py b/forum/backends/mongodb/contents.py index 09307c61..9d16202f 100644 --- a/forum/backends/mongodb/contents.py +++ b/forum/backends/mongodb/contents.py @@ -4,6 +4,7 @@ from typing import Any, Optional from bson import ObjectId +from pymongo.cursor import Cursor from forum.backends.mongodb.base_model import MongoBaseModel @@ -110,7 +111,7 @@ def override_query(self, query: dict[str, Any]) -> dict[str, Any]: query = {**query, "_type": self.content_type} return super().override_query(query) - def get_list(self, **kwargs: Any) -> Any: + def get_list(self, **kwargs: Any) -> Cursor[dict[str, Any]]: """ Retrieves a list of all content documents in the database based on provided filters. diff --git a/forum/backends/mongodb/subscriptions.py b/forum/backends/mongodb/subscriptions.py index 2d7e030c..99b36743 100644 --- a/forum/backends/mongodb/subscriptions.py +++ b/forum/backends/mongodb/subscriptions.py @@ -74,7 +74,9 @@ def get_subscription( subscription = self._collection.find_one(filter_query) return subscription - def delete_subscription(self, subscriber_id: str, source_id: str) -> int: + def delete_subscription( + self, subscriber_id: str, source_id: str, source_type: Optional[str] = "" + ) -> int: """ Deletes a subscription from the MongoDB collection. @@ -90,5 +92,8 @@ def delete_subscription(self, subscriber_id: str, source_id: str) -> int: "subscriber_id": subscriber_id, "source_id": source_id, } + if source_type: + filter_query["source_type"] = source_type + result = self._collection.delete_one(filter_query) return result.deleted_count diff --git a/forum/backends/mongodb/users.py b/forum/backends/mongodb/users.py index bf1bbb5d..885b1905 100644 --- a/forum/backends/mongodb/users.py +++ b/forum/backends/mongodb/users.py @@ -23,7 +23,7 @@ def insert( external_id: str, username: Optional[str] = None, email: Optional[str] = None, - default_sort_key: str = "date", + default_sort_key: Optional[str] = "date", read_states: Optional[list[dict[str, Any]]] = None, course_stats: Optional[list[dict[str, Any]]] = None, ) -> str: diff --git a/forum/backends/mysql/api.py b/forum/backends/mysql/api.py index dd52037e..c9fd6620 100644 --- a/forum/backends/mysql/api.py +++ b/forum/backends/mysql/api.py @@ -1,22 +1,37 @@ """Client backend for forum v2.""" import math -from datetime import datetime +import random +from datetime import timedelta from typing import Any, Optional, Union from django.contrib.auth.models import User # pylint: disable=E5142 from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ObjectDoesNotExist -from django.db.models import Count, F, Max, Q +from django.core.paginator import Paginator +from django.db.models import ( + Count, + Case, + Exists, + F, + IntegerField, + Max, + OuterRef, + Q, + Subquery, + When, +) from django.utils import timezone from rest_framework import status from rest_framework.response import Response +from forum.backends.backend import AbstractBackend from forum.backends.mysql.models import ( AbuseFlagger, Comment, CommentThread, CourseStat, + EditHistory, ForumUser, HistoricalAbuseFlagger, LastReadTime, @@ -25,1111 +40,1914 @@ UserVote, ) from forum.constants import RETIRED_BODY, RETIRED_TITLE -from forum.utils import get_group_ids_from_params, get_sort_criteria - - -def update_stats_for_course(user_id: str, course_id: str, **kwargs: Any) -> None: - """Update stats for a course.""" - user = User.objects.get(pk=user_id) - course_stat, created = CourseStat.objects.get_or_create( - user=user, course_id=course_id - ) - if created: - course_stat.active_flags = 0 - course_stat.inactive_flags = 0 - course_stat.threads = 0 - course_stat.responses = 0 - course_stat.replies = 0 - - for key, value in kwargs.items(): - if hasattr(course_stat, key): - setattr(course_stat, key, F(key) + value) - - course_stat.save() - - -def _get_entity_from_type( - entity_id: str, entity_type: str -) -> Union[Comment, CommentThread]: - """Get entity from type.""" - if entity_type == "Comment": - return Comment.objects.get(pk=entity_id) - else: - return CommentThread.objects.get(pk=entity_id) - - -def flag_as_abuse(user_id: str, entity_id: str, entity_type: str) -> dict[str, Any]: - """Flag an entity as abuse.""" - user = User.objects.get(pk=user_id) - entity = _get_entity_from_type(entity_id, entity_type) - abuse_flaggers = entity.abuse_flaggers - first_flag_added = False - if user.pk not in abuse_flaggers: - AbuseFlagger.objects.create( - user=user, content=entity, flagged_at=timezone.now() +from forum.utils import get_group_ids_from_params + + +class MySQLBackend(AbstractBackend): + """MySQL backend api.""" + + @classmethod + def update_stats_for_course( + cls, user_id: str, course_id: str, **kwargs: Any + ) -> None: + """Update stats for a course.""" + user = User.objects.get(pk=user_id) + course_stat, created = CourseStat.objects.get_or_create( + user=user, course_id=course_id + ) + if created: + course_stat.active_flags = 0 + course_stat.inactive_flags = 0 + course_stat.threads = 0 + course_stat.responses = 0 + course_stat.replies = 0 + + for key, value in kwargs.items(): + if hasattr(course_stat, key): + setattr(course_stat, key, F(key) + value) + + course_stat.save() + cls.build_course_stats(user_id, course_id) + + @staticmethod + def _get_entity_from_type( + entity_id: str, entity_type: str + ) -> Union[Comment, CommentThread, None]: + """Get entity from type.""" + try: + if entity_type == "Comment": + return Comment.objects.get(pk=entity_id) + else: + return CommentThread.objects.get(pk=entity_id) + except ObjectDoesNotExist: + return None + + @classmethod + def flag_as_abuse( + cls, user_id: str, entity_id: str, **kwargs: Any + ) -> dict[str, Any]: + """Flag an entity as abuse.""" + user = User.objects.get(pk=user_id) + entity = cls._get_entity_from_type( + entity_id, entity_type=kwargs.get("entity_type", "") ) - first_flag_added = len(abuse_flaggers) == 1 - if first_flag_added: - update_stats_for_course(user_id, entity.course_id, active_flags=1) - return entity.to_dict() - - -def un_flag_as_abuse(user_id: str, entity_id: str, entity_type: str) -> dict[str, Any]: - """Unflag an entity as abuse.""" - user = User.objects.get(pk=user_id) - entity = _get_entity_from_type(entity_id, entity_type) - has_no_historical_flags = len(entity.historical_abuse_flaggers) == 0 - if user.pk in entity.abuse_flaggers: + if not entity: + raise ValueError("Entity doesn't exist.") + + abuse_flaggers = entity.abuse_flaggers + first_flag_added = False + if user.pk not in abuse_flaggers: + AbuseFlagger.objects.create( + user=user, content=entity, flagged_at=timezone.now() + ) + first_flag_added = len(abuse_flaggers) == 1 + if first_flag_added: + cls.update_stats_for_course(user_id, entity.course_id, active_flags=1) + return entity.to_dict() + + @classmethod + def un_flag_as_abuse( + cls, user_id: str, entity_id: str, **kwargs: Any + ) -> dict[str, Any]: + """Unflag an entity as abuse.""" + user = User.objects.get(pk=user_id) + entity = cls._get_entity_from_type( + entity_id, entity_type=kwargs.get("entity_type", "") + ) + if not entity: + raise ValueError("Entity doesn't exist.") + + has_no_historical_flags = len(entity.historical_abuse_flaggers) == 0 + if user.pk in entity.abuse_flaggers: + AbuseFlagger.objects.filter( + user=user, + content_object_id=entity.pk, + content_type=entity.content_type, + ).delete() + cls.update_stats_after_unflag( + entity.author.pk, + entity.pk, + has_no_historical_flags, + entity_tpye=entity.type, + ) + + return entity.to_dict() + + @classmethod + def un_flag_all_as_abuse(cls, entity_id: str, **kwargs: Any) -> dict[str, Any]: + """Unflag all users from an entity.""" + entity = cls._get_entity_from_type( + entity_id, entity_type=kwargs.get("entity_type", "") + ) + if not entity: + raise ValueError("Entity doesn't exist.") + + has_no_historical_flags = len(entity.historical_abuse_flaggers) == 0 + historical_abuse_flaggers = list( + set(entity.historical_abuse_flaggers) | set(entity.abuse_flaggers) + ) + for flagger_id in historical_abuse_flaggers: + HistoricalAbuseFlagger.objects.create( + content=entity, + user=User.objects.get(pk=flagger_id), + flagged_at=timezone.now(), + ) AbuseFlagger.objects.filter( - user=user, - content_object_id=entity.pk, - content_type=entity.content_type, + content_object_id=entity.pk, content_type=entity.content_type ).delete() - update_stats_after_unflag( - entity.author.pk, entity.pk, entity.type, has_no_historical_flags + cls.update_stats_after_unflag( + entity.author.pk, + entity.pk, + has_no_historical_flags, + entity_tpye=entity.type, ) - return entity.to_dict() - - -def un_flag_all_as_abuse(entity_id: str, entity_type: str) -> dict[str, Any]: - """Unflag all users from an entity.""" - entity = _get_entity_from_type(entity_id, entity_type) - has_no_historical_flags = len(entity.historical_abuse_flaggers) == 0 - historical_abuse_flaggers = list( - set(entity.historical_abuse_flaggers) | set(entity.abuse_flaggers) - ) - for flagger_id in historical_abuse_flaggers: - HistoricalAbuseFlagger.objects.create( - content=entity, - user=User.objects.get(pk=flagger_id), - flagged_at=timezone.now(), + return entity.to_dict() + + @classmethod + def update_stats_after_unflag( + cls, user_id: str, entity_id: str, has_no_historical_flags: bool, **kwargs: Any + ) -> None: + """Update the stats for the course after unflagging an entity.""" + entity = cls._get_entity_from_type( + entity_id, entity_type=kwargs.get("entity_type", "") ) - AbuseFlagger.objects.filter( - content_object_id=entity.pk, content_type=entity.content_type - ).delete() - update_stats_after_unflag( - entity.author.pk, entity.pk, entity.type, has_no_historical_flags - ) - - return entity.to_dict() - - -def update_stats_after_unflag( - user_id: str, entity_id: str, entity_type: str, has_no_historical_flags: bool -) -> None: - """Update the stats for the course after unflagging an entity.""" - entity = _get_entity_from_type(entity_id, entity_type) - if not entity: - raise ObjectDoesNotExist - - first_historical_flag = ( - has_no_historical_flags and not entity.historical_abuse_flaggers - ) - if first_historical_flag: - update_stats_for_course(user_id, entity.course_id, inactive_flags=1) - - if not entity.abuse_flaggers: - update_stats_for_course(user_id, entity.course_id, active_flags=-1) - - -def update_vote( - content_id: str, - content_type: str, - user_id: str, - vote_type: str = "", - is_deleted: bool = False, -) -> bool: - """ - Update a vote on a thread (either upvote or downvote). - - :param content: The content containing vote data. - :param user: The user for the user voting. - :param vote_type: String indicating the type of vote ('up' or 'down'). - :param is_deleted: Boolean indicating if the user is removing their vote (True) or voting (False). - :return: True if the vote was successfully updated, False otherwise. - """ - user = User.objects.get(pk=user_id) - content = _get_entity_from_type(content_id, content_type) - votes = content.votes - user_vote = votes.filter(user__pk=user.pk).first() - - if not is_deleted: - if vote_type not in ["up", "down"]: - raise ValueError("Invalid vote_type, use ('up' or 'down')") - if not user_vote: - user_vote = UserVote.objects.create(user=user, content=content) - if vote_type == "up": - user_vote.vote = 1 - else: - user_vote.vote = -1 - user_vote.save() - return True - else: - if user_vote: - user_vote.delete() + if not entity: + raise ObjectDoesNotExist + + first_historical_flag = ( + has_no_historical_flags and not entity.historical_abuse_flaggers + ) + if first_historical_flag: + cls.update_stats_for_course(user_id, entity.course_id, inactive_flags=1) + + if not entity.abuse_flaggers: + cls.update_stats_for_course(user_id, entity.course_id, active_flags=-1) + + @classmethod + def update_vote( + cls, + content_id: str, + user_id: str, + vote_type: str = "", + is_deleted: bool = False, + **kwargs: Any, + ) -> bool: + """ + Update a vote on a thread (either upvote or downvote). + + :param content: The content containing vote data. + :param user: The user for the user voting. + :param vote_type: String indicating the type of vote ('up' or 'down'). + :param is_deleted: Boolean indicating if the user is removing their vote (True) or voting (False). + :return: True if the vote was successfully updated, False otherwise. + """ + user = User.objects.get(pk=user_id) + content = cls._get_entity_from_type( + content_id, entity_type=kwargs.get("entity_type", "") + ) + if not content: + raise ValueError("Entity doesn't exist.") + + votes = content.votes + user_vote = votes.filter(user__pk=user.pk).first() + if not is_deleted: + if vote_type not in ["up", "down"]: + raise ValueError("Invalid vote_type, use ('up' or 'down')") + if not user_vote: + vote = 1 if vote_type == "up" else -1 + user_vote = UserVote.objects.create( + user=user, + content=content, + vote=vote, + content_type=content.content_type, + ) + if vote_type == "up": + user_vote.vote = 1 + else: + user_vote.vote = -1 + user_vote.save() return True + else: + if user_vote: + user_vote.delete() + return True + + return False + + @classmethod + def upvote_content(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """ + Upvotes the specified thread or comment by the given user. + + Args: + thread (dict): The thread or comment data to be upvoted. + user (dict): The user who is performing the upvote. + + Returns: + bool: True if the vote was successfully updated, False otherwise. + """ + return cls.update_vote( + entity_id, user_id, vote_type="up", entity_type=kwargs.get("entity_type") + ) - return False + @classmethod + def downvote_content(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """ + Downvotes the specified thread or comment by the given user. + Args: + thread (dict): The thread or comment data to be downvoted. + user (dict): The user who is performing the downvote. -def upvote_content(entity_id: str, entity_type: str, user_id: str) -> bool: - """ - Upvotes the specified thread or comment by the given user. + Returns: + bool: True if the vote was successfully updated, False otherwise. + """ + return cls.update_vote( + entity_id, user_id, vote_type="down", entity_type=kwargs.get("entity_type") + ) - Args: - thread (dict): The thread or comment data to be upvoted. - user (dict): The user who is performing the upvote. + @classmethod + def remove_vote(cls, entity_id: str, user_id: str, **kwargs: Any) -> bool: + """ + Remove the vote (upvote or downvote) from the specified thread or comment for the given user. - Returns: - bool: True if the vote was successfully updated, False otherwise. - """ - return update_vote(entity_id, entity_type, user_id, vote_type="up") + Args: + thread (dict): The thread or comment data from which the vote should be removed. + user (dict): The user who is removing their vote. + Returns: + bool: True if the vote was successfully removed, False otherwise. + """ + return cls.update_vote( + entity_id, user_id, is_deleted=True, entity_type=kwargs.get("entity_type") + ) -def downvote_content(entity_id: str, entity_type: str, user_id: str) -> bool: - """ - Downvotes the specified thread or comment by the given user. + @staticmethod + def validate_thread_and_user( + user_id: str, thread_id: str + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Validate thread and user. - Args: - thread (dict): The thread or comment data to be downvoted. - user (dict): The user who is performing the downvote. + Arguments: + user_id (str): The ID of the user making the request. + thread_id (str): The ID of the thread. - Returns: - bool: True if the vote was successfully updated, False otherwise. - """ - return update_vote(entity_id, entity_type, user_id, vote_type="down") + Returns: + tuple[dict[str, Any], dict[str, Any]]: A tuple containing the user and thread data. + Raises: + ValueError: If the thread or user is not found. + """ + try: + thread = CommentThread.objects.get(pk=int(thread_id)) + user = ForumUser.objects.get(user__pk=user_id) + except ObjectDoesNotExist as exc: + raise ValueError("User / Thread doesn't exist") from exc + + return user.to_dict(), thread.to_dict() + + @staticmethod + def pin_unpin_thread(thread_id: str, action: str) -> None: + """ + Pin or unpin the thread based on action parameter. + + Arguments: + thread_id (str): The ID of the thread to pin/unpin. + action (str): The action to perform ("pin" or "unpin"). + """ + try: + comment_thread = CommentThread.objects.get(pk=int(thread_id)) + except ObjectDoesNotExist as exc: + raise ValueError("Thread doesn't exist") from exc + comment_thread.pinned = action == "pin" + comment_thread.save() -def remove_vote(entity_id: str, entity_type: str, user_id: str) -> bool: - """ - Remove the vote (upvote or downvote) from the specified thread or comment for the given user. + @classmethod + def get_pinned_unpinned_thread_serialized_data( + cls, user_id: str, thread_id: str, serializer_class: Any + ) -> dict[str, Any]: + """ + Return serialized data of pinned or unpinned thread. - Args: - thread (dict): The thread or comment data from which the vote should be removed. - user (dict): The user who is removing their vote. + Arguments: + user (dict[str, Any]): The user who requested the action. + thread_id (str): The ID of the thread to pin/unpin. - Returns: - bool: True if the vote was successfully removed, False otherwise. - """ - return update_vote(entity_id, entity_type, user_id, is_deleted=True) + Returns: + dict[str, Any]: The serialized data of the pinned/unpinned thread. + Raises: + ValueError: If the serialization is not valid. + """ + user = ForumUser.objects.get(user__pk=user_id) + updated_thread = CommentThread.objects.get(pk=thread_id) + user_data = user.to_dict() + context = { + "user_id": user_data["_id"], + "username": user_data["username"], + "type": "thread", + "id": thread_id, + } + if updated_thread is not None: + context = {**context, **updated_thread.to_dict()} + serializer = serializer_class(data=context, backend=cls) + if not serializer.is_valid(): + raise ValueError(serializer.errors) + + return serializer.data + + @classmethod + def handle_pin_unpin_thread_request( + cls, user_id: str, thread_id: str, action: str, serializer_class: Any + ) -> dict[str, Any]: + """ + Catches pin/unpin thread request. + + - validates thread and user. + - pin or unpin the thread based on action parameter. + - return serialized data of thread. + + Arguments: + user_id (str): The ID of the user making the request. + thread_id (str): The ID of the thread to pin/unpin. + action (str): The action to perform ("pin" or "unpin"). + + Returns: + dict[str, Any]: The serialized data of the pinned/unpinned thread. + """ + user, _ = cls.validate_thread_and_user(user_id, thread_id) + cls.pin_unpin_thread(thread_id, action) + return cls.get_pinned_unpinned_thread_serialized_data( + user["_id"], thread_id, serializer_class + ) -def validate_thread_and_user( - user_id: str, thread_id: str -) -> tuple[dict[str, Any], dict[str, Any]]: - """ - Validate thread and user. + @staticmethod + def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]: + """ + Retrieves the count of abuse-flagged comments for each thread in the provided list of thread IDs. + + Args: + thread_ids (list[str]): List of thread IDs to check for abuse flags. + + Returns: + dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count. + """ + abuse_flagger_count_subquery = ( + AbuseFlagger.objects.filter( + content_type=ContentType.objects.get_for_model(Comment), + content_object_id=OuterRef("pk"), + ) + .values("content_object_id") + .annotate(count=Count("pk")) + .values("count") + ) - Arguments: - user_id (str): The ID of the user making the request. - thread_id (str): The ID of the thread. + abuse_flagged_comments = ( + Comment.objects.filter( + comment_thread__pk__in=thread_ids, + ) + .annotate( + abuse_flaggers_count=Subquery( + abuse_flagger_count_subquery, output_field=IntegerField() + ) + ) + .filter(abuse_flaggers_count__gt=0) + ) - Returns: - tuple[dict[str, Any], dict[str, Any]]: A tuple containing the user and thread data. + result = {} + for comment in abuse_flagged_comments: + thread_pk = str(comment.comment_thread.pk) + if thread_pk not in result: + result[thread_pk] = 0 + abuse_flaggers = "abuse_flaggers_count" + result[thread_pk] += getattr(comment, abuse_flaggers) + + return result + + @staticmethod + def get_read_states( + thread_ids: list[str], user_id: str, course_id: str + ) -> dict[str, list[Any]]: + """ + Retrieves the read state and unread comment count for each thread in the provided list. + + Args: + threads (list[dict[str, Any]]): list of threads to check read state for. + user_id (str): The ID of the user whose read states are being retrieved. + course_id (str): The course ID associated with the threads. + + Returns: + dict[str, list[Any]]: A dictionary mapping thread IDs to a list containing + whether the thread is read and the unread comment count. + """ + read_states: dict[str, list[Any]] = {} + if user_id == "": + return read_states + try: + user = User.objects.get(pk=user_id) + except User.DoesNotExist: + return read_states + + threads = CommentThread.objects.filter(pk__in=thread_ids) + read_state = ReadState.objects.filter(user=user, course_id=course_id).first() + if not read_state: + return read_states + + read_dates = read_state.last_read_times + + for thread in threads: + read_date = read_dates.filter(comment_thread=thread).first() + if not read_date: + continue + + last_activity_at = thread.last_activity_at + is_read = read_date.timestamp >= last_activity_at + unread_comment_count = ( + Comment.objects.filter( + comment_thread=thread, created_at__gte=read_date.timestamp + ) + .exclude(author__pk=user_id) + .count() + ) + read_states[str(thread.pk)] = [is_read, unread_comment_count] - Raises: - ValueError: If the thread or user is not found. - """ - try: - thread = CommentThread.objects.get(pk=thread_id) - user = ForumUser.objects.get(user__pk=user_id) - except ObjectDoesNotExist as exc: - raise ValueError("User / Thread doesn't exist") from exc - - return user.to_dict(), thread.to_dict() - - -def pin_unpin_thread(thread_id: str, action: str) -> None: - """ - Pin or unpin the thread based on action parameter. - - Arguments: - thread_id (str): The ID of the thread to pin/unpin. - action (str): The action to perform ("pin" or "unpin"). - """ - try: - comment_thread = CommentThread.objects.get(pk=thread_id) - except ObjectDoesNotExist as exc: - raise ValueError("Thread doesn't exist") from exc - comment_thread.pinned = action == "pin" - comment_thread.save() - - -def get_pinned_unpinned_thread_serialized_data( - user_id: str, thread_id: str, serializer_class: Any -) -> dict[str, Any]: - """ - Return serialized data of pinned or unpinned thread. - - Arguments: - user (dict[str, Any]): The user who requested the action. - thread_id (str): The ID of the thread to pin/unpin. - - Returns: - dict[str, Any]: The serialized data of the pinned/unpinned thread. - - Raises: - ValueError: If the serialization is not valid. - """ - user = ForumUser.objects.get(user__pk=user_id) - updated_thread = CommentThread.objects.get(pk=thread_id) - user_data = user.to_dict() - context = { - "user_id": user_data["_id"], - "username": user_data["username"], - "type": "thread", - "id": thread_id, - } - if updated_thread is not None: - context = {**context, **updated_thread.to_dict()} - serializer = serializer_class(data=context) - if not serializer.is_valid(): - raise ValueError(serializer.errors) - - return serializer.data - - -def handle_pin_unpin_thread_request( - user_id: str, thread_id: str, action: str, serializer_class: Any -) -> dict[str, Any]: - """ - Catches pin/unpin thread request. - - - validates thread and user. - - pin or unpin the thread based on action parameter. - - return serialized data of thread. - - Arguments: - user_id (str): The ID of the user making the request. - thread_id (str): The ID of the thread to pin/unpin. - action (str): The action to perform ("pin" or "unpin"). - - Returns: - dict[str, Any]: The serialized data of the pinned/unpinned thread. - """ - user, _ = validate_thread_and_user(user_id, thread_id) - pin_unpin_thread(thread_id, action) - return get_pinned_unpinned_thread_serialized_data( - user["_id"], thread_id, serializer_class - ) - - -def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]: - """ - Retrieves the count of abuse-flagged comments for each thread in the provided list of thread IDs. - - Args: - thread_ids (list[str]): List of thread IDs to check for abuse flags. - - Returns: - dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count. - """ - flagged_threads = ( - Comment.objects.filter( - comment_thread__pk__in=thread_ids, - abuse_flaggers__isnull=False, - ) - .values("comment_thread_id") - .annotate(flagged_count=Count("id")) - .order_by() - ) - - result = { - str(item["comment_thread_id"]): item["flagged_count"] - for item in flagged_threads - } - return result - - -def get_read_states( - thread_ids: list[str], user_id: str, course_id: str -) -> dict[str, list[Any]]: - """ - Retrieves the read state and unread comment count for each thread in the provided list. - - Args: - threads (list[dict[str, Any]]): list of threads to check read state for. - user_id (str): The ID of the user whose read states are being retrieved. - course_id (str): The course ID associated with the threads. - - Returns: - dict[str, list[Any]]: A dictionary mapping thread IDs to a list containing - whether the thread is read and the unread comment count. - """ - read_states: dict[str, list[Any]] = {} - user = User.objects.get(pk=user_id) - threads = CommentThread.objects.filter(pk__in=thread_ids) - read_state = ReadState.objects.filter(user=user, course_id=course_id).first() - if not read_state: return read_states - read_dates = read_state.last_read_times + @staticmethod + def get_filtered_thread_ids( + thread_ids: list[str], context: str, group_ids: list[str] + ) -> set[str]: + """ + Filters thread IDs based on context and group ID criteria. + + Args: + thread_ids (list[str]): List of thread IDs to filter. + context (str): The context to filter by. + group_ids (list[str]): List of group IDs for group-based filtering. + + Returns: + set: A set of filtered thread IDs based on the context and group ID criteria. + """ + context_threads = CommentThread.objects.filter( + pk__in=thread_ids, context=context + ) + context_thread_ids = set(thread.pk for thread in context_threads) - for thread in threads: - read_date = read_dates.filter(comment_thread=thread).first() - if not read_date: - continue + if not group_ids: + return context_thread_ids - last_activity_at = thread.last_activity_at - is_read = read_date.timestamp >= last_activity_at - unread_comment_count = Comment.objects.filter( - comment_thread=thread, created_at__gte=read_date, author__pk__ne=user_id - ).count() - read_states[thread.pk] = [is_read, unread_comment_count] - - return read_states - - -def get_filtered_thread_ids( - thread_ids: list[str], context: str, group_ids: list[str] -) -> set[str]: - """ - Filters thread IDs based on context and group ID criteria. - - Args: - thread_ids (list[str]): List of thread IDs to filter. - context (str): The context to filter by. - group_ids (list[str]): List of group IDs for group-based filtering. - - Returns: - set: A set of filtered thread IDs based on the context and group ID criteria. - """ - context_threads = CommentThread.objects.filter(pk__in=thread_ids, context=context) - context_thread_ids = set(thread.pk for thread in context_threads) - - if not group_ids: - return context_thread_ids - - group_threads = CommentThread.objects.filter( - Q(group_id__in=group_ids) | Q(group_id__isnull=True), - id__in=thread_ids, - ) - group_thread_ids = set(thread.pk for thread in group_threads) - - return context_thread_ids.union(group_thread_ids) - - -def get_endorsed(thread_ids: list[str]) -> dict[str, bool]: - """ - Retrieves endorsed status for each thread in the provided list of thread IDs. - - Args: - thread_ids (list[str]): List of thread IDs to check for endorsement. - - Returns: - dict[str, bool]: A dictionary mapping thread IDs to their endorsed status (True if endorsed, False otherwise). - """ - endorsed_comments = Comment.objects.filter( - comment_thread__pk__in=thread_ids, endorsed=True - ) - - return {str(comment.comment_thread.pk): True for comment in endorsed_comments} - - -def get_user_read_state_by_course_id(user_id: str, course_id: str) -> dict[str, Any]: - """ - Retrieves the user's read state for a specific course. - - Args: - user (dict[str, Any]): The user object containing read states. - course_id (str): The course ID to filter the user's read state by. - - Returns: - dict[str, Any]: The user's read state for the specified course, or an empty dictionary if not found. - """ - user = User.objects.get(pk=user_id) - try: - read_state = ReadState.objects.get(user=user, course_id=course_id) - except ObjectDoesNotExist: - return {} - return read_state.to_dict() - - -# TODO: Make this function modular -# pylint: disable=too-many-nested-blocks,too-many-statements -def handle_threads_query( - comment_thread_ids: list[str], - user_id: str, - course_id: str, - group_ids: list[int], - author_id: Optional[int], - thread_type: Optional[str], - filter_flagged: bool, - filter_unread: bool, - filter_unanswered: bool, - filter_unresponded: bool, - count_flagged: bool, - sort_key: str, - page: int, - per_page: int, - context: str = "course", - raw_query: bool = False, -) -> dict[str, Any]: - """ - Handles complex thread queries based on various filters and returns paginated results. - - Args: - comment_thread_ids (list[int]): List of comment thread IDs to filter. - user (User): The user making the request. - course_id (str): The course ID associated with the threads. - group_ids (list[int]): List of group IDs for group-based filtering. - author_id (int): The ID of the author to filter threads by. - thread_type (str): The type of thread to filter by. - filter_flagged (bool): Whether to filter threads flagged for abuse. - filter_unread (bool): Whether to filter unread threads. - filter_unanswered (bool): Whether to filter unanswered questions. - filter_unresponded (bool): Whether to filter threads with no responses. - count_flagged (bool): Whether to include flagged content count. - sort_key (str): The key to sort the threads by. - page (int): The page number for pagination. - per_page (int): The number of threads per page. - context (str): The context to filter threads by. - raw_query (bool): Whether to return raw query results without further processing. - - Returns: - dict[str, Any]: A dictionary containing the paginated thread results and associated metadata. - """ - user = User.objects.get(pk=user_id) - # Base query - base_query = CommentThread.objects.filter( - pk__in=comment_thread_ids, context=context - ) - - # Group filtering - if group_ids: - base_query = base_query.filter( - Q(group_id__in=group_ids) | Q(group_id__isnull=True) + group_threads = CommentThread.objects.filter( + Q(group_id__in=group_ids) | Q(group_id__isnull=True), + id__in=thread_ids, ) + group_thread_ids = set(thread.pk for thread in group_threads) - # Author filtering - if author_id: - base_query = base_query.filter(author__pk=author_id) - if author_id != user.pk: - base_query = base_query.filter(anonymous=False, anonymous_to_peers=False) - - # Thread type filtering - if thread_type: - base_query = base_query.filter(thread_type=thread_type) - - # Flagged content filtering - if filter_flagged: - flagged_comments = Comment.objects.filter( - course_id=course_id, abuse_flaggers__isnull=False - ).values_list("comment_thread_id", flat=True) - flagged_threads = CommentThread.objects.filter( - course_id=course_id, abuse_flaggers__isnull=False - ).values_list("id", flat=True) - base_query = base_query.filter( - pk__in=list( - set(comment_thread_ids) & set(flagged_comments) | set(flagged_threads) - ) + return context_thread_ids.union(group_thread_ids) + + @staticmethod + def get_endorsed(thread_ids: list[str]) -> dict[str, bool]: + """ + Retrieves endorsed status for each thread in the provided list of thread IDs. + + Args: + thread_ids (list[str]): List of thread IDs to check for endorsement. + + Returns: + dict[str, bool]: A dictionary of thread IDs to their endorsed status (True if endorsed, False otherwise). + """ + endorsed_comments = Comment.objects.filter( + comment_thread__pk__in=thread_ids, endorsed=True ) - # Unanswered questions filtering - if filter_unanswered: - endorsed_threads = Comment.objects.filter( - course_id=course_id, - parent__isnull=True, - endorsed=True, - ).values_list("comment_thread_id", flat=True) - base_query = base_query.filter(thread_type="question", id__nin=endorsed_threads) + return {str(comment.comment_thread.pk): True for comment in endorsed_comments} + + @staticmethod + def get_user_read_state_by_course_id( + user_id: str, course_id: str + ) -> dict[str, Any]: + """ + Retrieves the user's read state for a specific course. + + Args: + user (dict[str, Any]): The user object containing read states. + course_id (str): The course ID to filter the user's read state by. + + Returns: + dict[str, Any]: The user's read state for the specified course, or an empty dictionary if not found. + """ + user = User.objects.get(pk=user_id) + try: + read_state = ReadState.objects.get(user=user, course_id=course_id) + except ObjectDoesNotExist: + return {} + return read_state.to_dict() + + @staticmethod + def get_sort_criteria(sort_key: str) -> list[str]: + """ + Generate sorting criteria based on the provided key. + + Parameters: + ----------- + sort_key : str + Key to determine sort order ("date", "activity", "votes", "comments"). + + Returns: + -------- + list + List of strings for sorting, including "pinned" and the relevant field, + optionally adding "created_at" if needed. + """ + sort_key_mapper = { + "date": "-created_at", + "activity": "-last_activity_at", + "votes": "-votes__point", + "comments": "-comment_count", + } + sort_key = sort_key or "date" + sort_key = sort_key_mapper.get(sort_key, "") + + if sort_key: + # only sort order of -1 (descending) is supported. + sort_criteria = ["-pinned", sort_key] + if sort_key not in ["-created_at", "-last_activity_at"]: + sort_criteria.append("-created_at") + return sort_criteria + else: + return [] + + # TODO: Make this function modular + # pylint: disable=too-many-nested-blocks,too-many-statements + @classmethod + def handle_threads_query( + cls, + comment_thread_ids: list[str], + user_id: str, + course_id: str, + group_ids: list[int], + author_id: Optional[str], + thread_type: Optional[str], + filter_flagged: bool, + filter_unread: bool, + filter_unanswered: bool, + filter_unresponded: bool, + count_flagged: bool, + sort_key: str, + page: int, + per_page: int, + context: str = "course", + raw_query: bool = False, + ) -> dict[str, Any]: + """ + Handles complex thread queries based on various filters and returns paginated results. + + Args: + comment_thread_ids (list[int]): List of comment thread IDs to filter. + user (User): The user making the request. + course_id (str): The course ID associated with the threads. + group_ids (list[int]): List of group IDs for group-based filtering. + author_id (int): The ID of the author to filter threads by. + thread_type (str): The type of thread to filter by. + filter_flagged (bool): Whether to filter threads flagged for abuse. + filter_unread (bool): Whether to filter unread threads. + filter_unanswered (bool): Whether to filter unanswered questions. + filter_unresponded (bool): Whether to filter threads with no responses. + count_flagged (bool): Whether to include flagged content count. + sort_key (str): The key to sort the threads by. + page (int): The page number for pagination. + per_page (int): The number of threads per page. + context (str): The context to filter threads by. + raw_query (bool): Whether to return raw query results without further processing. + + Returns: + dict[str, Any]: A dictionary containing the paginated thread results and associated metadata. + """ + if user_id is None or user_id == "": + user = None + else: + try: + user = User.objects.get(pk=user_id) + except User.DoesNotExist as exc: + raise ValueError("User does not exist") from exc + # Base query + base_query = CommentThread.objects.filter( + pk__in=comment_thread_ids, context=context + ) + + # Group filtering + if group_ids: + base_query = base_query.filter( + Q(group_id__in=group_ids) | Q(group_id__isnull=True) + ) + + # Author filtering + if author_id: + base_query = base_query.filter(author__pk=author_id) + if user and author_id != str(user.pk): + base_query = base_query.filter( + anonymous=False, anonymous_to_peers=False + ) + + # Thread type filtering + if thread_type: + base_query = base_query.filter(thread_type=thread_type) + + # Flagged content filtering + if filter_flagged: + comment_abuse_flaggers = AbuseFlagger.objects.filter( + content_object_id=OuterRef("pk"), + content_type=ContentType.objects.get_for_model(Comment), + ) + + flagged_comments = ( + Comment.objects.filter(course_id=course_id) + .annotate(has_abuse_flaggers=Exists(comment_abuse_flaggers)) + .filter(has_abuse_flaggers=True) + .values_list("comment_thread_id", flat=True) + ) + thread_abuse_flaggers = AbuseFlagger.objects.filter( + content_object_id=OuterRef("pk"), + content_type=ContentType.objects.get_for_model(CommentThread), + ) + + flagged_threads = ( + CommentThread.objects.filter(course_id=course_id) + .annotate(has_abuse_flaggers=Exists(thread_abuse_flaggers)) + .filter(has_abuse_flaggers=True) + .values_list("id", flat=True) + ) + + base_query = base_query.filter( + pk__in=list( + set(comment_thread_ids) & set(flagged_comments) + | set(flagged_threads) + ) + ) + + # Unanswered questions filtering + if filter_unanswered: + endorsed_threads = Comment.objects.filter( + course_id=course_id, + parent__isnull=True, + endorsed=True, + ).values_list("comment_thread_id", flat=True) + base_query = base_query.filter( + thread_type="question", + ).exclude(pk__in=endorsed_threads) + + # Unresponded threads filtering + if filter_unresponded: + base_query = base_query.annotate(num_comments=Count("comment")).filter( + num_comments=0 + ) + + sort_criteria = cls.get_sort_criteria(sort_key) + + comment_threads = ( + base_query.order_by(*sort_criteria) if sort_criteria else base_query + ) + thread_count = base_query.count() - # Unresponded threads filtering - if filter_unresponded: - base_query = base_query.filter(comment_count=0) + if raw_query: + return { + "result": [ + comment_thread.to_dict() for comment_thread in comment_threads + ] + } - sort_criteria = get_sort_criteria(sort_key) + if filter_unread and user: + read_state = cls.get_user_read_state_by_course_id(str(user.pk), course_id) + read_dates = read_state.get("last_read_times", {}) + + threads: list[str] = [] + skipped = 0 + to_skip = (page - 1) * per_page + has_more = False + + for thread in comment_threads.iterator(): + thread_key = thread.pk + if ( + thread_key not in read_dates + or read_dates[thread_key] < thread.last_activity_at + ): + if skipped >= to_skip: + if len(threads) == per_page: + has_more = True + break + threads.append(thread.pk) + else: + skipped += 1 + num_pages = page + 1 if has_more else page + else: + threads = [thread.pk for thread in comment_threads] + page = max(1, page) + start = per_page * (page - 1) + end = per_page * page + paginated_collection = threads[start:end] + threads = list(paginated_collection) + num_pages = max(1, math.ceil(thread_count / per_page)) + + if len(threads) == 0: + collection = [] + else: + collection = cls.threads_presentor( + threads, user_id, course_id, count_flagged + ) - comment_threads = ( - base_query.order_by(sort_criteria) if sort_criteria else base_query - ) - thread_count = base_query.count() + return { + "collection": collection, + "num_pages": num_pages, + "page": page, + "thread_count": thread_count, + } - if raw_query: + @staticmethod + def prepare_thread( + thread_id: str, + is_read: bool, + unread_count: int, + is_endorsed: bool, + abuse_flagged_count: int, + ) -> dict[str, Any]: + """ + Prepares thread data for presentation. + + Args: + thread (dict[str, Any]): The thread data. + is_read (bool): Whether the thread is read. + unread_count (int): The count of unread comments. + is_endorsed (bool): Whether the thread is endorsed. + abuse_flagged_count (int): The abuse flagged count. + + Returns: + dict[str, Any]: A dictionary representing the prepared thread data. + """ + thread = CommentThread.objects.get(pk=thread_id) return { - "result": [comment_thread.to_dict() for comment_thread in comment_threads] + **thread.to_dict(), + "type": "thread", + "read": is_read, + "unread_comments_count": unread_count, + "endorsed": is_endorsed, + "abuse_flagged_count": abuse_flagged_count, } - if filter_unread and user: - read_state = get_user_read_state_by_course_id(user.pk, course_id) - read_dates = read_state.get("last_read_times", {}) - - threads: list[str] = [] - skipped = 0 - to_skip = (page - 1) * per_page - has_more = False - - for thread in comment_threads.iterator(): - thread_key = thread.pk - if ( - thread_key not in read_dates - or read_dates[thread_key] < thread.last_activity_at - ): - if skipped >= to_skip: - if len(threads) == per_page: - has_more = True - break - threads.append(thread.pk) - else: - skipped += 1 - num_pages = page + 1 if has_more else page - else: - threads = [thread.pk for thread in comment_threads] - page = max(1, page) - start = per_page * (page - 1) - end = per_page * page - paginated_collection = threads[start:end] - threads = list(paginated_collection) - num_pages = max(1, math.ceil(thread_count / per_page)) - - if len(threads) == 0: - collection = [] - else: - collection = threads_presentor(threads, user.pk, course_id, count_flagged) - - return { - "collection": collection, - "num_pages": num_pages, - "page": page, - "thread_count": thread_count, - } - - -def prepare_thread( - thread_id: str, - is_read: bool, - unread_count: int, - is_endorsed: bool, - abuse_flagged_count: int, -) -> dict[str, Any]: - """ - Prepares thread data for presentation. - - Args: - thread (dict[str, Any]): The thread data. - is_read (bool): Whether the thread is read. - unread_count (int): The count of unread comments. - is_endorsed (bool): Whether the thread is endorsed. - abuse_flagged_count (int): The abuse flagged count. - - Returns: - dict[str, Any]: A dictionary representing the prepared thread data. - """ - thread = CommentThread.objects.get(pk=thread_id) - return { - **thread.to_dict(), - "type": "thread", - "read": is_read, - "unread_comments_count": unread_count, - "endorsed": is_endorsed, - "abuse_flagged_count": abuse_flagged_count, - } - - -def threads_presentor( - thread_ids: list[str], - user_id: str, - course_id: str, - count_flagged: bool = False, -) -> list[dict[str, Any]]: - """ - Presents the threads by preparing them for display. - - Args: - threads (list[CommentThread]): List of threads to present. - user (User): The user presenting the threads. - course_id (str): The course ID associated with the threads. - count_flagged (bool, optional): Whether to include flagged content count. Defaults to False. - - Returns: - list[dict[str, Any]]: A list of prepared thread data. - """ - threads = CommentThread.objects.filter(pk__in=thread_ids) - read_states = get_read_states(thread_ids, user_id, course_id) - threads_endorsed = get_endorsed(thread_ids) - threads_flagged = get_abuse_flagged_count(thread_ids) if count_flagged else {} - - presenters = [] - for thread in threads: - is_read, unread_count = read_states.get( - thread.pk, (False, thread.comment_count) + @classmethod + def threads_presentor( + cls, + thread_ids: list[str], + user_id: str, + course_id: str, + count_flagged: bool = False, + ) -> list[dict[str, Any]]: + """ + Presents the threads by preparing them for display. + + Args: + threads (list[CommentThread]): List of threads to present. + user (User): The user presenting the threads. + course_id (str): The course ID associated with the threads. + count_flagged (bool, optional): Whether to include flagged content count. Defaults to False. + + Returns: + list[dict[str, Any]]: A list of prepared thread data. + """ + threads = CommentThread.objects.filter(pk__in=thread_ids) + read_states = cls.get_read_states(thread_ids, user_id, course_id) + threads_endorsed = cls.get_endorsed(thread_ids) + threads_flagged = ( + cls.get_abuse_flagged_count(thread_ids) if count_flagged else {} ) - is_endorsed = threads_endorsed.get(thread.pk, False) - abuse_flagged_count = threads_flagged.get(thread.pk, 0) - presenters.append( - prepare_thread( - thread.pk, - is_read, - unread_count, - is_endorsed, - abuse_flagged_count, + + presenters = [] + for thread in threads: + is_read, unread_count = read_states.get( + thread.pk, (False, thread.comment_count) + ) + is_endorsed = threads_endorsed.get(thread.pk, False) + abuse_flagged_count = threads_flagged.get(str(thread.pk), 0) + presenters.append( + cls.prepare_thread( + thread.pk, + is_read, + unread_count, + is_endorsed, + abuse_flagged_count, + ) ) - ) - return presenters + return presenters + @staticmethod + def get_username_from_id(user_id: str) -> Optional[str]: + """ + Retrieve the username associated with a given user ID. -def get_username_from_id(user_id: str) -> Optional[str]: - """ - Retrieve the username associated with a given user ID. + Args: + _id (int): The unique identifier of the user. - Args: - _id (int): The unique identifier of the user. + Returns: + Optional[str]: The username of the user if found, or None if not. - Returns: - Optional[str]: The username of the user if found, or None if not. + """ + try: + user = User.objects.get(pk=user_id) + except ObjectDoesNotExist: + return None + return user.username + + @staticmethod + def validate_object(model: str, obj_id: str) -> Any: + """ + Validates the object if it exists or not. + + Parameters: + model: The model for which to validate the id. + id: The ID of the object to validate in the model. + Response: + raise exception if object does not exists. + return object + """ + modelss = { + "CommentThread": CommentThread, + "Comment": Comment, + } - """ - try: - user = User().objects.get(pk=user_id) - except ObjectDoesNotExist: - return None - return user.username - - -def validate_object(model: Any, obj_id: str) -> Any: - """ - Validates the object if it exists or not. - - Parameters: - model: The model for which to validate the id. - id: The ID of the object to validate in the model. - Response: - raise exception if object does not exists. - return object - """ - try: - instance = model.objects.get(pk=int(obj_id)) - except ObjectDoesNotExist as exc: - raise ObjectDoesNotExist from exc - - return instance - - -def find_subscribed_threads(user_id: str, course_id: Optional[str] = None) -> list[str]: - """ - Find threads that a user is subscribed to in a specific course. - - Args: - user_id (str): The ID of the user. - course_id (str): The ID of the course. - - Returns: - list: A list of thread ids that the user is subscribed to in the course. - """ - subscriptions = Subscription.objects.filter( - subscriber__pk=user_id, - source_content_type=ContentType.objects.get_for_model(CommentThread), - ) - thread_ids = [str(subscription.source_object_id) for subscription in subscriptions] - if course_id: - thread_ids = list( - CommentThread.objects.filter( - pk__in=thread_ids, - course_id=course_id, - ).values_list("pk", flat=True) + try: + instance = modelss[model].objects.get(pk=int(obj_id)) + except ObjectDoesNotExist as exc: + raise ObjectDoesNotExist from exc + + return instance.to_dict() + + @staticmethod + def find_subscribed_threads( + user_id: str, course_id: Optional[str] = None + ) -> list[str]: + """ + Find threads that a user is subscribed to in a specific course. + + Args: + user_id (str): The ID of the user. + course_id (str): The ID of the course. + + Returns: + list: A list of thread ids that the user is subscribed to in the course. + """ + subscriptions = Subscription.objects.filter( + subscriber__pk=user_id, + source_content_type=ContentType.objects.get_for_model(CommentThread), ) + thread_ids = [ + str(subscription.source_object_id) for subscription in subscriptions + ] + if course_id: + thread_ids = list( + CommentThread.objects.filter( + pk__in=thread_ids, + course_id=course_id, + ).values_list("pk", flat=True) + ) + + return thread_ids + + @classmethod + def subscribe_user( + cls, user_id: str, source_id: str, source_type: str + ) -> dict[str, Any] | None: + """Subscribe a user to a source.""" + source = cls._get_entity_from_type(source_id, source_type) + if source is None: + return None + + subscription, _ = Subscription.objects.get_or_create( + subscriber=User.objects.get(pk=int(user_id)), + source_object_id=source.pk, + source_content_type=source.content_type, + ) + return subscription.to_dict() + + @classmethod + def unsubscribe_user( + cls, user_id: str, source_id: str, source_type: Optional[str] = "" + ) -> None: + """Unsubscribe a user from a source.""" + source = cls._get_entity_from_type(source_id, source_type or "") + if source is None: + return + + Subscription.objects.filter( + subscriber=User.objects.get(pk=int(user_id)), + source_object_id=source.pk, + source_content_type=source.content_type, + ).delete() + + @staticmethod + def delete_comments_of_a_thread(thread_id: str) -> None: + """Delete comments of a thread.""" + Comment.objects.filter(comment_thread__pk=thread_id, parent=None).delete() + + @classmethod + def delete_subscriptions_of_a_thread(cls, thread_id: str) -> None: + """Delete subscriptions of a thread.""" + source = cls._get_entity_from_type(thread_id, "CommentThread") + if source is None: + return + + Subscription.objects.filter( + source_object_id=source.pk, + source_content_type=source.content_type, + ).delete() - return thread_ids - - -def subscribe_user( - user_id: str, source_id: str, source_type: str -) -> dict[str, Any] | None: - """Subscribe a user to a source.""" - subscription, _ = Subscription.objects.get_or_create( - subscriber__pk=user_id, source__pk=source_id, source_type=source_type - ) - return subscription.to_dict() - - -def unsubscribe_user(user_id: str, source_id: str) -> None: - """Unsubscribe a user from a source.""" - Subscription.objects.filter(subscriber__pk=user_id, source__pk=source_id).delete() - - -def delete_comments_of_a_thread(thread_id: str) -> None: - """Delete comments of a thread.""" - Comment.objects.filter(comment_thread__pk=thread_id, parent=None).delete() - - -def delete_subscriptions_of_a_thread(thread_id: str) -> None: - """Delete subscriptions of a thread.""" - Subscription.objects.filter( - source__pk=thread_id, - source_type=ContentType.objects.get_for_model(CommentThread), - ).delete() - - -def validate_params( - params: dict[str, Any], user_id: Optional[str] = None -) -> Response | None: - """ - Validate the request parameters. - - Args: - params (dict): The request parameters. - user_id (optional[str]): The Id of the user for validation. - - Returns: - Response: A Response object with an error message if doesn't exist. - """ - valid_params = [ - "course_id", - "author_id", - "thread_type", - "flagged", - "unread", - "unanswered", - "unresponded", - "count_flagged", - "sort_key", - "page", - "per_page", - "request_id", - "commentable_ids", - ] - if not user_id: - valid_params.append("user_id") - - for key in params: - if key not in valid_params: + @staticmethod + def validate_params( + params: dict[str, Any], user_id: Optional[str] = None + ) -> Response | None: + """ + Validate the request parameters. + + Args: + params (dict): The request parameters. + user_id (optional[str]): The Id of the user for validation. + + Returns: + Response: A Response object with an error message if doesn't exist. + """ + valid_params = [ + "course_id", + "author_id", + "thread_type", + "flagged", + "unread", + "unanswered", + "unresponded", + "count_flagged", + "sort_key", + "page", + "per_page", + "request_id", + "commentable_ids", + ] + if not user_id: + valid_params.append("user_id") + + for key in params: + if key not in valid_params: + return Response( + {"error": f"Invalid parameter: {key}"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if "course_id" not in params: return Response( - {"error": f"Invalid parameter: {key}"}, + {"error": "Missing required parameter: course_id"}, status=status.HTTP_400_BAD_REQUEST, ) - if "course_id" not in params: - return Response( - {"error": "Missing required parameter: course_id"}, - status=status.HTTP_400_BAD_REQUEST, + if user_id: + try: + User.objects.get(pk=user_id) + except ObjectDoesNotExist: + return Response( + {"error": "User doesn't exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + return None + + @classmethod + def get_threads( + cls, + params: dict[str, Any], + user_id: str, + serializer: Any, + thread_ids: list[str], + ) -> dict[str, Any]: + """get subscribed or all threads of a specific course for a specific user.""" + count_flagged = bool(params.get("count_flagged", False)) + threads = cls.handle_threads_query( + thread_ids, + user_id, + params["course_id"], + get_group_ids_from_params(params), + params.get("author_id", ""), + params.get("thread_type"), + bool(params.get("flagged", False)), + bool(params.get("unread", False)), + bool(params.get("unanswered", False)), + bool(params.get("unresponded", False)), + count_flagged, + params.get("sort_key", ""), + int(params.get("page", 1)), + int(params.get("per_page", 100)), + ) + context: dict[str, Any] = { + "count_flagged": count_flagged, + "include_endorsed": True, + "include_read_state": True, + } + if user_id: + context["user_id"] = user_id + serializer = serializer( + threads.pop("collection"), many=True, context=context, backend=cls + ) + threads["collection"] = serializer.data + return threads + + @classmethod + def get_user_voted_ids(cls, user_id: str, vote: str) -> list[str]: + """Get the IDs of the posts voted by a user.""" + if vote not in ["up", "down"]: + raise ValueError("Invalid vote type") + + vote_value = 1 if vote == "up" else -1 + voted_ids = UserVote.objects.filter( + user__pk=user_id, vote=vote_value + ).values_list("content_object_id", flat=True) + return list(voted_ids) + + @staticmethod + def filter_standalone_threads(comment_ids: list[str]) -> list[str]: + """Filter out standalone threads from the list of threads.""" + comments = Comment.objects.filter(pk__in=comment_ids) + filtered_threads = [ + comment.comment_thread + for comment in comments + if comment.comment_thread.context != "standalone" + ] + return [str(thread.pk) for thread in filtered_threads] + + @classmethod + def user_to_hash( + cls, user_id: str, params: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: + """ + Converts user data to a hash + """ + user = User.objects.get(pk=user_id) + forum_user = ForumUser.objects.get(user__pk=user_id) + if params is None: + params = {} + + user_data = forum_user.to_dict() + hash_data = {} + hash_data["username"] = user_data["username"] + hash_data["external_id"] = user_data["external_id"] + + if params.get("complete"): + subscribed_thread_ids = cls.find_subscribed_threads(user_id) + upvoted_ids = cls.get_user_voted_ids(user_id, "up") + downvoted_ids = cls.get_user_voted_ids(user_id, "down") + hash_data.update( + { + "subscribed_thread_ids": subscribed_thread_ids, + "subscribed_commentable_ids": [], + "subscribed_user_ids": [], + "follower_ids": [], + "id": user_id, + "upvoted_ids": upvoted_ids, + "downvoted_ids": downvoted_ids, + "default_sort_key": user_data["default_sort_key"], + } + ) + + if params.get("course_id"): + threads = CommentThread.objects.filter( + author=user, + course_id=params["course_id"], + anonymous=False, + anonymous_to_peers=False, + ) + comments = Comment.objects.filter( + author=user, + course_id=params["course_id"], + anonymous=False, + anonymous_to_peers=False, + ) + comment_ids = list(comments.values_list("pk", flat=True)) + if params.get("group_ids"): + group_threads = threads.filter( + group_id__in=params["group_ids"] + [None] + ) + group_thread_ids = [str(thread.pk) for thread in group_threads] + threads_count = len(group_thread_ids) + comment_thread_ids = cls.filter_standalone_threads(comment_ids) + + group_comment_threads = CommentThread.objects.filter( + id__in=comment_thread_ids, group_id__in=params["group_ids"] + [None] + ) + group_comment_thread_ids = [ + str(thread.pk) for thread in group_comment_threads + ] + comments_count = sum( + 1 + for comment_thread_id in comment_thread_ids + if comment_thread_id in group_comment_thread_ids + ) + else: + thread_ids = [str(thread.pk) for thread in threads] + threads_count = len(thread_ids) + comment_thread_ids = cls.filter_standalone_threads(comment_ids) + comments_count = len(comment_thread_ids) + + hash_data.update( + { + "threads_count": threads_count, + "comments_count": comments_count, + } + ) + + return hash_data + + @staticmethod + def replace_username(user_id: str, username: str) -> None: + """Replace the username of a Django user.""" + try: + user = User.objects.get(id=user_id) + user.username = username + user.save() + except User.DoesNotExist as exc: + raise ValueError("User does not exist") from exc + + @staticmethod + def unsubscribe_all(user_id: str) -> None: + """Unsubscribe user from all content.""" + Subscription.objects.filter(subscriber__pk=user_id).delete() + + # Kept method signature same as mongo implementation + @staticmethod + def retire_all_content( + user_id: str, username: str + ) -> None: # pylint: disable=W0613 + """Retire all content from user.""" + comments = Comment.objects.filter(author__pk=user_id) + for comment in comments: + comment.body = RETIRED_BODY + comment.save() + + comment_threads = CommentThread.objects.filter(author__pk=user_id) + for comment_thread in comment_threads: + comment_thread.body = RETIRED_BODY + comment_thread.title = RETIRED_TITLE + comment_thread.save() + + @staticmethod + def find_or_create_read_state(user_id: str, thread_id: str) -> dict[str, Any]: + """Find or create user read states.""" + try: + user = User.objects.get(pk=user_id) + thread = CommentThread.objects.get(pk=thread_id) + except (User.DoesNotExist, CommentThread.DoesNotExist) as exc: + raise ObjectDoesNotExist from exc + + read_state, _ = ReadState.objects.get_or_create( + user=user, course_id=thread.course_id + ) + return read_state.to_dict() + + @classmethod + def mark_as_read(cls, user_id: str, thread_id: str) -> None: + """Mark thread as read.""" + user = User.objects.get(pk=user_id) + thread = CommentThread.objects.get(pk=thread_id) + read_state, _ = ReadState.objects.get_or_create( + user=user, course_id=thread.course_id + ) + + LastReadTime.objects.update_or_create( + read_state=read_state, + comment_thread=thread, + defaults={ + "timestamp": timezone.now(), + }, ) - if user_id: + @staticmethod + def find_or_create_user_stats(user_id: str, course_id: str) -> dict[str, Any]: + """Find or create user stats document.""" + user = User.objects.get(pk=user_id) try: - User.objects.get(pk=user_id) - except ObjectDoesNotExist: - return Response( - {"error": "User doesn't exist"}, - status=status.HTTP_400_BAD_REQUEST, + course_stat = CourseStat.objects.get(user=user, course_id=course_id) + return course_stat.to_dict() + except CourseStat.DoesNotExist: + course_stat = CourseStat( + user=user, + course_id=course_id, + active_flags=0, + inactive_flags=0, + threads=0, + responses=0, + replies=0, + last_activity_at=None, ) + course_stat.save() + return course_stat.to_dict() - return None - - -def get_threads( - params: dict[str, Any], - user_id: str, - serializer: Any, - thread_ids: list[str], -) -> dict[str, Any]: - """get subscribed or all threads of a specific course for a specific user.""" - count_flagged = bool(params.get("count_flagged", False)) - threads = handle_threads_query( - thread_ids, - user_id, - params["course_id"], - get_group_ids_from_params(params), - params.get("author_id", ""), - params.get("thread_type"), - bool(params.get("flagged", False)), - bool(params.get("unread", False)), - bool(params.get("unanswered", False)), - bool(params.get("unresponded", False)), - count_flagged, - params.get("sort_key", ""), - int(params.get("page", 1)), - int(params.get("per_page", 100)), - ) - context: dict[str, Any] = { - "count_flagged": count_flagged, - "include_endorsed": True, - "include_read_state": True, - } - if user_id: - context["user_id"] = user_id - serializer = serializer(threads.pop("collection"), many=True, context=context) - threads["collection"] = serializer.data - return threads - - -def get_user_voted_ids(user_id: str, vote: str) -> list[str]: - """Get the IDs of the posts voted by a user.""" - if vote not in ["up", "down"]: - raise ValueError("Invalid vote type") - - vote_value = 1 if vote == "up" else -1 - voted_ids = UserVote.objects.filter(user__pk=user_id, vote=vote_value).values_list( - "content_object_id", flat=True - ) - - return list(voted_ids) - - -def filter_standalone_threads(comment_thread_ids: list[str]) -> list[str]: - """Filter out standalone threads from the list of threads.""" - comment_threads = CommentThread.objects.filter(pk__in=comment_thread_ids) - filtered_threads = [ - thread for thread in comment_threads if thread.context != "standalone" - ] - return [str(thread.pk) for thread in filtered_threads] - - -def user_to_hash( - user_id: str, params: Optional[dict[str, Any]] = None -) -> dict[str, Any]: - """ - Converts user data to a hash - """ - user = User.objects.get(pk=user_id) - forum_user = ForumUser.objects.get(user__pk=user_id) - if params is None: - params = {} - - user_data = forum_user.to_dict() - hash_data = {} - hash_data["username"] = user_data["username"] - hash_data["external_id"] = user_data["external_id"] - - if params.get("complete"): - subscribed_thread_ids = find_subscribed_threads(user_id) - upvoted_ids = get_user_voted_ids(user_id, "up") - downvoted_ids = get_user_voted_ids(user_id, "down") - hash_data.update( - { - "subscribed_thread_ids": subscribed_thread_ids, - "subscribed_commentable_ids": [], - "subscribed_user_ids": [], - "follower_ids": [], - "id": user_id, - "upvoted_ids": upvoted_ids, - "downvoted_ids": downvoted_ids, - "default_sort_key": user_data["default_sort_key"], - } + @staticmethod + def update_user_stats_for_course(user_id: str, stat: dict[str, Any]) -> None: + """Update user stats for course.""" + user = User.objects.get(pk=user_id) + try: + course_stat = CourseStat.objects.get(user=user, course_id=stat["course_id"]) + for key, value in stat.items(): + setattr(course_stat, key, value) + course_stat.save() + except CourseStat.DoesNotExist: + course_stat = CourseStat(user=user, **stat) + course_stat.save() + + @classmethod + def build_course_stats(cls, author_id: str, course_id: str) -> None: + """Build course stats.""" + author = User.objects.get(pk=author_id) + threads = CommentThread.objects.filter( + author=author, + course_id=course_id, + anonymous_to_peers=False, + anonymous=False, + ) + comments = Comment.objects.filter( + author=author, + course_id=course_id, + anonymous_to_peers=False, + anonymous=False, ) - if params.get("course_id"): - threads = CommentThread.objects.filter( - author=user, - course_id=params["course_id"], + responses = comments.filter(comment_thread__isnull=False) + replies = comments.filter(comment_thread__isnull=True) + + comment_ids = [comment.pk for comment in comments] + active_flags = AbuseFlagger.objects.filter( + content_type=ContentType.objects.get_for_model(Comment), + content_object_id__in=comment_ids, + ).count() + + inactive_flags = HistoricalAbuseFlagger.objects.filter( + content_type=ContentType.objects.get_for_model(Comment), + content_object_id__in=comment_ids, + ).count() + + threads_updated_at = threads.aggregate(Max("updated_at"))["updated_at__max"] + comments_updated_at = comments.aggregate(Max("updated_at"))["updated_at__max"] + + updated_at = max( + threads_updated_at or timezone.now() - timedelta(days=365 * 100), + comments_updated_at or timezone.now() - timedelta(days=365 * 100), + ) + + stats, _ = CourseStat.objects.get_or_create(user=author, course_id=course_id) + stats.threads = threads.count() + stats.responses = responses.count() + stats.replies = replies.count() + stats.active_flags = active_flags + stats.inactive_flags = inactive_flags + stats.last_activity_at = updated_at + stats.save() + cls.update_user_stats_for_course(author_id, stats.to_dict()) + + @classmethod + def update_all_users_in_course(cls, course_id: str) -> list[str]: + """Update all user stats in a course.""" + course_comments = Comment.objects.filter( anonymous=False, anonymous_to_peers=False, + course_id=course_id, ) - comments = Comment.objects.filter( - author=user, - course_id=params["course_id"], + course_threads = CommentThread.objects.filter( anonymous=False, anonymous_to_peers=False, + course_id=course_id, + ) + + comment_authors = set(course_comments.values_list("author__id", flat=True)) + thread_authors = set(course_threads.values_list("author__id", flat=True)) + author_ids = list(comment_authors | thread_authors) + + for author_id in author_ids: + cls.build_course_stats(author_id, course_id) + return author_ids + + @staticmethod + def get_user_by_username(username: str | None) -> dict[str, Any] | None: + """Return user from username.""" + try: + user = User.objects.get(username=username) + except User.DoesNotExist: + return None + try: + forum_user = ForumUser.objects.get(user=user) + except ForumUser.DoesNotExist: + return None + return forum_user.to_dict() + + @staticmethod + def generate_id() -> str: + """Generate a random id.""" + return str(random.randint(1, 1000000)) + + @staticmethod + def find_or_create_user( + user_id: str, + username: Optional[str] = None, + default_sort_key: Optional[str] = "date", + ) -> str: + """Find or create user.""" + username = username or user_id + try: + user = User.objects.get(pk=int(user_id)) + except User.DoesNotExist: + user = None + + if user is None: + if User.objects.filter(username=username).exists(): + raise ValueError(f"User with username {username} already exists") + user = User.objects.create(pk=int(user_id), username=username) + + forum_user, _ = ForumUser.objects.get_or_create( + user=user, defaults={"default_sort_key": default_sort_key} + ) + return forum_user.user.pk + + @staticmethod + def get_comment(comment_id: str) -> dict[str, Any] | None: + """Return comment from comment_id.""" + try: + comment = Comment.objects.get(pk=comment_id) + except Comment.DoesNotExist: + return None + return comment.to_dict() + + @staticmethod + def get_comments(**kwargs: Any) -> list[dict[str, Any]]: + """Return comments from kwargs.""" + return Comment.get_list(**kwargs) + + @staticmethod + def update_child_count_in_parent_comment(parent_id: str, count: int) -> None: + """ + Update(increment/decrement) child_count in parent comment. + + Args: + parent_id: The ID of the parent comment whose child_count will be updated. + count: It can be any number. + If positive, this function will increase child_count by the count. + If negative, this function will decrease child_count by the count. + + Returns: + None. + """ + Comment.objects.filter(pk=int(parent_id)).update( + child_count=F("child_count") + count + ) + + @classmethod + def create_comment(cls, data: dict[str, Any]) -> str: + """Handle comment creation and returns a comment.""" + comment_thread = None + parent = None + comment_thread_id = data.get("comment_thread_id") + parent_id = data.get("parent_id") + if comment_thread_id: + comment_thread = CommentThread.objects.get(pk=int(comment_thread_id)) + if parent_id: + parent = Comment.objects.get(pk=int(parent_id)) + new_comment = Comment.objects.create( + body=data.get("body"), + course_id=data.get("course_id"), + anonymous=data.get("anonymous", False), + anonymous_to_peers=data.get("anonymous_to_peers", False), + author=User.objects.get(pk=int(data["author_id"])), + comment_thread=comment_thread, + parent=parent, + depth=data.get("depth", 0), ) - comment_ids = list(comments.values_list("pk", flat=True)) - if params.get("group_ids"): - group_threads = threads.filter(group_id__in=params["group_ids"] + [None]) - group_thread_ids = [str(thread.pk) for thread in group_threads] - threads_count = len(group_thread_ids) - comment_thread_ids = filter_standalone_threads(comment_ids) - - group_comment_threads = CommentThread.objects.filter( - id__in=comment_thread_ids, group_id__in=params["group_ids"] + [None] + new_comment.sort_key = new_comment.get_sort_key() + new_comment.save() + if data.get("parent_id"): + cls.update_child_count_in_parent_comment(data["parent_id"], 1) + cls.update_stats_for_course(data["author_id"], data["course_id"], replies=1) + else: + cls.update_stats_for_course( + data["author_id"], data["course_id"], responses=1 ) - group_comment_thread_ids = [ - str(thread.pk) for thread in group_comment_threads - ] - comments_count = sum( - 1 - for comment_thread_id in comment_thread_ids - if comment_thread_id in group_comment_thread_ids + return str(new_comment.pk) + + @classmethod + def delete_comment(cls, comment_id: str) -> None: + """Delete comment from comment_id.""" + comment = Comment.objects.get(pk=comment_id) + if comment.parent: + cls.update_child_count_in_parent_comment(str(comment.parent.pk), -1) + + comment.delete() + + @staticmethod + def get_commentables_counts_based_on_type(course_id: str) -> dict[str, Any]: + """Return commentables counts in a course based on thread's type.""" + result = ( + CommentThread.objects.filter(course_id=course_id) + .values("commentable_id") + .annotate( + discussion_count=Count( + Case( + When(thread_type="discussion", then=1), + output_field=IntegerField(), + ) + ), + question_count=Count( + Case( + When(thread_type="question", then=1), + output_field=IntegerField(), + ) + ), ) - else: - thread_ids = [str(thread.pk) for thread in threads] - threads_count = len(thread_ids) - comment_thread_ids = filter_standalone_threads(comment_ids) - comments_count = len(comment_thread_ids) - - hash_data.update( - { - "threads_count": threads_count, - "comments_count": comments_count, + .order_by() + ) + commentable_counts = {} + for commentable in result: + topic_id = commentable["commentable_id"] + commentable_counts[topic_id] = { + "discussion": commentable["discussion_count"], + "question": commentable["question_count"], } + return commentable_counts + + @staticmethod + def update_comment(comment_id: str, **kwargs: Any) -> int: + """Updates a comment in the database.""" + try: + comment = Comment.objects.get(id=comment_id) + except Comment.DoesNotExist: + return 0 + + if kwargs.get("body"): + comment.body = kwargs["body"] + if kwargs.get("course_id"): + comment.course_id = kwargs["course_id"] + if kwargs.get("anonymous"): + comment.anonymous = kwargs["anonymous"] + if kwargs.get("anonymous_to_peers"): + comment.anonymous_to_peers = kwargs["anonymous_to_peers"] + if kwargs.get("comment_thread_id"): + comment.comment_thread = CommentThread.objects.get( + pk=kwargs["comment_thread_id"] + ) + if kwargs.get("visible"): + comment.visible = kwargs["visible"] + if kwargs.get("author_id"): + comment.author = User.objects.get(pk=kwargs["author_id"]) + if kwargs.get("endorsed"): + comment.endorsed = kwargs["endorsed"] + if kwargs.get("child_count"): + comment.child_count = kwargs["child_count"] + if kwargs.get("depth"): + comment.depth = kwargs["depth"] + + if kwargs.get("endorsed") and kwargs.get("endorsement_user_id"): + comment.endorsement = { + "user_id": kwargs["endorsement_user_id"], + "time": timezone.now(), + } + else: + comment.endorsement = {} + + if "abuse_flaggers" in kwargs: + existing_abuse_flaggers = AbuseFlagger.objects.filter( + content_object_id=comment.pk, + content_type=ContentType.objects.get_for_model(Comment), + ).values_list("user_id", flat=True) + + new_abuse_flaggers = [ + user_id + for user_id in kwargs["abuse_flaggers"] + if user_id not in existing_abuse_flaggers + ] + for user_id in new_abuse_flaggers: + AbuseFlagger.objects.create( + user=User.objects.get(pk=user_id), + content_object_id=comment.pk, + content_type=ContentType.objects.get_for_model(Comment), + ) + + if kwargs.get("editing_user_id"): + EditHistory.objects.create( + comment=comment, + author=User.objects.get(pk=kwargs["editing_user_id"]), + original_body=kwargs.get("original_body"), + reason_code=kwargs.get("edit_reason_code"), + created_at=timezone.now(), + ) + + comment.updated_at = timezone.now() + comment.save() + return 1 + + @staticmethod + def get_thread_id_from_comment(comment_id: str) -> dict[str, Any] | None: + """Return thread_id from comment_id.""" + comment = Comment.objects.get(pk=comment_id) + if comment.comment_thread: + return comment.comment_thread.to_dict() + raise ValueError("Comment doesn't have the thread.") + + @staticmethod + def get_user(user_id: str) -> dict[str, Any] | None: + """Return user from user_id.""" + try: + return ForumUser.objects.get(user__pk=user_id).to_dict() + except ObjectDoesNotExist: + return None + + @staticmethod + def get_thread(thread_id: str) -> dict[str, Any] | None: + """Return thread from thread_id.""" + try: + thread = CommentThread.objects.get(pk=thread_id) + except CommentThread.DoesNotExist: + return None + return thread.to_dict() + + @classmethod + def get_subscription( + cls, subscriber_id: str, source_id: str, **kwargs: Any + ) -> dict[str, Any] | None: + """Return subscription from subscriber_id and source_id.""" + source = cls._get_entity_from_type( + source_id, entity_type=kwargs.get("source_type", "") ) + if not source: + return None + try: + subscription = Subscription.objects.get( + subscriber_id=User.objects.get(pk=int(subscriber_id)), + source_object_id=source.pk, + source_content_type=source.content_type, + ) + except Subscription.DoesNotExist: + return None + return subscription.to_dict() + + @classmethod + def get_subscriptions(cls, query: dict[str, Any]) -> list[dict[str, Any]]: + """Return subscriptions from filter.""" + source = cls._get_entity_from_type( + entity_id=query["source_id"], entity_type=query.get("source_type", "") + ) + if not source: + return [] - return hash_data + subscriptions = ( + Subscription.objects.filter( + source_object_id=source.pk, + source_content_type=source.content_type, + ) + .distinct() + .order_by("subscriber_id", "source_object_id") + ) + + return [subscription.to_dict() for subscription in subscriptions] + @staticmethod + def delete_thread(thread_id: str) -> int: + """Delete thread from thread_id.""" + try: + thread = CommentThread.objects.get(pk=thread_id) + except ObjectDoesNotExist: + return 0 + thread.delete() + return 1 + + @staticmethod + def create_thread(data: dict[str, Any]) -> str: + """Create thread.""" + new_thread = CommentThread.objects.create( + title=data["title"], + body=data["body"], + course_id=data["course_id"], + anonymous=data.get("anonymous", False), + anonymous_to_peers=data.get("anonymous_to_peers", False), + author=User.objects.get(pk=int(data["author_id"])), + commentable_id=data.get("commentable_id", "course"), + thread_type=data.get("thread_type", "discussion"), + context=data.get("context", "course"), + last_activity_at=timezone.now(), + ) + return str(new_thread.pk) + + @staticmethod + def update_thread( + thread_id: str, + **kwargs: Any, + ) -> int: + """Updates a thread document in the database.""" + thread = CommentThread.objects.get(id=thread_id) + + if "thread_type" in kwargs: + thread.thread_type = kwargs["thread_type"] + if "title" in kwargs: + thread.title = kwargs["title"] + if "body" in kwargs: + thread.body = kwargs["body"] + if "course_id" in kwargs: + thread.course_id = kwargs["course_id"] + if "anonymous" in kwargs: + thread.anonymous = kwargs["anonymous"] + if "anonymous_to_peers" in kwargs: + thread.anonymous_to_peers = kwargs["anonymous_to_peers"] + if "commentable_id" in kwargs: + thread.commentable_id = kwargs["commentable_id"] + if "author_id" in kwargs and kwargs["author_id"]: + thread.author = User.objects.get(pk=int(kwargs["author_id"])) + if "closed_by_id" in kwargs and kwargs["closed_by_id"]: + thread.closed_by = User.objects.get(pk=int(kwargs["closed_by_id"])) + if "pinned" in kwargs: + thread.pinned = kwargs["pinned"] + if "close_reason_code" in kwargs: + thread.close_reason_code = kwargs["close_reason_code"] + if "closed" in kwargs: + thread.closed = kwargs["closed"] + if not kwargs["closed"]: + thread.closed_by = None # type: ignore + thread.close_reason_code = None + if "endorsed" in kwargs: + thread.endorsed = kwargs["endorsed"] + if "group_id" in kwargs: + thread.group_id = kwargs["group_id"] + if "abuse_flaggers" in kwargs: + existing_abuse_flaggers = AbuseFlagger.objects.filter( + content_object_id=thread.pk, + content_type=ContentType.objects.get_for_model(CommentThread), + ).values_list("user_id", flat=True) + + new_abuse_flaggers = [ + user_id + for user_id in kwargs["abuse_flaggers"] + if user_id not in existing_abuse_flaggers + ] -def unsubscribe_all(user_id: str) -> None: - """Unsubscribe user from all content.""" - Subscription.objects.filter(subscriber__pk=user_id).delete() + for user_id in new_abuse_flaggers: + AbuseFlagger.objects.create( + user=User.objects.get(pk=user_id), + content_object_id=thread.pk, + content_type=ContentType.objects.get_for_model(CommentThread), + ) + if "editing_user_id" in kwargs and kwargs["editing_user_id"]: + EditHistory.objects.create( + content_object_id=thread.pk, + content_type=thread.content_type, + reason_code=kwargs.get("edit_reason_code"), + original_body=kwargs.get("original_body"), + editor=User.objects.get(pk=kwargs["editing_user_id"]), + created_at=timezone.now(), + ) + thread.updated_at = timezone.now() + thread.save() + return 1 + + @staticmethod + def get_user_thread_filter(course_id: str) -> dict[str, Any]: + """Get user thread filter""" + return {"course_id": course_id} + + @staticmethod + def get_filtered_threads(query: dict[str, Any]) -> list[dict[str, Any]]: + """Return a list of threads that match the given filter.""" + threads = CommentThread.objects.filter(**query) + return [thread.to_dict() for thread in threads] + + @staticmethod + def update_user(user_id: str, data: dict[str, Any]) -> int: + """ + Updates user info and ForumUser fields. + + Args: + user_id: ID of the user to update. + data: Dictionary containing updated user info and ForumUser fields. + """ + try: + user = User.objects.get(id=user_id) + forum_user = ForumUser.objects.get(user=user) + except ObjectDoesNotExist: + return 0 + + if "username" in data: + user.username = data["username"] + if "email" in data: + user.email = data["email"] + if "default_sort_key" in data: + forum_user.default_sort_key = data["default_sort_key"] + if "read_states" in data and data["read_states"] == []: + user_read_states = ReadState.objects.filter(user=user) + user_read_states.delete() + + user.save() + forum_user.save() + return 1 + + @staticmethod + def replace_username_in_all_content(user_id: str, username: str) -> None: + """Replace the username of a Django user.""" + try: + user = User.objects.get(pk=user_id) + user.username = username + user.save() + except User.DoesNotExist as exc: + raise ValueError("User does not exist") from exc + + @staticmethod + def get_thread_id_by_comment_id(parent_comment_id: str) -> str: + """ + The thread Id from the parent comment. + """ + try: + comment = Comment.objects.get(pk=parent_comment_id) + except ObjectDoesNotExist as exc: + raise ValueError("comment does not exist.") from exc + return comment.comment_thread.pk + + @staticmethod + def update_comment_and_get_updated_comment( + comment_id: str, + body: Optional[str] = None, + course_id: Optional[str] = None, + user_id: Optional[str] = None, + anonymous: Optional[bool] = False, + anonymous_to_peers: Optional[bool] = False, + endorsed: Optional[bool] = None, + closed: Optional[bool] = False, + editing_user_id: Optional[str] = None, + edit_reason_code: Optional[str] = None, + endorsement_user_id: Optional[str] = None, + ) -> dict[str, Any] | None: + """ + Update an existing child/parent comment. + + Parameters: + comment_id: The ID of the comment to be edited. + body (Optional[str]): The content of the comment. + course_id (Optional[str]): The Id of the respective course. + user_id (Optional[str]): The requesting user id. + anonymous (Optional[bool]): anonymous flag(True or False). + anonymous_to_peers (Optional[bool]): anonymous to peers flag(True or False). + endorsed (Optional[bool]): Flag indicating if the comment is endorsed by any user. + closed (Optional[bool]): Flag indicating if the comment thread is closed. + editing_user_id (Optional[str]): The ID of the user editing the comment. + edit_reason_code (Optional[str]): The reason for editing the comment, typically represented by a code. + endorsement_user_id (Optional[str]): The ID of the user endorsing the comment. + Response: + The details of the comment that is updated. + """ + try: + comment = Comment.objects.get(id=comment_id) + except Comment.DoesNotExist: + return None + + original_body = comment.body + if body: + comment.body = body + if course_id: + comment.course_id = course_id + if user_id: + comment.author = User.objects.get(pk=user_id) + if anonymous is not None: + comment.anonymous = anonymous + if anonymous_to_peers is not None: + comment.anonymous_to_peers = anonymous_to_peers + if endorsed is not None: + comment.endorsed = endorsed + if endorsed is False: + comment.endorsement = {} + if endorsement_user_id: + comment.endorsement = { + "user_id": endorsement_user_id, + "time": str(timezone.now()), + } + + if editing_user_id: + EditHistory.objects.create( + content_object_id=comment.pk, + content_type=ContentType.objects.get_for_model(Comment), + editor=User.objects.get(pk=editing_user_id), + original_body=original_body, + reason_code=edit_reason_code, + created_at=timezone.now(), + ) -# Kept method signature same as mongo implementation -def retire_all_content(user_id: str, username: str) -> None: # pylint: disable=W0613 - """Retire all content from user.""" - comments = Comment.objects.filter(author__pk=user_id) - for comment in comments: - comment.body = RETIRED_BODY + comment.updated_at = timezone.now() comment.save() + return comment.to_dict() + + @staticmethod + def get_course_id_by_thread_id(thread_id: str) -> str | None: + """ + Return course_id for the matching thread. + """ + thread = CommentThread.objects.filter(id=thread_id).first() + if thread: + return thread.course_id + return None - comment_threads = CommentThread.objects.filter(author__pk=user_id) - for comment_thread in comment_threads: - comment_thread.body = RETIRED_BODY - comment_thread.title = RETIRED_TITLE - comment_thread.save() + @staticmethod + def get_course_id_by_comment_id(comment_id: str) -> str | None: + """ + Return course_id for the matching comment. + """ + comment = Comment.objects.filter(id=comment_id).first() + if comment: + return comment.course_id + return None + @staticmethod + def get_users(**kwargs: Any) -> list[dict[str, Any]]: + """ + Retrieves a list of users in the database based on provided filters. + + Args: + kwargs: The filter arguments. + + Returns: + A list of users. + """ + forum_users = ForumUser.objects.filter(**kwargs) + sort_key = kwargs.get("sort_key") + if sort_key: + forum_users = forum_users.order_by(sort_key) + + result = [user.to_dict() for user in forum_users] + return result + + @staticmethod + def get_user_sort_criterion(sort_by: str) -> dict[str, Any]: + """ + Get sort criterion based on sort_by parameter. + + Args: + sort_by (str): The sort_by parameter. + + Returns: + A dictionary representing the sort criterion. + """ + if sort_by == "flagged": + return {"-active_flags": None, "-inactive_flags": None, "-username": None} + elif sort_by == "recency": + return {"-last_activity_at": None, "-username": None} + else: + return { + "-threads": None, + "-responses": None, + "-replies": None, + "-username": None, + } -def find_or_create_read_state(user_id: str, thread_id: str) -> dict[str, Any]: - """Find or create user read states.""" - try: - user = User.objects.get(pk=user_id) - thread = CommentThread.objects.get(pk=thread_id) - except (User.DoesNotExist, CommentThread.DoesNotExist) as exc: - raise ObjectDoesNotExist from exc - - read_state, _ = ReadState.objects.get_or_create( - user=user, course_id=thread.course_id - ) - return read_state.to_dict() - - -def mark_as_read(user_id: str, thread_id: str) -> None: - """Mark thread as read.""" - user = User.objects.get(pk=user_id) - thread = CommentThread.objects.get(pk=thread_id) - read_state, _ = ReadState.objects.get_or_create( - user=user, course_id=thread.course_id - ) - - LastReadTime.objects.update_or_create( - read_state=read_state, - comment_thread=thread, - defaults={ - "timestamp": timezone.now(), - }, - ) - - -def find_or_create_user_stats(user_id: str, course_id: str) -> dict[str, Any]: - """Find or create user stats document.""" - user = User.objects.get(pk=user_id) - try: - course_stat = CourseStat.objects.get(user=user, course_id=course_id) - return course_stat.to_dict() - except CourseStat.DoesNotExist: - course_stat = CourseStat( - user=user, - course_id=course_id, - active_flags=0, - inactive_flags=0, - threads=0, - responses=0, - replies=0, - last_activity_at=None, + @classmethod + def get_paginated_user_stats( + cls, course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any] + ) -> dict[str, Any]: + """Get paginated user stats.""" + users = User.objects.filter( + Q(course_stats__course_id=course_id) + & Q(course_stats__course_id__isnull=False) + ).order_by( + *[key for key, value in sort_criterion.items() if value == -1], + *[key for key, value in sort_criterion.items() if value == 1], ) - course_stat.save() - return course_stat.to_dict() - - -def update_user_stats_for_course(user_id: str, stat: dict[str, Any]) -> None: - """Update user stats for course.""" - user = User.objects.get(pk=user_id) - try: - course_stat = CourseStat.objects.get(user=user, course_id=stat["course_id"]) - for key, value in stat.items(): - setattr(course_stat, key, value) - course_stat.save() - except CourseStat.DoesNotExist: - course_stat = CourseStat(user=user, **stat) - course_stat.save() + paginator = Paginator(users, per_page) + paginated_users = paginator.page(page) -def build_course_stats(author_id: str, course_id: str) -> None: - """Build course stats.""" - author = User.objects.get(pk=author_id) - threads = CommentThread.objects.filter( - author=author, - course_id=course_id, - anonymous_to_peers=False, - anonymous=False, - ) - comments = Comment.objects.filter( - author=author, - course_id=course_id, - anonymous_to_peers=False, - anonymous=False, - ) - - responses = comments.filter(comment_thread__isnull=False) - replies = comments.filter(comment_thread__isnull=True) - - active_flags = comments.filter(abuse_flaggers__isnull=False).count() - inactive_flags = comments.filter(historical_abuse_flaggers__isnull=False).count() - - updated_at = max( - threads.aggregate(Max("updated_at"))["updated_at__max"] or datetime(1970, 1, 1), - comments.aggregate(Max("updated_at"))["updated_at__max"] - or datetime(1970, 1, 1), - ) - - stats, _ = CourseStat.objects.get_or_create(user=author, course_id=course_id) - stats.threads = threads.count() - stats.responses = responses.count() - stats.replies = replies.count() - stats.active_flags = active_flags - stats.inactive_flags = inactive_flags - stats.last_activity_at = updated_at - stats.save() - - -def update_all_users_in_course(course_id: str) -> list[str]: - """Update all user stats in a course.""" - course_comments = Comment.objects.filter( - anonymous=False, - anonymous_to_peers=False, - course_id=course_id, - ) - course_threads = CommentThread.objects.filter( - anonymous=False, - anonymous_to_peers=False, - course_id=course_id, - ) - - comment_authors = set(course_comments.values_list("author__id", flat=True)) - thread_authors = set(course_threads.values_list("author__id", flat=True)) - author_ids = list(comment_authors | thread_authors) - - for author_id in author_ids: - build_course_stats(author_id, course_id) - return author_ids - - -def get_user_by_username(username: str | None) -> dict[str, Any] | None: - """Return user from username.""" - try: - return ForumUser.objects.get(user__username=username).to_dict() - except User.DoesNotExist: - return None + return { + "total_count": paginator.count, + "data": paginated_users.object_list, + } + @staticmethod + def get_contents(**kwargs: Any) -> list[dict[str, Any]]: + """ + Retrieves a list of comments and comment threads in the database based on provided filters. -def find_or_create_user(user_id: str) -> str: - """Find or create user.""" - user, _ = ForumUser.objects.get_or_create(user__pk=user_id) - return user.pk + Args: + kwargs: The filter arguments. + Returns: + A list of comments and comment threads. + """ + comment_filters = { + key: value for key, value in kwargs.items() if hasattr(Comment, key) + } + thread_filters = { + key: value for key, value in kwargs.items() if hasattr(CommentThread, key) + } -def get_course_id_by_thread_id(thread_id: str) -> str | None: - """ - Return course_id for the matching thread. - """ - thread = CommentThread.objects.filter(id=thread_id).first() - if thread: - return thread.course_id - return None + comments = Comment.objects.filter(**comment_filters) + threads = CommentThread.objects.filter(**thread_filters) + sort_key = kwargs.get("sort_key") + if sort_key: + comments = comments.order_by(sort_key) + threads = threads.order_by(sort_key) -def get_course_id_by_comment_id(comment_id: str) -> str | None: - """ - Return course_id for the matching comment. - """ - comment = Comment.objects.filter(id=comment_id).first() - if comment: - return comment.course_id - return None + result = [content.to_dict() for content in list(comments) + list(threads)] + return result diff --git a/forum/backends/mysql/models.py b/forum/backends/mysql/models.py index 584179d1..bb90e0e9 100644 --- a/forum/backends/mysql/models.py +++ b/forum/backends/mysql/models.py @@ -39,6 +39,7 @@ def to_dict(self) -> dict[str, Any]: "default_sort_key": self.default_sort_key, "external_id": self.user.pk, "username": self.user.username, + "email": self.user.email, "course_stats": [stat.to_dict() for stat in course_stats], "read_states": [state.to_dict() for state in read_states], } @@ -158,7 +159,7 @@ def get_votes(self) -> dict[str, Any]: "count": 0, "point": 0, } - for vote in self.votes.all(): + for vote in self.votes: if vote.vote == 1: votes["up"].append(vote.user.pk) votes["up_count"] += 1 @@ -169,6 +170,10 @@ def get_votes(self) -> dict[str, Any]: votes["count"] = votes["count"] return votes + def to_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the content.""" + raise NotImplementedError + class Meta: app_label = "forum" abstract = True @@ -198,9 +203,6 @@ class CommentThread(Content): pinned: models.BooleanField[Optional[bool], bool] = models.BooleanField( null=True, blank=True ) - comment_count: models.PositiveIntegerField[int, int] = models.PositiveIntegerField( - default=0 - ) last_activity_at: models.DateTimeField[Optional[datetime], datetime] = ( models.DateTimeField(null=True, blank=True) ) @@ -214,6 +216,17 @@ class CommentThread(Content): blank=True, on_delete=models.SET_NULL, ) + commentable_id: models.CharField[str, str] = models.CharField( + max_length=255, + default=None, + blank=True, + null=True, + ) + + @property + def comment_count(self) -> int: + """Return the number of comments in the thread.""" + return Comment.objects.filter(comment_thread=self).count() @classmethod def get(cls, thread_id: str) -> CommentThread: @@ -241,18 +254,25 @@ def to_dict(self) -> dict[str, Any]: "_id": str(self.pk), "votes": self.get_votes, "visible": self.visible, - "abuse_flaggers": self.abuse_flaggers, - "historical_abuse_flaggers": self.historical_abuse_flaggers, + "abuse_flaggers": [str(flagger) for flagger in self.abuse_flaggers], + "historical_abuse_flaggers": [ + str(flagger) for flagger in self.historical_abuse_flaggers + ], "thread_type": self.thread_type, + "_type": "CommentThread", + "commentable_id": self.commentable_id, "context": self.context, "comment_count": self.comment_count, "at_position_list": [], + "pinned": self.pinned if self.pinned else False, "title": self.title, "body": self.body, "course_id": self.course_id, "anonymous": self.anonymous, "anonymous_to_peers": self.anonymous_to_peers, "closed": self.closed, + "closed_by_id": str(self.closed_by.pk) if self.closed_by else None, + "close_reason_code": self.close_reason_code, "author_id": self.author.pk, "author_username": self.author.username, "updated_at": self.updated_at.isoformat() if self.updated_at else None, @@ -297,11 +317,53 @@ class Comment(Content): parent: models.ForeignKey[Comment, Comment] = models.ForeignKey( "self", on_delete=models.CASCADE, null=True, blank=True ) + depth: models.PositiveIntegerField[int, int] = models.PositiveIntegerField( + default=0 + ) + + def get_sort_key(self) -> str: + """Get the sort key for the comment""" + if self.parent: + return f"{self.parent.pk}-{self.pk}" + return str(self.pk) + + @staticmethod + def get_list(**kwargs: Any) -> list[dict[str, Any]]: + """ + Retrieves a list of all comments in the database based on provided filters. + + Args: + kwargs: The filter arguments. + + Returns: + A list of comments. + """ + sort = kwargs.pop("sort", None) + comments = Comment.objects.filter(**kwargs) + if sort: + if sort == 1: + result = sorted( + comments, key=lambda x: (x.sort_key is None, x.sort_key or "") + ) + elif sort == -1: + result = sorted( + comments, + key=lambda x: (x.sort_key is None, x.sort_key or ""), + reverse=True, + ) + return [content.to_dict() for content in result] + + def get_parent_ids(self) -> list[str]: + """Return a list of all parent IDs of a comment.""" + parent_ids = [] + current_comment = self + while current_comment.parent: + parent_ids.append(str(current_comment.parent.pk)) + current_comment = current_comment.parent + return parent_ids def to_dict(self) -> dict[str, Any]: """Return a dictionary representation of the model.""" - abuse_flaggers = self.abuse_flaggers - historical_abuse_flaggers = self.historical_abuse_flaggers edit_history = [] for edit in self.edit_history.all(): edit_history.append( @@ -316,18 +378,22 @@ def to_dict(self) -> dict[str, Any]: ), } ) + endorsement = { "user_id": self.endorsement.get("user_id") if self.endorsement else None, "time": self.endorsement.get("time") if self.endorsement else None, } - return { + data = { "_id": str(self.pk), "votes": self.get_votes, "visible": self.visible, - "abuse_flaggers": abuse_flaggers, - "historical_abuse_flaggers": historical_abuse_flaggers, - "parent_ids": [], + "abuse_flaggers": [str(flagger) for flagger in self.abuse_flaggers], + "historical_abuse_flaggers": [ + str(flagger) for flagger in self.historical_abuse_flaggers + ], + "parent_ids": self.get_parent_ids(), + "parent_id": str(self.parent.pk) if self.parent else "None", "at_position_list": [], "body": self.body, "course_id": self.course_id, @@ -342,9 +408,12 @@ def to_dict(self) -> dict[str, Any]: "sk": str(self.pk), "updated_at": self.updated_at.isoformat() if self.updated_at else None, "created_at": self.created_at.isoformat() if self.created_at else None, - "edit_history": edit_history, - "endorsement": endorsement, + "endorsement": endorsement if self.endorsement else None, } + if edit_history: + data["edit_history"] = edit_history + + return data @classmethod def get(cls, comment_id: str) -> Comment: diff --git a/forum/migrations/0001_initial.py b/forum/migrations/0001_initial.py index 870a4552..db34332f 100644 --- a/forum/migrations/0001_initial.py +++ b/forum/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.14 on 2024-09-23 06:12 +# Generated by Django 4.2.14 on 2024-10-10 15:57 from django.conf import settings from django.db import migrations, models @@ -60,12 +60,17 @@ class Migration(migrations.Migration): ), ("closed", models.BooleanField(default=False)), ("pinned", models.BooleanField(blank=True, null=True)), - ("comment_count", models.PositiveIntegerField(default=0)), ("last_activity_at", models.DateTimeField(blank=True, null=True)), ( "close_reason_code", models.CharField(blank=True, max_length=255, null=True), ), + ( + "commentable_id", + models.CharField( + blank=True, default=None, max_length=255, null=True + ), + ), ( "author", models.ForeignKey( @@ -303,6 +308,7 @@ class Migration(migrations.Migration): "retired_username", models.CharField(blank=True, max_length=255, null=True), ), + ("depth", models.PositiveIntegerField(default=0)), ( "author", models.ForeignKey( diff --git a/forum/serializers/comment.py b/forum/serializers/comment.py index 7d95cb45..4c68b0f1 100644 --- a/forum/serializers/comment.py +++ b/forum/serializers/comment.py @@ -4,10 +4,8 @@ from typing import Any -from bson import ObjectId from rest_framework import serializers -from forum.backends.mongodb import Comment, CommentThread from forum.serializers.contents import ContentSerializer from forum.serializers.custom_datetime import CustomDateTimeField from forum.utils import prepare_comment_data_for_get_children @@ -61,6 +59,7 @@ class CommentSerializer(ContentSerializer): def __init__(self, *args: Any, **kwargs: Any) -> None: exclude_fields = kwargs.pop("exclude_fields", None) + self.backend = kwargs.pop("backend") super().__init__(*args, **kwargs) if exclude_fields: for field in exclude_fields: @@ -71,12 +70,10 @@ def get_children(self, obj: Any) -> list[dict[str, Any]]: if not self.context.get("recursive", False): return [] - children = list( - Comment().get_list( - parent_id=ObjectId(obj["_id"]), - depth=1, - sort=self.context.get("sort", -1), - ) + children = self.backend.get_comments( + parent_id=obj["_id"], + depth=1, + sort=self.context.get("sort", -1), ) children_data = prepare_comment_data_for_get_children(children) serializer = CommentSerializer( @@ -84,6 +81,7 @@ def get_children(self, obj: Any) -> list[dict[str, Any]]: many=True, context={"recursive": False}, exclude_fields=["sk"], + backend=self.backend, ) return list(serializer.data) @@ -93,8 +91,8 @@ def to_representation(self, instance: Any) -> dict[str, Any]: if comment["parent_id"] == "None": comment["parent_id"] = None - thread = CommentThread().get(comment["thread_id"]) - comment_from_db = Comment().get(comment["id"]) + thread = self.backend.get_thread(comment["thread_id"]) + comment_from_db = self.backend.get_comment(comment["id"]) if ( not comment["endorsed"] and comment_from_db diff --git a/forum/serializers/thread.py b/forum/serializers/thread.py index 38db25d7..12997dc5 100644 --- a/forum/serializers/thread.py +++ b/forum/serializers/thread.py @@ -4,18 +4,9 @@ from typing import Any, Optional -from bson import ObjectId -from pymongo import ASCENDING, DESCENDING from rest_framework import serializers from rest_framework.serializers import ValidationError -from forum.backends.mongodb import Comment -from forum.backends.mongodb.api import ( - get_abuse_flagged_count, - get_endorsed, - get_read_states, - get_username_from_id, -) from forum.serializers.comment import CommentSerializer from forum.serializers.contents import ContentSerializer from forum.serializers.custom_datetime import CustomDateTimeField @@ -57,7 +48,7 @@ class ThreadSerializer(ContentSerializer): thread_type = serializers.CharField() title = serializers.CharField() context = serializers.CharField() # type: ignore - last_activity_at = CustomDateTimeField() + last_activity_at = CustomDateTimeField(allow_null=True, default=None) closed_by_id = serializers.CharField(allow_null=True, default=None) closed_by = serializers.SerializerMethodField() close_reason_code = serializers.CharField(allow_null=True, default=None) @@ -90,6 +81,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: - 'include_read_state' (bool): Whether to include read state information. """ self.context_data = kwargs.get("context", {}) + self.backend = kwargs.pop("backend") self.with_responses = self.context_data.pop("with_responses", False) self.count_flagged = self.context_data.pop("count_flagged", False) self.include_endorsed = self.context_data.pop("include_endorsed", False) @@ -133,9 +125,9 @@ def get_read(self, obj: dict[str, Any]) -> Optional[bool]: user_id = self.context_data.get("user_id", None) course_id = obj["course_id"] thread_key = obj["_id"] - is_read, _ = get_read_states([obj], user_id, course_id).get( - thread_key, (False, obj["comment_count"]) - ) + is_read, _ = self.backend.get_read_states( + [obj["_id"]], user_id, course_id + ).get(thread_key, (False, obj["comment_count"])) return is_read return None @@ -155,9 +147,9 @@ def get_unread_comments_count(self, obj: dict[str, Any]) -> Optional[int]: user_id = self.context_data.get("user_id", None) course_id = obj["course_id"] thread_key = obj["_id"] - _, unread_count = get_read_states([obj], user_id, course_id).get( - thread_key, (False, obj["comment_count"]) - ) + _, unread_count = self.backend.get_read_states( + [obj["_id"]], user_id, course_id + ).get(thread_key, (False, obj["comment_count"])) return unread_count return None @@ -175,7 +167,7 @@ def get_endorsed(self, obj: dict[str, Any]) -> Optional[bool]: if isinstance(obj, dict) and obj.get("endorsed") is not None: return obj.get("endorsed", True) thread_key = obj["_id"] - return get_endorsed([thread_key]).get(thread_key, False) + return self.backend.get_endorsed([thread_key]).get(thread_key, False) return None def get_abuse_flagged_count(self, obj: dict[str, Any]) -> int: @@ -192,7 +184,7 @@ def get_abuse_flagged_count(self, obj: dict[str, Any]) -> int: if isinstance(obj, dict) and obj.get("abuse_flagged_count") is not None: return obj.get("abuse_flagged_count", 0) thread_key = obj["_id"] - return get_abuse_flagged_count([thread_key]).get(thread_key, 0) + return self.backend.get_abuse_flagged_count([thread_key]).get(thread_key, 0) return 0 def get_children(self, obj: dict[str, Any]) -> Optional[Any]: @@ -206,18 +198,12 @@ def get_children(self, obj: dict[str, Any]) -> Optional[Any]: Optional[Any]: The responses or children related to the thread, or None if not included. """ if self.with_responses: - sorting_order = ( - DESCENDING - if self.context_data.get("reverse_order", True) - else ASCENDING - ) - children = list( - Comment().get_list( - comment_thread_id=ObjectId(obj["_id"]), - depth=0, - parent_id=None, - sort=sorting_order, - ) + sorting_order = -1 if self.context_data.get("reverse_order", True) else 1 + children = self.backend.get_comments( + comment_thread_id=obj["_id"], + depth=0, + parent_id=None, + sort=sorting_order, ) children_data = prepare_comment_data_for_get_children(children) serializer = CommentSerializer( @@ -228,6 +214,7 @@ def get_children(self, obj: dict[str, Any]) -> Optional[Any]: "sort": sorting_order, }, exclude_fields=["sk"], + backend=self.backend, ) if not serializer.is_valid(raise_exception=True): raise ValidationError(serializer.errors) @@ -299,5 +286,5 @@ def update(self, instance: Any, validated_data: dict[str, Any]) -> Any: def get_closed_by(self, obj: dict[str, Any]) -> Optional[str]: """Retrieve the username of the person who closed the object.""" if closed_by_id := obj.get("closed_by_id"): - return get_username_from_id(closed_by_id) + return self.backend.get_username_from_id(closed_by_id) return None diff --git a/forum/serializers/users.py b/forum/serializers/users.py index 5ba4a6b6..1637087e 100644 --- a/forum/serializers/users.py +++ b/forum/serializers/users.py @@ -10,6 +10,7 @@ class UserSerializer(serializers.Serializer[Any]): id = serializers.CharField(allow_null=True) username = serializers.CharField() + email = serializers.CharField(allow_null=True) external_id = serializers.CharField() subscribed_thread_ids = serializers.ListField( child=serializers.CharField(), default=[] diff --git a/forum/settings/common.py b/forum/settings/common.py index 67294ea9..f54c715f 100644 --- a/forum/settings/common.py +++ b/forum/settings/common.py @@ -25,3 +25,5 @@ def plugin_settings(settings: Any) -> None: settings.FEATURES["ENABLE_DISCUSSION_SERVICE"] = True # URL prefix must match the regex in the url_config of the plugin app settings.COMMENTS_SERVICE_URL = "http://localhost:8000/forum" + + settings.USE_TZ = True diff --git a/forum/settings/test.py b/forum/settings/test.py index 6b1856da..a9d17f91 100644 --- a/forum/settings/test.py +++ b/forum/settings/test.py @@ -76,3 +76,5 @@ def root(*args: str) -> str: ] else: FORUM_ELASTIC_SEARCH_CONFIG = [{}] + +USE_TZ = True diff --git a/forum/views/commentables.py b/forum/views/commentables.py index 058114ea..793e331b 100644 --- a/forum/views/commentables.py +++ b/forum/views/commentables.py @@ -6,7 +6,7 @@ from rest_framework.response import Response from rest_framework.views import APIView -from forum.api import get_commentables_stats +from forum.api.commentables import get_commentables_stats class CommentablesCountAPIView(APIView): diff --git a/forum/views/users.py b/forum/views/users.py index 5d2d1f94..1a8f94a4 100644 --- a/forum/views/users.py +++ b/forum/views/users.py @@ -88,7 +88,7 @@ def post(self, request: Request) -> Response: data: dict[str, Any] = { "user_id": params.get("id"), "username": params.get("username"), - "default_sort_key": params.get("default_sort_key"), + "default_sort_key": params.get("default_sort_key", "date"), "course_id": params.get("course_id"), "group_ids": params.get("group_ids"), "complete": params.get("complete"), diff --git a/tests/conftest.py b/tests/conftest.py index 151929f7..094fa2c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,13 +2,15 @@ Init file for tests. """ -from typing import Any, Generator +from typing import Any, Callable, Generator from unittest.mock import patch import mongomock import pytest from pymongo import MongoClient +from forum.backends.mysql.api import MySQLBackend +from forum.backends.mongodb.api import MongoBackend from test_utils.client import APIClient from test_utils.mock_es_backend import MockElasticsearchBackend @@ -34,3 +36,15 @@ def mock_elasticsearch_backend() -> Generator[Any, Any, Any]: """Mock the dummy elastic search.""" with patch("forum.search.backend.ElasticsearchBackend", MockElasticsearchBackend): yield + + +@pytest.fixture(params=[MongoBackend, MySQLBackend], autouse=True) +def patch_get_backend(request: pytest.FixtureRequest) -> Generator[Any, Any, Any]: + """Mock the get_backend function for both Mongo and MySQL backends.""" + backend_class = request.param + + def backend_factory() -> Callable[[], MongoBackend | MySQLBackend]: + return backend_class() + + with patch("forum.backend.get_backend", return_value=backend_factory): + yield diff --git a/tests/e2e/test_search.py b/tests/e2e/test_search.py index ea52e749..3378b368 100644 --- a/tests/e2e/test_search.py +++ b/tests/e2e/test_search.py @@ -9,7 +9,7 @@ from requests import Response from forum.backends.mongodb import Comment, CommentThread, Users -from forum.backends.mongodb.api import mark_as_read +from forum.backends.mongodb.api import MongoBackend as backend from forum.search.backend import get_search_backend from test_utils.client import APIClient @@ -309,10 +309,7 @@ def test_filter_threads_by_unread(api_client: APIClient) -> None: course_id_0, course_id_1 ) refresh_elastic_search_indices() - - user = Users().get(_id=user_id) or {} - thread = CommentThread().get(_id=threads_ids[0]) or {} - mark_as_read(user, thread) + backend.mark_as_read(user_id, threads_ids[0]) params = { "text": "text", diff --git a/tests/e2e/test_users.py b/tests/e2e/test_users.py index f14a1936..d64720bb 100644 --- a/tests/e2e/test_users.py +++ b/tests/e2e/test_users.py @@ -11,7 +11,7 @@ from faker import Faker from forum.backends.mongodb import Comment, CommentThread, Users -from forum.backends.mongodb.api import build_course_stats +from forum.backends.mongodb.api import MongoBackend as backend from test_utils.client import APIClient fake = Faker() @@ -145,7 +145,7 @@ def build_structure_and_response( if build_initial_stats: for author in authors: - build_course_stats(author["_id"], course_id) + backend.build_course_stats(author["_id"], course_id) return expected_data diff --git a/tests/test_backends/test_mysql/test_api.py b/tests/test_backends/test_mysql/test_api.py index ac568bdc..833ee480 100644 --- a/tests/test_backends/test_mysql/test_api.py +++ b/tests/test_backends/test_mysql/test_api.py @@ -3,13 +3,11 @@ import pytest from django.contrib.auth import get_user_model -from forum.backends.mysql.models import AbuseFlagger, CommentThread -from forum.backends.mysql.api import ( - flag_as_abuse, - un_flag_all_as_abuse, - un_flag_as_abuse, +from forum.backends.mysql.models import ( + AbuseFlagger, + CommentThread, ) - +from forum.backends.mysql.api import MySQLBackend as backend User = get_user_model() @@ -27,14 +25,14 @@ def test_flag_as_abuse() -> None: thread_type="discussion", context="course", ) - flagged_comment_thread = flag_as_abuse( + flagged_comment_thread = backend.flag_as_abuse( str(flag_user.pk), str(comment_thread.pk), - comment_thread.type, + entity_type=comment_thread.type, ) assert flagged_comment_thread["_id"] == str(comment_thread.pk) - assert flagged_comment_thread["abuse_flaggers"] == [flag_user.pk] + assert flagged_comment_thread["abuse_flaggers"] == [str(flag_user.pk)] @pytest.mark.django_db @@ -51,10 +49,10 @@ def test_un_flag_as_abuse_success() -> None: ) AbuseFlagger.objects.create(user=user, content=comment_thread) comment_thread.save() - un_flagged_entity = un_flag_as_abuse( + un_flagged_entity = backend.un_flag_as_abuse( user.pk, comment_thread.pk, - comment_thread.type, + entity_type=comment_thread.type, ) assert user.pk not in comment_thread.abuse_flaggers @@ -80,9 +78,9 @@ def test_un_flag_all_as_abuse_historical_flags_updated() -> None: context="course", ) AbuseFlagger.objects.create(user=user, content=comment_thread) - un_flagged_comment_thread = un_flag_all_as_abuse( + un_flagged_comment_thread = backend.un_flag_all_as_abuse( comment_thread.pk, - comment_thread.type, + entity_type=comment_thread.type, ) assert un_flagged_comment_thread["_id"] == str(comment_thread.pk) diff --git a/tests/test_views/test_commentables.py b/tests/test_views/test_commentables.py index 44476d0b..cd5710ec 100644 --- a/tests/test_views/test_commentables.py +++ b/tests/test_views/test_commentables.py @@ -3,46 +3,56 @@ import random import uuid -from forum.backends.mongodb import CommentThread +import pytest + +from forum.backend import get_backend from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + def test_get_commentables_counts_api(api_client: APIClient) -> None: """ Test retrieving counts of discussion and question threads for multiple commentables within a course. """ + username = "test_user" + user_id = backend.find_or_create_user("1", username=username) course_id = "abcd" id_map = {} - for _ in range(5): commentable_id = str(uuid.uuid4()) question_count = random.randint(5, 15) discussion_count = random.randint(5, 15) for _ in range(question_count): - CommentThread().insert( - title="Question Thread", - body="This is a question thread.", - course_id=course_id, - commentable_id=commentable_id, - thread_type="question", - author_id="test_user_id", - author_username="test_user", - abuse_flaggers=[], - historical_abuse_flaggers=[], + backend.create_thread( + { + "title": "Question Thread", + "body": "This is a question thread.", + "course_id": course_id, + "commentable_id": commentable_id, + "thread_type": "question", + "author_id": user_id, + "author_username": username, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + } ) for _ in range(discussion_count): - CommentThread().insert( - title="Discussion Thread", - body="This is a discussion thread.", - course_id=course_id, - commentable_id=commentable_id, - thread_type="discussion", - author_id="test_user_id", - author_username="test_user", - abuse_flaggers=[], - historical_abuse_flaggers=[], + backend.create_thread( + { + "title": "Discussion Thread", + "body": "This is a discussion thread.", + "course_id": course_id, + "commentable_id": commentable_id, + "thread_type": "discussion", + "author_id": user_id, + "author_username": username, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + } ) id_map[commentable_id] = { diff --git a/tests/test_views/test_comments.py b/tests/test_views/test_comments.py index 2cdb0745..32f106c7 100644 --- a/tests/test_views/test_comments.py +++ b/tests/test_views/test_comments.py @@ -1,8 +1,13 @@ """Test comments api endpoints.""" -from forum.backends.mongodb import Comment, CommentThread, Users +import pytest + +from forum.backend import get_backend from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + def setup_models() -> tuple[str, str, str]: """ @@ -15,23 +20,28 @@ def setup_models() -> tuple[str, str, str]: user_id = "1" username = "user1" course_id = "course-xyz" - Users().insert(user_id, username=username, email="email1") - comment_thread_id = CommentThread().insert( - title="Thread 1", - body="Thread 1", - course_id=course_id, - commentable_id="CommentThread", - author_id=user_id, - author_username=username, - abuse_flaggers=[], - historical_abuse_flaggers=[], + backend.find_or_create_user(user_id, username=username) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Thread 1", + "course_id": course_id, + "commentable_id": "CommentThread", + "author_id": user_id, + "author_username": username, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + } ) - parent_comment_id = Comment().insert( - body="

Parent Comment

", - course_id=course_id, - author_id=user_id, - comment_thread_id=comment_thread_id, - author_username=username, + + parent_comment_id = backend.create_comment( + { + "body": "

Parent Comment

", + "course_id": course_id, + "author_id": user_id, + "comment_thread_id": comment_thread_id, + "author_username": username, + } ) return user_id, comment_thread_id, parent_comment_id @@ -57,7 +67,7 @@ def test_comment_post_api(api_client: APIClient) -> None: assert comment["user_id"] == user_id assert comment["thread_id"] == thread_id assert comment["parent_id"] == parent_comment_id - parent_comment = Comment().get(parent_comment_id) + parent_comment = backend.get_comment(parent_comment_id) assert parent_comment is not None assert parent_comment["child_count"] == 1 @@ -91,7 +101,7 @@ def test_update_comment_endorsed_api(api_client: APIClient) -> None: data={"endorsed": "True", "endorsement_user_id": user_id}, ) assert response.status_code == 200 - comment = Comment().get(parent_comment_id) + comment = backend.get_comment(parent_comment_id) assert comment is not None assert comment["endorsed"] is True assert comment["endorsement"]["user_id"] == user_id @@ -101,7 +111,7 @@ def test_update_comment_endorsed_api(api_client: APIClient) -> None: data={"endorsed": "False"}, ) assert response.status_code == 200 - comment = Comment().get(parent_comment_id) + comment = backend.get_comment(parent_comment_id) assert comment is not None assert comment["endorsed"] is False assert comment["endorsement"] is None @@ -125,7 +135,7 @@ def test_delete_parent_comment(api_client: APIClient) -> None: assert response.status_code == 200 response = api_client.delete_json(f"/api/v2/comments/{parent_comment_id}") assert response.status_code == 200 - assert Comment().get(parent_comment_id) is None + assert backend.get_comment(parent_comment_id) is None def test_delete_child_comment(api_client: APIClient) -> None: @@ -147,15 +157,15 @@ def test_delete_child_comment(api_client: APIClient) -> None: child_comment_id = response.json()["id"] assert child_comment_id is not None - parent_comment = Comment().get(parent_comment_id) or {} + parent_comment = backend.get_comment(parent_comment_id) or {} previous_child_count = parent_comment.get("child_count") response = api_client.delete_json(f"/api/v2/comments/{child_comment_id}") assert previous_child_count is not None assert response.status_code == 200 - assert Comment().get(child_comment_id) is None + assert backend.get_comment(child_comment_id) is None - parent_comment = Comment().get(parent_comment_id) or {} + parent_comment = backend.get_comment(parent_comment_id) or {} new_child_count = parent_comment.get("child_count") assert new_child_count is not None @@ -163,7 +173,7 @@ def test_delete_child_comment(api_client: APIClient) -> None: def test_returns_400_when_comment_does_not_exist(api_client: APIClient) -> None: - incorrect_comment_id = "66c42d4aa3a68c001c6c22db" + incorrect_comment_id = backend.generate_id() response = api_client.get_json(f"/api/v2/comments/{incorrect_comment_id}", {}) assert response.status_code == 400 @@ -191,12 +201,12 @@ def test_updates_body_correctly(api_client: APIClient) -> None: Test updating the body of a comment. """ _, _, parent_comment_id = setup_models() - comment = Comment().get(parent_comment_id) + comment = backend.get_comment(parent_comment_id) assert comment is not None original_body = comment["body"] editing_user_id = "2" editing_username = "user2" - Users().insert(editing_user_id, username=editing_username, email="email2") + backend.find_or_create_user(editing_user_id, username=editing_username) edit_reason_code = "test_reason" new_body = "new body" response = api_client.put_json( @@ -209,7 +219,7 @@ def test_updates_body_correctly(api_client: APIClient) -> None: ) assert response.status_code == 200 - updated_comment = Comment().get(parent_comment_id) + updated_comment = backend.get_comment(parent_comment_id) assert updated_comment is not None assert updated_comment["body"] == new_body edit_history = updated_comment["edit_history"] @@ -231,7 +241,7 @@ def test_updates_body_correctly_without_user_id(api_client: APIClient) -> None: data={"body": new_body}, ) assert response.status_code == 200 - updated_comment = Comment().get(parent_comment_id) + updated_comment = backend.get_comment(parent_comment_id) assert updated_comment is not None assert updated_comment["body"] == new_body assert ("edit_history" not in updated_comment) is True @@ -249,7 +259,7 @@ def test_update_endorsed_and_body_simultaneously(api_client: APIClient) -> None: data={"endorsed": "True", "body": new_body}, ) assert response.status_code == 200 - updated_comment = Comment().get(parent_comment_id) + updated_comment = backend.get_comment(parent_comment_id) assert updated_comment is not None assert updated_comment["body"] == new_body assert updated_comment["endorsement"] is None @@ -277,6 +287,6 @@ def test_thread_comment_post_api(api_client: APIClient) -> None: assert comment["user_id"] == user_id assert comment["thread_id"] == thread_id assert comment["parent_id"] is None - parent_comment = Comment().get(comment["id"]) + parent_comment = backend.get_comment(comment["id"]) assert parent_comment is not None assert parent_comment["child_count"] == 0 diff --git a/tests/test_views/test_flags.py b/tests/test_views/test_flags.py index 1ccd8145..c4a135b1 100644 --- a/tests/test_views/test_flags.py +++ b/tests/test_views/test_flags.py @@ -1,10 +1,13 @@ """Test flags api endpoints.""" -from bson import ObjectId +import pytest -from forum.backends.mongodb import Comment, CommentThread, Users +from forum.backend import get_backend from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + def test_comment_thread_api(api_client: APIClient) -> None: """ @@ -12,19 +15,21 @@ def test_comment_thread_api(api_client: APIClient) -> None: This test checks that a user can flag a comment thread for abuse and then unflag it. """ - flag_user = str(ObjectId()) - author_user = str(ObjectId()) - Users().insert(flag_user, flag_user) - Users().insert(author_user, author_user) - comment_thread_id = CommentThread().insert( - title="Thread 1", - body="Body 1", - course_id="course1", - commentable_id="3", - author_id=author_user, - author_username=author_user, - abuse_flaggers=[], - historical_abuse_flaggers=[], + flag_user = backend.generate_id() + author_user = backend.generate_id() + backend.find_or_create_user(flag_user, flag_user) + backend.find_or_create_user(author_user, author_user) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": "course1", + "commentable_id": "3", + "author_id": author_user, + "author_username": author_user, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + } ) response = api_client.put_json( @@ -51,7 +56,7 @@ def test_comment_thread_api(api_client: APIClient) -> None: data={"user_id": str(flag_user)}, ) assert response.status_code == 200 - comment = CommentThread().get(comment_thread_id) + comment = backend.get_thread(comment_thread_id) assert comment is not None assert comment["abuse_flaggers"] == [] @@ -62,20 +67,34 @@ def test_comment_flag_api(api_client: APIClient) -> None: This test checks that a user can flag a comment for abuse and then unflag it. """ - flag_user = str(ObjectId()) - author_user = str(ObjectId()) - Users().insert(flag_user, flag_user) - Users().insert(author_user, author_user) - comment_thread_id = str(ObjectId()) + flag_user = backend.generate_id() + author_user = backend.generate_id() + backend.find_or_create_user(flag_user, flag_user) + backend.find_or_create_user(author_user, author_user) course_id = "course-xyz" - comment_id = Comment().insert( - "

Comment 1

", - course_id, - author_user, - comment_thread_id=comment_thread_id, - author_username=author_user, - abuse_flaggers=[], - historical_abuse_flaggers=[], + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "commentable_id": "3", + "author_id": author_user, + "author_username": author_user, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + } + ) + comment_id = backend.create_comment( + { + "body": "

Comment 1

", + "course_id": course_id, + "author_id": author_user, + "comment_thread_id": comment_thread_id, + "author_username": author_user, + "anonymous": False, + "anonymous_to_peers": False, + "depth": 0, + } ) response = api_client.put_json( @@ -91,7 +110,7 @@ def test_comment_flag_api(api_client: APIClient) -> None: data={"user_id": str(flag_user)}, ) assert response.status_code == 200 - comment = Comment().get(comment_id) + comment = backend.get_comment(comment_id) assert comment is not None assert comment["abuse_flaggers"] == [] @@ -100,7 +119,7 @@ def test_comment_flag_api(api_client: APIClient) -> None: data={"user_id": str(flag_user)}, ) assert response.status_code == 200 - comment = Comment().get(comment_id) + comment = backend.get_comment(comment_id) assert comment is not None assert comment["abuse_flaggers"] == [] @@ -111,11 +130,12 @@ def test_comment_flag_api_invalid_data(api_client: APIClient) -> None: This test checks that the API returns a 400 error when the user or comment does not exist. """ - user = str(ObjectId()) - Users().insert(user, user) + user = backend.generate_id() + backend.find_or_create_user(user) + comment_id = backend.generate_id() response = api_client.put_json( - path="/api/v2/comments/66ace22474ba69001e1440bd/abuse_flag", + path=f"/api/v2/comments/{comment_id}/abuse_flag", data={"user_id": str(user)}, ) assert response.status_code == 400 @@ -128,32 +148,36 @@ def test_comment_flag_api_with_all_param(api_client: APIClient) -> None: This test checks that a user can flag a comment for abuse and then unflag it using all. """ - flag_user = str(ObjectId()) - flag_user_2 = str(ObjectId()) - author_user = str(ObjectId()) - Users().insert(flag_user, flag_user) - Users().insert(flag_user_2, flag_user_2) - Users().insert(author_user, author_user) + flag_user = backend.generate_id() + flag_user_2 = backend.generate_id() + author_user = backend.generate_id() + backend.find_or_create_user(flag_user, flag_user) + backend.find_or_create_user(flag_user_2, flag_user_2) + backend.find_or_create_user(author_user, author_user) course_id = "course-xyz" - comment_thread_id = CommentThread().insert( - title="Test Thread", - body="This is a test thread", - course_id="course1", - commentable_id="commentable1", - author_id=author_user, - author_username=author_user, - ) - comment_id = Comment().insert( - "

Comment 1

", - course_id, - author_user, - comment_thread_id=comment_thread_id, - author_username=author_user, - abuse_flaggers=[], - historical_abuse_flaggers=[], - ) - - comment = Comment().get(comment_id) + comment_thread_id = backend.create_thread( + { + "title": "Test Thread", + "body": "This is a test thread", + "course_id": "course1", + "commentable_id": "commentable", + "author_id": author_user, + "author_username": author_user, + } + ) + comment_id = backend.create_comment( + { + "body": "

Comment 1

", + "course_id": course_id, + "author_id": author_user, + "comment_thread_id": comment_thread_id, + "author_username": author_user, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + } + ) + + comment = backend.get_comment(comment_id) assert comment is not None assert comment["abuse_flaggers"] == [] @@ -174,7 +198,7 @@ def test_comment_flag_api_with_all_param(api_client: APIClient) -> None: assert response.status_code == 200 flagged_comment = response.json() assert flagged_comment is not None - comment = Comment().get(flagged_comment["id"]) + comment = backend.get_comment(flagged_comment["id"]) assert comment is not None assert comment["abuse_flaggers"] == [flag_user, flag_user_2] assert comment["historical_abuse_flaggers"] == [] @@ -187,7 +211,7 @@ def test_comment_flag_api_with_all_param(api_client: APIClient) -> None: assert response.status_code == 200 unflagged_comment = response.json() assert unflagged_comment is not None - comment = Comment().get(unflagged_comment["id"]) + comment = backend.get_comment(unflagged_comment["id"]) assert comment is not None assert comment["abuse_flaggers"] == [] assert set(comment["historical_abuse_flaggers"]) == set([flag_user, flag_user_2]) @@ -201,7 +225,7 @@ def test_comment_flag_api_with_all_param(api_client: APIClient) -> None: assert response.status_code == 200 flagged_thread = response.json() assert flagged_thread is not None - thread = CommentThread().get(flagged_thread["id"]) + thread = backend.get_thread(flagged_thread["id"]) assert thread is not None assert thread["abuse_flaggers"] == [flag_user] assert thread["historical_abuse_flaggers"] == [] @@ -214,7 +238,7 @@ def test_comment_flag_api_with_all_param(api_client: APIClient) -> None: assert response.status_code == 200 unflagged_thread = response.json() assert unflagged_thread is not None - thread = CommentThread().get(unflagged_thread["id"]) + thread = backend.get_thread(unflagged_thread["id"]) assert thread is not None assert thread["abuse_flaggers"] == [] assert thread["historical_abuse_flaggers"] == [flag_user] diff --git a/tests/test_views/test_pins.py b/tests/test_views/test_pins.py index 0d2f23f1..93d2e2f6 100644 --- a/tests/test_views/test_pins.py +++ b/tests/test_views/test_pins.py @@ -1,43 +1,52 @@ """Test pin/unpin thread api endpoints.""" -from forum.backends.mongodb import Comment, CommentThread, Users +import pytest + +from forum.backend import get_backend from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + def test_pin_and_unpin_thread_api(api_client: APIClient) -> None: """ Test the pin/unpin thread API. This test checks that a user can pin/unpin a thread. """ - user_id = "unique_1" + user_id = "1" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - title="title", - body="Hello World!", - pinned=False, - author_id=user_id, - course_id="course-v1:Arbisoft+SE002+2024_S2", - commentable_id="66b4e0440dead7001deb948b", - author_username="Faraz", + backend.find_or_create_user(user_id, username="user1") + comment_thread_id = backend.create_thread( + { + "title": "title", + "body": "Hello World!", + "pinned": False, + "author_id": user_id, + "course_id": "course-v1:Arbisoft+SE002+2024_S2", + "commentable_id": "66b4e0440dead7001deb948b", + "author_username": "Faraz", + } ) - Comment().insert( - body="Hello World!", - course_id="course-v1:Arbisoft+SE002+2024_S2", - comment_thread_id=comment_thread_id, - author_id="1", - author_username="Faraz", + backend.create_comment( + { + "body": "Hello World!", + "course_id": "course-v1:Arbisoft+SE002+2024_S2", + "comment_thread_id": comment_thread_id, + "author_id": "1", + "author_username": "Faraz", + } ) - response = api_client.put_json( f"/api/v2/threads/{comment_thread_id}/pin", data={"user_id": user_id}, ) + assert response.status_code == 200 thread_data = response.json() assert thread_data is not None assert thread_data["pinned"] is True - thread = CommentThread().get(comment_thread_id) + thread = backend.get_thread(comment_thread_id) assert thread is not None assert thread["pinned"] is True @@ -45,11 +54,12 @@ def test_pin_and_unpin_thread_api(api_client: APIClient) -> None: f"/api/v2/threads/{comment_thread_id}/unpin", data={"user_id": user_id}, ) + assert response.status_code == 200 thread_data = response.json() assert thread_data is not None assert thread_data["pinned"] is False - thread = CommentThread().get(comment_thread_id) + thread = backend.get_thread(comment_thread_id) assert thread is not None assert thread["pinned"] is False @@ -60,17 +70,18 @@ def test_pin_unpin_thread_api_invalid_data(api_client: APIClient) -> None: This test checks that if user/thread exists or not. """ user_id = "1" - Users().insert(user_id, username="user1", email="email1") + thread_id = backend.generate_id() + backend.find_or_create_user(user_id, username="user1") response = api_client.put_json( - path="/api/v2/threads/66b4e0440dead7001deb948b/pin", + path=f"/api/v2/threads/{thread_id}/pin", data={"user_id": str(user_id)}, ) assert response.status_code == 400 assert response.json() == {"error": "User / Thread doesn't exist"} response = api_client.put_json( - path="/api/v2/threads/66b4e0440dead7001deb948b/unpin", + path=f"/api/v2/threads/{thread_id}/unpin", data={"user_id": str(user_id)}, ) assert response.status_code == 400 diff --git a/tests/test_views/test_search.py b/tests/test_views/test_search.py index 673ce740..111f6fa4 100644 --- a/tests/test_views/test_search.py +++ b/tests/test_views/test_search.py @@ -24,14 +24,16 @@ from unittest.mock import patch from urllib.parse import urlencode +import pytest from requests import Response -from forum.backends.mongodb import Comment, CommentThread, Users -from forum.backends.mongodb.api import mark_as_read +from forum.backends.mongodb.api import MongoBackend as backend from forum.search.backend import get_search_backend from forum.search.comment_search import ThreadSearch from test_utils.client import APIClient +pytestmark = pytest.mark.django_db + def assert_result_total(response: Response, expected_total: int) -> None: """Assert that the total number of results matches the expected total.""" @@ -98,22 +100,27 @@ def test_invalid_request(api_client: APIClient) -> None: user_id = "1" course_id = "course-v1:Arbisoft+SE002+2024_S2" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - title="title", - body="Hello World!", - pinned=False, - author_id=user_id, - course_id=course_id, - commentable_id="66b4e0440dead7001deb948b", - author_username="Faraz", + backend().find_or_create_user(user_id, username="user1") + comment_thread_id = backend().create_thread( + { + "title": "title", + "body": "Hello World!", + "pinned": False, + "author_id": user_id, + "course_id": course_id, + "commentable_id": "66b4e0440dead7001deb948b", + "author_username": "Faraz", + } ) - Comment().insert( - body="Hello World!", - course_id=course_id, - comment_thread_id=comment_thread_id, - author_id="1", - author_username="Faraz", + + backend().create_comment( + { + "body": "Hello World!", + "course_id": course_id, + "comment_thread_id": comment_thread_id, + "author_id": "1", + "author_username": "Faraz", + } ) refresh_elastic_search_indices() @@ -136,16 +143,18 @@ def test_search_returns_empty_for_deleted_thread(api_client: APIClient) -> None: """ course_id = "course-v1:Arbisoft+SE002+2024_S2" - thread_id = CommentThread().insert( - title="title-1", - course_id=course_id, - body="body-1", - author_id="1", - author_username="test_user", - commentable_id="course", + thread_id = backend().create_thread( + { + "title": "title-1", + "course_id": course_id, + "body": "body-1", + "author_id": "1", + "author_username": "test_user", + "commentable_id": "course", + } ) - CommentThread().delete(thread_id) + backend().delete_thread(thread_id) refresh_elastic_search_indices() @@ -168,15 +177,17 @@ def test_search_returns_only_updated_thread(api_client: APIClient) -> None: updated_title = "updated-title" course_id = "course-v1:Arbisoft+SE002+2024_S2" - thread_id = CommentThread().insert( - title=original_title, - course_id=course_id, - body="body-1", - author_id="1", - author_username="test_user", - commentable_id="course", + thread_id = backend().create_thread( + { + "title": original_title, + "course_id": course_id, + "body": "body-1", + "author_id": "1", + "author_username": "test_user", + "commentable_id": "course", + } ) - CommentThread().update(thread_id=thread_id, title=updated_title) + backend().update_thread(thread_id=thread_id, title=updated_title) refresh_elastic_search_indices() @@ -199,21 +210,26 @@ def test_search_returns_empty_for_deleted_comment(api_client: APIClient) -> None """ course_id = "course-v1:Arbisoft+SE002+2024_S2" - thread_id = CommentThread().insert( - title="thread-1", - course_id=course_id, - body="thread-body", - author_id="1", - author_username="test_user", - commentable_id="course", + thread_id = backend().create_thread( + { + "title": "thread-1", + "course_id": course_id, + "body": "thread-body", + "author_id": "1", + "author_username": "test_user", + "commentable_id": "course", + } ) - comment_id = Comment().insert( - body="comment-body", - course_id=course_id, - comment_thread_id=thread_id, - author_id="1", + + comment_id = backend().create_comment( + { + "body": "comment-body", + "course_id": course_id, + "comment_thread_id": thread_id, + "author_id": "1", + } ) - Comment().delete(comment_id) + backend().delete_comment(comment_id) refresh_elastic_search_indices() @@ -235,22 +251,26 @@ def test_search_returns_only_updated_comment(api_client: APIClient) -> None: updated_comment = "comment-updated" course_id = "course-v1:Arbisoft+SE002+2024_S2" - thread_id = CommentThread().insert( - title="thread-1", - course_id=course_id, - body="thread-body", - author_id="1", - author_username="test_user", - commentable_id="course", - ) - comment_id = Comment().insert( - body=original_comment, - course_id=course_id, - comment_thread_id=thread_id, - author_id="1", + thread_id = backend().create_thread( + { + "title": "thread-1", + "course_id": course_id, + "body": "thread-body", + "author_id": "1", + "author_username": "test_user", + "commentable_id": "course", + } ) - Comment().update(comment_id=comment_id, body=updated_comment) + comment_id = backend().create_comment( + { + "body": original_comment, + "course_id": course_id, + "comment_thread_id": thread_id, + "author_id": "1", + } + ) + backend().update_comment(comment_id=comment_id, body=updated_comment) refresh_elastic_search_indices() params = {"course_id": course_id, "text": original_comment} @@ -274,36 +294,42 @@ def create_threads_and_comments_for_filter_tests( for i in range(35): context = "standalone" if i > 29 else "course" group_id = i % 5 - thread_id = CommentThread().insert( - title=f"title-{i}", - body="text", - author_id="1", - course_id=course_id_0 if i % 2 == 0 else course_id_1, - commentable_id=f"commentable{i % 3}", - context=context, - group_id=group_id, + thread_id = backend().create_thread( + { + "title": f"title-{i}", + "body": "text", + "author_id": "1", + "course_id": course_id_0 if i % 2 == 0 else course_id_1, + "commentable_id": f"commentable{i % 3}", + "context": context, + "group_id": group_id, + } ) threads_ids.append(thread_id) if i < 2: - comment_id = Comment().insert( - body="objectionable", - course_id=course_id_0 if i % 2 == 0 else course_id_1, - comment_thread_id=thread_id, - author_id="1", + comment_id = backend().create_comment( + { + "body": "objectionable", + "course_id": course_id_0 if i % 2 == 0 else course_id_1, + "comment_thread_id": thread_id, + "author_id": "1", + } ) - Comment().update(comment_id=comment_id, abuse_flaggers=["1"]) + backend().update_comment(comment_id=comment_id, abuse_flaggers=["1"]) comment_ids = threads_comments.get(thread_id, []) comment_ids.append(comment_id) threads_comments[thread_id] = comment_ids if i in [0, 2, 4]: - CommentThread().update(thread_id=thread_id, thread_type="question") - comment_id = Comment().insert( - body="response", - course_id=course_id_0 if i % 2 == 0 else course_id_1, - comment_thread_id=thread_id, - author_id="1", + backend().update_thread(thread_id=thread_id, thread_type="question") + comment_id = backend().create_comment( + { + "body": "response", + "course_id": course_id_0 if i % 2 == 0 else course_id_1, + "comment_thread_id": thread_id, + "author_id": "1", + } ) comment_ids = threads_comments.get(thread_id, []) comment_ids.append(comment_id) @@ -322,7 +348,7 @@ def test_filter_threads(api_client: APIClient) -> None: course_id_0 = "course-v1:Arbisoft+SE002+2024_S2" course_id_1 = "course-v1:Arbisoft+SE003+2024_S2" - user_id = Users().insert("1", username="user1", email="example@test.com") + user_id = backend().find_or_create_user("1", username="user1") threads_ids, threads_comments = create_threads_and_comments_for_filter_tests( course_id_0, course_id_1 ) @@ -351,12 +377,7 @@ def assert_response_contains( response = get_search_response(api_client, params, threads_ids[30:35]) assert_response_contains(response, list(range(30, 35))) - # Test filtering with unread filter - user = Users().get(_id=user_id) or {} - thread_course_1 = CommentThread().get(_id=threads_ids[0]) or {} - thread_course_2 = CommentThread().get(_id=threads_ids[1]) or {} - - mark_as_read(user, thread_course_2) + backend().mark_as_read(user_id, threads_ids[1]) params = { "text": "text", "course_id": course_id_0, @@ -366,7 +387,7 @@ def assert_response_contains( response = get_search_response(api_client, params, threads_ids[:35:2]) assert_response_contains(response, [i for i in range(30) if i % 2 == 0]) - mark_as_read(user, thread_course_1) + backend().mark_as_read(user_id, threads_ids[0]) params = { "text": "text", "course_id": course_id_0, @@ -406,7 +427,7 @@ def assert_response_contains( assert_response_contains(response, [0, 4]) comment = threads_comments[threads_ids[4]][0] - Comment().update(comment_id=comment, endorsed=True) + backend().update_comment(comment_id=comment, endorsed=True) refresh_elastic_search_indices() response = get_search_response(api_client, params, threads_ids[:30:2]) @@ -454,12 +475,14 @@ def test_pagination(api_client: APIClient) -> None: threads_ids = [] for i in range(50): - thread_id = CommentThread().insert( - title=f"title-{i}", - body="text", - author_id="1", - course_id=course_id, - commentable_id="dummy", + thread_id = backend().create_thread( + { + "title": f"title-{i}", + "body": "text", + "author_id": "1", + "course_id": course_id, + "commentable_id": "dummy", + } ) threads_ids.append(thread_id) # Add a slight delay to ensure created_date is different @@ -498,23 +521,25 @@ def test_sorting(api_client: APIClient) -> None: # Create and save threads threads_ids = [] for i in range(6): - thread = CommentThread().insert( - title=f"title-{i}", - body="text", - author_id="1", - course_id=course_id, - commentable_id="dummy", + thread = backend().create_thread( + { + "title": f"title-{i}", + "body": "text", + "author_id": "1", + "course_id": course_id, + "commentable_id": "dummy", + } ) threads_ids.append(thread) # Add a slight delay to ensure created_date is different time.sleep(0.001) # Update specific threads to simulate activity, votes, and comments - votes = CommentThread().get_votes_dict(up=["1"], down=[]) - CommentThread().update(thread_id=threads_ids[1], votes=votes) - CommentThread().update(thread_id=threads_ids[2], votes=votes) - CommentThread().update(thread_id=threads_ids[1], comments_count=5) - CommentThread().update(thread_id=threads_ids[3], comments_count=5) + votes = backend().get_votes_dict(up=["1"], down=[]) + backend().update_thread(thread_id=threads_ids[1], votes=votes) + backend().update_thread(thread_id=threads_ids[2], votes=votes) + backend().update_thread(thread_id=threads_ids[1], comments_count=5) + backend().update_thread(thread_id=threads_ids[3], comments_count=5) refresh_elastic_search_indices() @@ -550,19 +575,23 @@ def test_spelling_correction(api_client: APIClient) -> None: thread_title = "a thread about green artichokes" comment_body = "a comment about greed pineapples" - thread_id = CommentThread().insert( - title=thread_title, - body="", - author_id="1", - course_id="course_id", - commentable_id=commentable_id, + thread_id = backend().create_thread( + { + "title": thread_title, + "body": "", + "author_id": "1", + "course_id": "course_id", + "commentable_id": commentable_id, + } ) - Comment().insert( - body=comment_body, - course_id="course_id", - comment_thread_id=thread_id, - author_id="1", + backend().create_comment( + { + "body": comment_body, + "course_id": "course_id", + "comment_thread_id": thread_id, + "author_id": "1", + } ) refresh_elastic_search_indices() @@ -628,12 +657,14 @@ def test_spelling_correction_with_mush_clause(api_client: APIClient) -> None: # to the filter, and that suggestion in this case does not match any # results, we should get back no results and no correction. for _ in range(10): - CommentThread().insert( - title="abbot", - body="text", - author_id="1", - course_id="other_course_id", - commentable_id="other_commentable_id", + backend().create_thread( + { + "title": "abbot", + "body": "text", + "author_id": "1", + "course_id": "other_course_id", + "commentable_id": "other_commentable_id", + } ) refresh_elastic_search_indices() @@ -674,12 +705,14 @@ def test_total_results_and_num_pages(api_client: APIClient) -> None: text += " one" # Create the comment - thread_id = CommentThread().insert( - title=f"title-{i}", - body=text, - course_id=course_id, - author_id="1", - commentable_id="course", + thread_id = backend().create_thread( + { + "title": f"title-{i}", + "body": text, + "course_id": course_id, + "author_id": "1", + "commentable_id": "course", + } ) threads_ids.append(thread_id) @@ -734,18 +767,23 @@ def test_unicode_data(api_client: APIClient) -> None: search_term = "artichoke" # Create a comment thread and a comment containing the specified text - thread_id = CommentThread().insert( - title="A thread title", - body=f"{search_term} {text}", - author_id="1", - course_id="course-v1:Arbisoft+SE002+2024_S2", - commentable_id="course", + thread_id = backend().create_thread( + { + "title": "A thread title", + "body": f"{search_term} {text}", + "author_id": "1", + "course_id": "course-v1:Arbisoft+SE002+2024_S2", + "commentable_id": "course", + } ) - Comment().insert( - body=text, - course_id="course-v1:Arbisoft+SE002+2024_S2", - comment_thread_id=thread_id, - author_id="1", + + backend().create_comment( + { + "body": text, + "course_id": "course-v1:Arbisoft+SE002+2024_S2", + "comment_thread_id": thread_id, + "author_id": "1", + } ) # Refresh Elasticsearch indices to make the new data searchable diff --git a/tests/test_views/test_subscriptions.py b/tests/test_views/test_subscriptions.py index 62f6f9cf..a380d0a2 100644 --- a/tests/test_views/test_subscriptions.py +++ b/tests/test_views/test_subscriptions.py @@ -1,26 +1,34 @@ """Tests for subscription apis.""" -from forum.backends.mongodb import CommentThread, Subscriptions, Users +import pytest + +from forum.backend import get_backend from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + def test_get_subscribed_threads(api_client: APIClient) -> None: """ Test getting subscribed threads for a user. """ user_id = "1" + username = "user1" course_id = "demo_course" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - "Thread 1", - "Body 1", - course_id, - "CommentThread", - "3", - "user3", - ) + backend.find_or_create_user(user_id, username=username) - Subscriptions().insert(user_id, comment_thread_id, source_type="CommentThread") + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "thread_type": "discussion", + "author_id": user_id, + "author_username": username, + } + ) + backend.subscribe_user(user_id, comment_thread_id, source_type="CommentThread") response = api_client.get( f"/api/v2/users/{user_id}/subscribed_threads?course_id={course_id}" ) @@ -35,17 +43,19 @@ def test_get_subscribed_threads_with_filters(api_client: APIClient) -> None: Test getting subscribed threads for a user with filters. """ user_id = "1" + username = "user1" course_id = "demo_course" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - "Thread 1", - "Body 1", - course_id, - "CommentThread", - "3", - "user3", + backend.find_or_create_user(user_id, username=username) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "author_id": user_id, + "author_username": username, + } ) - Subscriptions().insert(user_id, comment_thread_id, source_type="thread") + backend.subscribe_user(user_id, comment_thread_id, source_type="CommentThread") response = api_client.get( f"/api/v2/users/{user_id}/subscribed_threads?flagged=true&course_id={course_id}" @@ -54,7 +64,7 @@ def test_get_subscribed_threads_with_filters(api_client: APIClient) -> None: threads = response.json()["collection"] assert len(threads) == 0 - CommentThread().update(comment_thread_id, abuse_flaggers=[user_id]) + backend.update_thread(comment_thread_id, abuse_flaggers=[user_id]) response = api_client.get( f"/api/v2/users/{user_id}/subscribed_threads?flagged=true&course_id={course_id}" ) @@ -70,21 +80,28 @@ def test_subscribe_thread(api_client: APIClient) -> None: """ user_id = "1" course_id = "demo_course" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - "Thread 1", - "Body 1", - course_id, - "CommentThread", - "3", - "user3", + username = "user1" + author_id = "2" + author_username = "author" + backend.find_or_create_user(user_id, username) + backend.find_or_create_user(author_id, author_username) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "author_id": author_id, + "author_username": author_username, + } ) response = api_client.post( f"/api/v2/users/{user_id}/subscriptions", data={"source_type": "thread", "source_id": comment_thread_id}, ) assert response.status_code == 200 - subscription = Subscriptions().get_subscription(user_id, comment_thread_id) + subscription = backend.subscribe_user( + user_id, comment_thread_id, source_type="CommentThread" + ) assert subscription is not None @@ -94,23 +111,27 @@ def test_unsubscribe_thread(api_client: APIClient) -> None: """ user_id = "1" course_id = "demo_course" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - "Thread 1", - "Body 1", - course_id, - "CommentThread", - "3", - "user3", + username = "user1" + author_id = "2" + author_username = "author" + backend.find_or_create_user(user_id, username) + backend.find_or_create_user(author_id, author_username) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "author_id": author_id, + "author_username": author_username, + } ) - Subscriptions().insert(user_id, comment_thread_id, source_type="thread") + backend.subscribe_user(user_id, comment_thread_id, source_type="CommentThread") response = api_client.delete( f"/api/v2/users/{user_id}/subscriptions?source_id={comment_thread_id}" ) assert response.status_code == 200 - subscription = Subscriptions().get_subscription(user_id, comment_thread_id) - assert subscription is None + assert backend.get_subscription(user_id, comment_thread_id) is None # Attempt to unsubscribe from a thread that the user is not subscribed to response = api_client.delete( @@ -125,34 +146,45 @@ def test_get_subscribed_threads_with_pagination(api_client: APIClient) -> None: """ user_id = "1" course_id = "demo_course" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - "Thread 1", - "Body 1", - course_id, - "CommentThread", - "3", - "user3", + username = "user1" + author_id = "2" + author_username = "author" + backend.find_or_create_user(user_id, username) + backend.find_or_create_user(author_id, author_username) + comment_thread_id_1 = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "author_id": author_id, + "author_username": author_username, + } ) - comment_thread_id_2 = CommentThread().insert( - "Thread 2", - "Body 2", - course_id, - "CommentThread", - "3", - "user3", + + comment_thread_id_2 = backend.create_thread( + { + "title": "Thread 2", + "body": "Body 2", + "course_id": course_id, + "type": "CommentThread", + "author_id": author_id, + "author_username": author_username, + } ) - comment_thread_id_3 = CommentThread().insert( - "Thread 2", - "Body 2", - course_id, - "CommentThread", - "3", - "user3", + + comment_thread_id_3 = backend.create_thread( + { + "title": "Thread 2", + "body": "Body 2", + "course_id": course_id, + "type": "CommentThread", + "author_id": author_id, + "author_username": author_username, + } ) - Subscriptions().insert(user_id, comment_thread_id, source_type="thread") - Subscriptions().insert(user_id, comment_thread_id_2, source_type="thread") - Subscriptions().insert(user_id, comment_thread_id_3, source_type="thread") + backend.subscribe_user(user_id, comment_thread_id_1, source_type="CommentThread") + backend.subscribe_user(user_id, comment_thread_id_2, source_type="CommentThread") + backend.subscribe_user(user_id, comment_thread_id_3, source_type="CommentThread") response = api_client.get( f"/api/v2/users/{user_id}/subscribed_threads?page=1&per_page=2&course_id={course_id}" @@ -161,7 +193,7 @@ def test_get_subscribed_threads_with_pagination(api_client: APIClient) -> None: threads = response.json()["collection"] assert len(threads) == 2 assert threads[0]["id"] in [ - comment_thread_id, + comment_thread_id_1, comment_thread_id_2, comment_thread_id_3, ] @@ -173,7 +205,7 @@ def test_get_subscribed_threads_with_pagination(api_client: APIClient) -> None: threads = response.json()["collection"] assert len(threads) == 1 assert threads[0]["id"] in [ - comment_thread_id, + comment_thread_id_1, comment_thread_id_2, comment_thread_id_3, ] @@ -185,18 +217,24 @@ def test_get_thread_subscriptions(api_client: APIClient) -> None: """ user_id = "1" course_id = "demo_course" - Users().insert(user_id, username="user1", email="email1") - comment_thread_id = CommentThread().insert( - "Thread 1", - "Body 1", - course_id, - "CommentThread", - "3", - "user3", + username = "user1" + author_id = "2" + author_username = "author" + backend.find_or_create_user(user_id, username) + backend.find_or_create_user(author_id, author_username) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "author_id": author_id, + "author_username": author_username, + } ) - subscription_id = Subscriptions().insert( + subscription = backend.subscribe_user( user_id, comment_thread_id, source_type="CommentThread" ) + assert subscription response = api_client.get( f"/api/v2/threads/{comment_thread_id}/subscriptions?page=1" @@ -204,7 +242,7 @@ def test_get_thread_subscriptions(api_client: APIClient) -> None: assert response.status_code == 200 subscriptions = response.json()["collection"] assert len(subscriptions) == 1 - assert subscriptions[0]["id"] == subscription_id + assert subscriptions[0]["id"] == subscription["_id"] response = api_client.get( f"/api/v2/threads/{comment_thread_id}/subscriptions?page=2" @@ -220,22 +258,22 @@ def test_get_thread_subscriptions_with_pagination(api_client: APIClient) -> None """ user_id = "1" course_id = "demo_course" - comment_thread_id = CommentThread().insert( - "Thread 1", - "Body 1", - course_id, - "CommentThread", - "3", - "user3", + author_id = "10" + author_username = "author" + backend.find_or_create_user(author_id, author_username) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Body 1", + "course_id": course_id, + "author_id": author_id, + "author_username": author_username, + } ) user_ids = ["1", "2", "3", "4", "5"] for user_id in user_ids: - Users().insert( - user_id, - username=f"user{user_id}", - email=f"email{user_id}@example.com", - ) - Subscriptions().insert(user_id, comment_thread_id, source_type="CommentThread") + backend.find_or_create_user(user_id, username=f"user{user_id}") + backend.subscribe_user(user_id, comment_thread_id, source_type="CommentThread") response = api_client.get( f"/api/v2/threads/{comment_thread_id}/subscriptions?page=1&per_page=2" diff --git a/tests/test_views/test_threads.py b/tests/test_views/test_threads.py index a6c06815..c0fe3446 100644 --- a/tests/test_views/test_threads.py +++ b/tests/test_views/test_threads.py @@ -2,9 +2,14 @@ from typing import Optional -from forum.backends.mongodb import Comment, CommentThread, Subscriptions, Users +import pytest + +from forum.backend import get_backend from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + def setup_models( user_id: Optional[str] = None, @@ -22,17 +27,19 @@ def setup_models( user_id = user_id or "1" username = username or "user1" course_id = course_id or "course1" - Users().insert(user_id, username=username, email="email1") - comment_thread_id = CommentThread().insert( - title="Thread 1", - body="Thread 1", - course_id=course_id, - commentable_id="CommentThread", - author_id=user_id, - author_username=username, - abuse_flaggers=[], - historical_abuse_flaggers=[], - thread_type=thread_type or "discussion", + backend.find_or_create_user(user_id, username=username) + comment_thread_id = backend.create_thread( + { + "title": "Thread 1", + "body": "Thread 1", + "course_id": course_id, + "commentable_id": "CommentThread", + "author_id": user_id, + "author_username": username, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + "thread_type": thread_type or "discussion", + } ) return user_id, comment_thread_id @@ -42,19 +49,24 @@ def create_comments_in_a_thread(thread_id: str) -> tuple[str, str]: user_id = "1" username = "user1" course_id = "course1" - comment_id_1 = Comment().insert( - body="Comment 1", - course_id=course_id, - author_id=user_id, - comment_thread_id=thread_id, - author_username=username, - ) - comment_id_2 = Comment().insert( - body="Comment 2", - course_id=course_id, - author_id=user_id, - comment_thread_id=thread_id, - author_username=username, + comment_id_1 = backend.create_comment( + { + "body": "Comment 1", + "course_id": course_id, + "author_id": user_id, + "comment_thread_id": thread_id, + "author_username": username, + } + ) + + comment_id_2 = backend.create_comment( + { + "body": "Comment 2", + "course_id": course_id, + "author_id": user_id, + "comment_thread_id": thread_id, + "author_username": username, + } ) return comment_id_1, comment_id_2 @@ -75,7 +87,7 @@ def test_update_thread(api_client: APIClient) -> None: ) assert response.status_code == 200 updated_thread = response.json() - updated_thread_from_db = CommentThread().get(updated_thread["id"]) + updated_thread_from_db = backend.get_thread(updated_thread["id"]) assert updated_thread_from_db is not None assert updated_thread_from_db["body"] == "new thread body" assert updated_thread_from_db["title"] == "new thread title" @@ -97,7 +109,7 @@ def test_update_thread_without_user_id(api_client: APIClient) -> None: ) assert response.status_code == 200 updated_thread = response.json() - updated_thread_from_db = CommentThread().get(updated_thread["id"]) + updated_thread_from_db = backend.get_thread(updated_thread["id"]) assert updated_thread_from_db is not None assert updated_thread_from_db["body"] == "new thread body" assert updated_thread_from_db["title"] == "new thread title" @@ -118,7 +130,7 @@ def test_update_close_reason(api_client: APIClient) -> None: ) assert response.status_code == 200 updated_thread = response.json() - updated_thread_from_db = CommentThread().get(updated_thread["id"]) + updated_thread_from_db = backend.get_thread(updated_thread["id"]) assert updated_thread_from_db is not None assert updated_thread_from_db["closed"] assert updated_thread_from_db["close_reason_code"] == "test_code" @@ -147,7 +159,7 @@ def test_closing_and_reopening_thread_clears_reason_code(api_client: APIClient) ) assert response.status_code == 200 updated_thread = response.json() - updated_thread_from_db = CommentThread().get(updated_thread["id"]) + updated_thread_from_db = backend.get_thread(updated_thread["id"]) assert updated_thread_from_db is not None assert not updated_thread_from_db["closed"] assert updated_thread_from_db["close_reason_code"] is None @@ -156,7 +168,7 @@ def test_closing_and_reopening_thread_clears_reason_code(api_client: APIClient) def test_update_thread_not_exist(api_client: APIClient) -> None: """Test thread does not exists through update thread API.""" - wrong_thread_id = "66cd75eba3a68c001d51927b" + wrong_thread_id = backend.generate_id() response = api_client.put_json( f"/api/v2/threads/{wrong_thread_id}", data={ @@ -182,7 +194,7 @@ def test_unicode_data(api_client: APIClient) -> None: ) assert response.status_code == 200 updated_thread = response.json() - updated_thread_from_db = CommentThread().get(updated_thread["id"]) + updated_thread_from_db = backend.get_thread(updated_thread["id"]) assert updated_thread_from_db is not None assert updated_thread_from_db["body"] == text assert updated_thread_from_db["title"] == text @@ -192,28 +204,25 @@ def test_delete_thread(api_client: APIClient) -> None: """Test delete a thread.""" user_id, thread_id = setup_models() comment_id_1, comment_id_2 = create_comments_in_a_thread(thread_id) - thread_from_db = CommentThread().get(thread_id) + thread_from_db = backend.get_thread(thread_id) assert thread_from_db is not None assert thread_from_db["comment_count"] == 2 response = api_client.delete_json(f"/api/v2/threads/{thread_id}") assert response.status_code == 200 - assert CommentThread().get(thread_id) is None - assert Comment().get(comment_id_1) is None - assert Comment().get(comment_id_2) is None - assert ( - Subscriptions().get_subscription(subscriber_id=user_id, source_id=thread_id) - is None - ) + assert backend.get_thread(thread_id) is None + assert backend.get_comment(comment_id_1) is None + assert backend.get_comment(comment_id_2) is None + assert backend.get_subscription(subscriber_id=user_id, source_id=thread_id) is None def test_delete_thread_not_exist(api_client: APIClient) -> None: """Test thread does not exists through delete thread API.""" - wrong_thread_id = "66cd75eba3a68c001d51927b" + wrong_thread_id = backend.generate_id() response = api_client.delete_json(f"/api/v2/threads/{wrong_thread_id}") assert response.status_code == 400 -def test_invalide_data(api_client: APIClient) -> None: +def test_invalid_data(api_client: APIClient) -> None: """Test invalid data""" setup_models() response = api_client.get_json("/api/v2/threads", {}) @@ -240,16 +249,18 @@ def test_filter_by_course(api_client: APIClient) -> None: def test_filter_exclude_standalone(api_client: APIClient) -> None: """Test filter exclude standalone threads through get thread API.""" setup_models() - CommentThread().insert( - title="Thread 2", - body="Thread 2", - course_id="course1", - commentable_id="CommentThread", - author_id="1", - author_username="user1", - abuse_flaggers=[], - historical_abuse_flaggers=[], - context="standalone", + backend.create_thread( + { + "title": "Thread 2", + "body": "Thread 2", + "course_id": "course1", + "commentable_id": "CommentThread", + "author_id": "1", + "author_username": "user1", + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + "context": "standalone", + } ) params = {"course_id": "course1"} response = api_client.get_json("/api/v2/threads", params) @@ -269,7 +280,7 @@ def test_api_with_count_flagged(api_client: APIClient) -> None: comment_id_1, comment_id_2 = create_comments_in_a_thread(thread_id) # Mark Comment 1 as abused - Comment().update(comment_id_1, abuse_flaggers=["1"]) + backend.update_comment(comment_id_1, abuse_flaggers=["1"]) params = {"course_id": "course1", "count_flagged": "true"} response = api_client.get_json("/api/v2/threads", params) @@ -280,7 +291,7 @@ def test_api_with_count_flagged(api_client: APIClient) -> None: assert results[0]["abuse_flagged_count"] == 1 # Mark Comment 2 as abused - Comment().update(comment_id_2, abuse_flaggers=["1"]) + backend.update_comment(comment_id_2, abuse_flaggers=["1"]) params = {"course_id": "course1", "count_flagged": "true"} response = api_client.get_json("/api/v2/threads", params) @@ -353,29 +364,33 @@ def test_anonymous_threads(api_client: APIClient) -> None: course_id = "course-1" author_id = "1" author_username = "author-1" - Users().insert(author_id, username=author_username, email="author@example.com") - - CommentThread().insert( - title="Thread 1", - body="Thread 1", - course_id=course_id, - commentable_id="CommentThread", - author_id=author_id, - author_username=author_username, - ) - - CommentThread().insert( - title="Thread 2", - body="Thread 2", - course_id=course_id, - commentable_id="CommentThread", - author_id=author_id, - author_username=author_username, - anonymous=True, - anonymous_to_peers=True, - ) - - user_id = Users().insert("2", username="anonymus-user", email="email2") + backend.find_or_create_user(author_id, username=author_username) + + backend.create_thread( + { + "title": "Thread 1", + "body": "Thread 1", + "course_id": course_id, + "commentable_id": "CommentThread", + "author_id": author_id, + "author_username": author_username, + } + ) + + backend.create_thread( + { + "title": "Thread 2", + "body": "Thread 2", + "course_id": course_id, + "commentable_id": "CommentThread", + "author_id": author_id, + "author_username": author_username, + "anonymous": True, + "anonymous_to_peers": True, + } + ) + + user_id = backend.find_or_create_user("2", username="anonymus-user") params = {"course_id": course_id, "author_id": author_id, "user_id": user_id} response = api_client.get_json("/api/v2/threads", params) @@ -409,16 +424,20 @@ def test_filter_by_post_type(api_client: APIClient) -> None: """Test filter threads by thread_type through get thread API.""" setup_models() setup_models("2", "user2", "course1") - CommentThread().insert( - title="Thread 3", - body="Thread 3", - course_id="course1", - commentable_id="CommentThread", - author_id="3", - author_username="user3", - abuse_flaggers=[], - historical_abuse_flaggers=[], - thread_type="question", + username_3 = "user3" + user_id_3 = backend.find_or_create_user("3", username_3) + backend.create_thread( + { + "title": "Thread 3", + "body": "Thread 3", + "course_id": "course1", + "commentable_id": "CommentThread", + "author_id": user_id_3, + "author_username": username_3, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + "thread_type": "question", + } ) params = {"course_id": "course1", "thread_type": "discussion"} response = api_client.get_json("/api/v2/threads", params) @@ -443,16 +462,20 @@ def test_filter_unanswered_questions(api_client: APIClient) -> None: username = "user1" user_id_1, thread1 = setup_models("1", "user1", thread_type="question") user_id_2, thread2 = setup_models("2", "user2", thread_type="question") - CommentThread().insert( - title="Thread 3", - body="Thread 3", - course_id=course_id, - commentable_id="CommentThread", - author_id=user_id_1, - author_username=username, - abuse_flaggers=[], - historical_abuse_flaggers=[], - thread_type="question", + username_3 = "user3" + user_id_3 = backend.find_or_create_user("3", username_3) + backend.create_thread( + { + "title": "Thread 3", + "body": "Thread 3", + "course_id": "course1", + "commentable_id": "CommentThread", + "author_id": user_id_3, + "author_username": username_3, + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + "thread_type": "question", + } ) params = {"course_id": "course1", "unanswered": True} @@ -461,23 +484,28 @@ def test_filter_unanswered_questions(api_client: APIClient) -> None: results = response.json()["collection"] assert len(results) == 3 - comment_id_1 = Comment().insert( - body="

Thread 1 Comment

", - course_id=course_id, - author_id=user_id_1, - comment_thread_id=thread1, - author_username=username, + comment_id_1 = backend.create_comment( + { + "body": "

Thread 1 Comment

", + "course_id": course_id, + "author_id": user_id_1, + "comment_thread_id": thread1, + "author_username": username, + } ) - comment_id_2 = Comment().insert( - body="

Thread 2 Comment

", - course_id=course_id, - author_id=user_id_2, - comment_thread_id=thread2, - author_username=username, + + comment_id_2 = backend.create_comment( + { + "body": "

Thread 2 Comment

", + "course_id": course_id, + "author_id": user_id_2, + "comment_thread_id": thread2, + "author_username": username, + } ) - Comment().update(comment_id=comment_id_1, endorsed=True) - Comment().update(comment_id=comment_id_2, endorsed=True) + backend.update_comment(comment_id=comment_id_1, endorsed=True) + backend.update_comment(comment_id=comment_id_2, endorsed=True) # api_client.put_json( # f"/api/v2/threads/{thread1}", @@ -525,14 +553,16 @@ def test_get_thread(api_client: APIClient) -> None: def test_computes_endorsed_correctly(api_client: APIClient) -> None: """Test computes endorsed correctly through get thread API.""" _, thread_id = setup_models() - comment_id = Comment().insert( - body="Comment 1", - course_id="course1", - author_id="1", - comment_thread_id=thread_id, - author_username="user1", - ) - Comment().update(comment_id=comment_id, endorsed=True) + comment_id = backend.create_comment( + { + "body": "Comment 1", + "course_id": "course1", + "author_id": "1", + "comment_thread_id": thread_id, + "author_username": "user1", + } + ) + backend.update_comment(comment_id=comment_id, endorsed=True) response = api_client.get_json( f"/api/v2/threads/{thread_id}", params={ @@ -554,12 +584,14 @@ def test_computes_endorsed_correctly(api_client: APIClient) -> None: def test_no_children_for_informational_request(api_client: APIClient) -> None: """Test no children returned from get thread by thread_id API""" _, thread_id = setup_models() - Comment().insert( - body="Comment 1", - course_id="course1", - author_id="1", - comment_thread_id=thread_id, - author_username="user1", + backend.create_comment( + { + "body": "Comment 1", + "course_id": "course1", + "author_id": "1", + "comment_thread_id": thread_id, + "author_username": "user1", + } ) response = api_client.get_json( f"/api/v2/threads/{thread_id}", @@ -632,19 +664,21 @@ def test_endorement_is_none_after_unanswering_a_comment_in_question( when that question was initialy marked as answered. """ user_id, thread_id = setup_models(thread_type="question") - comment_id = Comment().insert( - body="Comment 1", - course_id="course1", - author_id=user_id, - comment_thread_id=thread_id, - author_username="user1", + comment_id = backend.create_comment( + { + "body": "Comment 1", + "course_id": "course1", + "author_id": user_id, + "comment_thread_id": thread_id, + "author_username": "user1", + } ) response = api_client.put_json( f"/api/v2/comments/{comment_id}", data={"endorsed": "True", "endorsement_user_id": user_id}, ) assert response.status_code == 200 - comment = Comment().get(comment_id) + comment = backend.get_comment(comment_id) assert comment is not None assert comment["endorsed"] is True assert comment["endorsement"]["user_id"] == user_id @@ -680,19 +714,24 @@ def test_response_for_thread_type_question(api_client: APIClient) -> None: It varies according to queryparams. """ user_id, thread_id = setup_models(thread_type="question") - comment_id1 = Comment().insert( - body="Comment 1", - course_id="course1", - author_id=user_id, - comment_thread_id=thread_id, - author_username="user1", - ) - comment_id2 = Comment().insert( - body="Comment 2", - course_id="course1", - author_id=user_id, - comment_thread_id=thread_id, - author_username="user1", + comment_id1 = backend.create_comment( + { + "body": "Comment 1", + "course_id": "course1", + "author_id": user_id, + "comment_thread_id": thread_id, + "author_username": "user1", + } + ) + + comment_id2 = backend.create_comment( + { + "body": "Comment 2", + "course_id": "course1", + "author_id": user_id, + "comment_thread_id": thread_id, + "author_username": "user1", + } ) response = api_client.put_json( f"/api/v2/comments/{comment_id1}", diff --git a/tests/test_views/test_users.py b/tests/test_views/test_users.py index c67df5de..4932bd33 100644 --- a/tests/test_views/test_users.py +++ b/tests/test_views/test_users.py @@ -1,29 +1,37 @@ """Tests for Users apis.""" -from forum.backends.mongodb import Comment, CommentThread, Contents, Users -from forum.backends.mongodb.api import subscribe_user, upvote_content +import pytest + +from forum.backend import get_backend from forum.constants import RETIRED_BODY, RETIRED_TITLE from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + def setup_10_threads(author_id: str, author_username: str) -> list[str]: """Create 10 threads for a user.""" ids = [] for thread in range(10): - thread_id = CommentThread().insert( - title=f"Test Thread {thread}", - body="This is a test thread", - course_id="course1", - commentable_id="commentable1", - author_id=author_id, - author_username=author_username, + thread_id = backend.create_thread( + { + "title": f"Test Thread {thread}", + "body": "This is a test thread", + "course_id": "course1", + "commentable_id": "commentable1", + "author_id": author_id, + "author_username": author_username, + } ) - Comment().insert( - body="This is a test comment", - course_id="course1", - author_id=author_id, - comment_thread_id=str(thread_id), - author_username=author_username, + backend.create_comment( + { + "body": "This is a test comment", + "course_id": "course1", + "author_id": author_id, + "comment_thread_id": str(thread_id), + "author_username": author_username, + } ) ids.append(thread_id) return ids @@ -31,22 +39,22 @@ def setup_10_threads(author_id: str, author_username: str) -> list[str]: def test_create_user(api_client: APIClient) -> None: """Test creating a new user.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" response = api_client.post_json( "/api/v2/users", data={"id": user_id, "username": username} ) assert response.status_code == 200 - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == username def test_create_user_with_existing_id(api_client: APIClient) -> None: """Test create user with an existing id.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) @@ -58,24 +66,24 @@ def test_create_user_with_existing_id(api_client: APIClient) -> None: def test_create_user_with_existing_username(api_client: APIClient) -> None: """Test create user with an existing username.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) response = api_client.post_json( - "/api/v2/users", data={"id": "test_id_2", "username": username} + "/api/v2/users", data={"id": backend.generate_id(), "username": username} ) assert response.status_code == 400 def test_update_user(api_client: APIClient) -> None: """Test updating user information.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" new_username = "new-test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) @@ -83,30 +91,14 @@ def test_update_user(api_client: APIClient) -> None: f"/api/v2/users/{user_id}", data={"username": new_username} ) assert response.status_code == 200 - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == new_username -def test_update_user_id(api_client: APIClient) -> None: - """Test updating user id.""" - user_id = "test_id" - username = "test-user" - new_id = "new-test-id" - Users().insert( - user_id, - username, - ) - response = api_client.put_json(f"/api/v2/users/{user_id}", data={"id": new_id}) - assert response.status_code == 200 - user = Users().get(user_id) - assert user - assert user["username"] == username - - def test_update_non_existent_user(api_client: APIClient) -> None: """Test updating non-existent user.""" - user_id = "test_id" + user_id = backend.generate_id() response = api_client.put_json( f"/api/v2/users/{user_id}", data={"username": "new-test-user"} ) @@ -115,15 +107,15 @@ def test_update_non_existent_user(api_client: APIClient) -> None: def test_update_user_with_conflicting_info(api_client: APIClient) -> None: """Test updating user with conflicting information.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" conflicting_username = "test-user-2" - Users().insert( + backend.find_or_create_user( user_id, username, ) - Users().insert( - "test_id_2", + backend.find_or_create_user( + backend.generate_id(), conflicting_username, ) response = api_client.put_json( @@ -134,9 +126,9 @@ def test_update_user_with_conflicting_info(api_client: APIClient) -> None: def test_get_user(api_client: APIClient) -> None: """Test getting user information.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) @@ -149,16 +141,16 @@ def test_get_user(api_client: APIClient) -> None: def test_get_non_existent_user(api_client: APIClient) -> None: """Test getting non-existent user.""" - user_id = "test_id" + user_id = backend.generate_id() response = api_client.get(f"/api/v2/users/{user_id}") assert response.status_code == 404 def test_get_user_with_no_votes(api_client: APIClient) -> None: """Test getting user with no votes.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) @@ -170,27 +162,31 @@ def test_get_user_with_no_votes(api_client: APIClient) -> None: def test_get_user_with_votes(api_client: APIClient) -> None: """Test getting user with votes.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) - thread_id = CommentThread().insert( - title="Test Thread", - body="This is a test thread", - course_id="course1", - commentable_id="commentable1", - author_id="author1", - author_username="author_user", - ) - thread = CommentThread().get(thread_id) - user = Users().get(user_id) + author_id = backend.generate_id() + author_username = "author" + backend.find_or_create_user(author_id, author_username) + thread_id = backend.create_thread( + { + "title": "Test Thread", + "body": "This is a test thread", + "course_id": "course1", + "commentable_id": "commentable1", + "author_id": author_id, + "author_username": author_username, + } + ) + thread = backend.get_thread(thread_id) + user = backend.get_user(user_id) assert thread assert user - upvote_content( - thread, - user, + backend.upvote_content( + thread["_id"], user["external_id"], content_type="CommentThread" ) response = api_client.get(f"/api/v2/users/{user_id}?complete=true") assert response.status_code == 200 @@ -201,9 +197,9 @@ def test_get_user_with_votes(api_client: APIClient) -> None: def test_get_active_threads_requires_course_id(api_client: APIClient) -> None: """Test getting active threads requires course id.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) @@ -215,9 +211,9 @@ def test_get_active_threads_requires_course_id(api_client: APIClient) -> None: def test_get_active_threads(api_client: APIClient) -> None: """Test getting active threads.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) @@ -233,22 +229,27 @@ def test_get_active_threads(api_client: APIClient) -> None: def test_marks_thread_as_read_for_user(api_client: APIClient) -> None: """Test marking a thread as read for a user.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) - thread_id = CommentThread().insert( - title="Test Thread", - body="This is a test thread", - course_id="course1", - commentable_id="commentable1", - author_id="test_id", - author_username="test-user", - ) - - thread = CommentThread().get(thread_id) + author_id = backend.generate_id() + author_username = "author" + backend.find_or_create_user(author_id, author_username) + thread_id = backend.create_thread( + { + "title": "Test Thread", + "body": "This is a test thread", + "course_id": "course1", + "commentable_id": "commentable1", + "author_id": author_id, + "author_username": author_username, + } + ) + + thread = backend.get_thread(thread_id) assert thread response = api_client.post_json( f"/api/v2/users/{user_id}/read", @@ -257,7 +258,7 @@ def test_marks_thread_as_read_for_user(api_client: APIClient) -> None: assert response.status_code == 200 read_date = {} - updated_user = Users().get(user_id) + updated_user = backend.get_user(user_id) assert updated_user read_states = updated_user.get("read_states", []) for state in read_states: @@ -271,13 +272,13 @@ def test_marks_thread_as_read_for_user(api_client: APIClient) -> None: def test_replaces_username(api_client: APIClient) -> None: """Test replace_username api.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == username @@ -286,7 +287,7 @@ def test_replaces_username(api_client: APIClient) -> None: f"/api/v2/users/{user_id}/replace_username", data={"new_username": new_username} ) assert response.status_code == 200 - updated_user = Users().get(user_id) + updated_user = backend.get_user(user_id) assert updated_user assert updated_user["username"] == new_username @@ -306,14 +307,14 @@ def test_attempts_to_replace_username_and_username_on_content( api_client: APIClient, ) -> None: """Test replace_username api with content.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) setup_10_threads(user_id, username) - user = Users().get(user_id) + user = backend.get_user(user_id) new_username = "test_username_replacement" response = api_client.post_json( @@ -321,10 +322,10 @@ def test_attempts_to_replace_username_and_username_on_content( ) assert response.status_code == 200 - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == new_username - contents = list(Contents().get_list(author_id=user_id)) + contents = list(backend.get_contents(author_id=user_id)) assert len(contents) > 0 for content in contents: assert content["author_username"] == new_username @@ -334,9 +335,9 @@ def test_attempts_to_replace_username_without_sending_new_username( api_client: APIClient, ) -> None: """Test replace_username api without sending new username.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) @@ -351,7 +352,7 @@ def test_attempts_to_retire_user_without_sending_retired_username( api_client: APIClient, ) -> None: """Test retire user api without sending retired username.""" - user_id = "1" + user_id = backend.generate_id() response = api_client.post_json( f"/api/v2/users/{user_id}/retire", data={}, @@ -361,7 +362,7 @@ def test_attempts_to_retire_user_without_sending_retired_username( def test_attempts_to_retire_non_existent_user(api_client: APIClient) -> None: """Test retire non-existent user.""" - user_id = "1234" + user_id = backend.generate_id() retired_username = "retired_user_test" response = api_client.post_json( f"/api/v2/users/{user_id}/retire", @@ -372,15 +373,15 @@ def test_attempts_to_retire_non_existent_user(api_client: APIClient) -> None: def test_retire_user(api_client: APIClient) -> None: """Test retire user.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) setup_10_threads(user_id, username) retired_username = "retired_username_ABCD1234" - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == username @@ -389,11 +390,11 @@ def test_retire_user(api_client: APIClient) -> None: data={"retired_username": retired_username}, ) assert response.status_code == 200 - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == retired_username assert user["email"] == "" - contents = list(Contents().get_list(author_id=user_id)) + contents = list(backend.get_contents(author_id=user_id)) assert len(contents) > 0 for content in contents: if content["_type"] == "CommentThread": @@ -404,26 +405,31 @@ def test_retire_user(api_client: APIClient) -> None: def test_retire_user_with_subscribed_threads(api_client: APIClient) -> None: """Test retire user with subscribed threads.""" - user_id = "test_id" + user_id = backend.generate_id() username = "test-user" - Users().insert( + backend.find_or_create_user( user_id, username, ) + author_id = backend.generate_id() + author_username = "author" + backend.find_or_create_user(author_id, author_username) setup_10_threads(user_id, username) retired_username = "retired_username_ABCD1234" - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == username - thread_id = CommentThread().insert( - title="Test Thread", - body="This is a test thread", - course_id="course1", - commentable_id="commentable1", - author_id="test_id", - author_username="test-user", - ) - subscribe_user(user_id, thread_id, "CommentThread") + thread_id = backend.create_thread( + { + "title": "Test Thread", + "body": "This is a test thread", + "course_id": "course1", + "commentable_id": "commentable1", + "author_id": author_id, + "author_username": author_username, + } + ) + backend.subscribe_user(user_id, thread_id, "CommentThread") response = api_client.get( f"/api/v2/users/{user_id}/subscribed_threads?course_id=course1" ) @@ -438,7 +444,7 @@ def test_retire_user_with_subscribed_threads(api_client: APIClient) -> None: ) assert response.status_code == 200 - user = Users().get(user_id) + user = backend.get_user(user_id) assert user assert user["username"] == retired_username assert user["email"] == "" @@ -457,7 +463,7 @@ def test_retire_user_with_subscribed_threads(api_client: APIClient) -> None: assert body["thread_count"] == 0 # User's comments should be blanked out. - contents = list(Contents().get_list(author_id=user_id)) + contents = list(backend.get_contents(author_id=user_id)) assert len(contents) > 0 for content in contents: if content["_type"] == "CommentThread": diff --git a/tests/test_views/test_votes.py b/tests/test_views/test_votes.py index a9e1d0f7..58c05e68 100644 --- a/tests/test_views/test_votes.py +++ b/tests/test_views/test_votes.py @@ -4,9 +4,12 @@ import pytest -from forum.backends.mongodb import Comment, CommentThread, Users +from forum.backend import get_backend from test_utils.client import APIClient +pytestmark = pytest.mark.django_db +backend = get_backend()() + @pytest.fixture(name="user") def get_user() -> dict[str, Any]: @@ -19,8 +22,8 @@ def get_user() -> dict[str, Any]: dict[str, Any]: The created user, represented as a dictionary. """ user_id = "1" - Users().insert(user_id, username="testuser", email="testuser@example.com") - return Users().get(_id=user_id) or {} + backend.find_or_create_user(user_id, username="testuser") + return backend.get_user(user_id) or {} @pytest.fixture(name="thread") @@ -37,17 +40,19 @@ def get_thread(user: dict[str, Any]) -> dict[str, Any]: Returns: dict[str, Any]: The created thread, represented as a dictionary. """ - thread_id = CommentThread().insert( - title="Test Thread", - body="This is a test thread.", - author_id=user["_id"], - course_id="course-v1:Test+Course+2024_S2", - commentable_id="commentable_id", - author_username="testuser", + thread_id = backend.create_thread( + { + "title": "Test Thread", + "body": "This is a test thread.", + "author_id": user["_id"], + "course_id": "course-v1:Test+Course+2024_S2", + "commentable_id": "commentable_id", + "author_username": "testuser", + } ) - votes = Comment().get_votes_dict(up=["2", "3"], down=["4", "5"]) - CommentThread().update_votes(content_id=thread_id, votes=votes) - return CommentThread().get(_id=thread_id) or {} + votes = backend.get_votes_dict(up=["2", "3"], down=["4", "5"]) + backend.update_thread(thread_id, votes=votes) + return backend.get_thread(thread_id) or {} @pytest.fixture(name="comment") @@ -65,16 +70,18 @@ def get_comment(user: dict[str, Any], thread: dict[str, Any]) -> dict[str, Any]: Returns: dict[str, Any]: The created comment, represented as a dictionary. """ - comment_id = Comment().insert( - body="This is a test comment.", - course_id="course-v1:Test+Course+2024_S2", - comment_thread_id=thread["_id"], - author_id=user["_id"], - author_username="testuser", + comment_id = backend.create_comment( + { + "body": "This is a test comment.", + "course_id": "course-v1:Test+Course+2024_S2", + "comment_thread_id": thread["_id"], + "author_id": user["_id"], + "author_username": "testuser", + } ) - votes = Comment().get_votes_dict(up=["2", "3"], down=["4", "5"]) - Comment().update_votes(content_id=comment_id, votes=votes) - return Comment().get(_id=comment_id) or {} + votes = backend.get_votes_dict(up=["2", "3"], down=["4", "5"]) + backend.update_comment(comment_id, votes=votes) + return backend.get_comment(comment_id) or {} def test_upvote_thread_api( @@ -110,7 +117,7 @@ def test_upvote_thread_api( assert response_data is not None assert response_data["votes"]["up_count"] == prev_up_count + 1 - thread_data = CommentThread().get(_id=thread_id) or {} + thread_data = backend.get_thread(thread_id) or {} assert thread_data["votes"]["up_count"] == prev_up_count + 1 @@ -151,7 +158,7 @@ def test_vote_thread_api( data={"user_id": user_id, "value": "down"}, ) - thread_data = CommentThread().get(_id=thread_id) or {} + thread_data = backend.get_thread(thread_id) or {} assert thread_data["votes"]["up_count"] == prev_up_count assert thread_data["votes"]["down_count"] == prev_down_count + 1 @@ -195,7 +202,7 @@ def test_downvote_thread_api( assert response_data is not None assert response_data["votes"]["down_count"] == prev_down_count + 1 - thread_data = CommentThread().get(_id=thread_id) + thread_data = backend.get_thread(thread_id) assert thread_data is not None assert thread_data["votes"]["down_count"] == prev_down_count + 1 @@ -235,7 +242,7 @@ def test_remove_vote_thread_api( assert response_data["votes"]["up_count"] == prev_up_count assert response_data["votes"]["down_count"] == prev_down_count - thread_data = CommentThread().get(_id=thread_id) or {} + thread_data = backend.get_thread(thread_id) or {} assert thread_data is not None assert thread_data["votes"]["up_count"] == prev_up_count assert thread_data["votes"]["down_count"] == prev_down_count @@ -254,7 +261,7 @@ def test_remove_vote_thread_api( assert response_data["votes"]["up_count"] == prev_up_count assert response_data["votes"]["down_count"] == prev_down_count - thread_data = CommentThread().get(_id=thread_id) or {} + thread_data = backend.get_thread(thread_id) or {} assert thread_data is not None assert thread_data["votes"]["up_count"] == prev_up_count assert thread_data["votes"]["down_count"] == prev_down_count @@ -288,7 +295,7 @@ def test_upvote_comment_api( assert response_data is not None assert response_data["votes"]["up_count"] == prev_up_count + 1 - comment_data = Comment().get(_id=comment_id) + comment_data = backend.get_comment(comment_id) assert comment_data is not None assert comment_data["votes"]["up_count"] == prev_up_count + 1 @@ -321,7 +328,7 @@ def test_downvote_comment_api( assert response_data is not None assert response_data["votes"]["down_count"] == prev_down_count + 1 - comment_data = Comment().get(_id=comment_id) + comment_data = backend.get_comment(comment_id) assert comment_data is not None assert comment_data["votes"]["down_count"] == prev_down_count + 1 @@ -361,7 +368,7 @@ def test_remove_vote_comment_api( assert response_data["votes"]["up_count"] == prev_up_count assert response_data["votes"]["down_count"] == prev_down_count - comment_data = Comment().get(_id=comment_id) + comment_data = backend.get_comment(comment_id) assert comment_data is not None assert comment_data["votes"]["up_count"] == prev_up_count assert comment_data["votes"]["down_count"] == prev_down_count