From 9c0b07776db28929ff5e3344bb9c1d81808dd38c Mon Sep 17 00:00:00 2001 From: Muhammad Faraz Maqsood Date: Wed, 28 Aug 2024 15:26:57 +0500 Subject: [PATCH 1/2] feat: implement thread API This PR includes - threads GET (for getting all threads of a course with filters) - /threads/:thread_id DELETE (for deleting a thread) - /threads/:thread_id GET (for getting threads data, including its children data) - /threads/:thread_id PUT (for updating thread) - /course/threads POST (for creating a new thread) - some minor fixes in serializers and models - close #34 --- forum/models/comments.py | 23 ++- forum/models/contents.py | 6 +- forum/models/model_utils.py | 142 +++++++++++++-- forum/models/threads.py | 26 ++- forum/serializers/comment.py | 38 ++-- forum/serializers/thread.py | 67 +++++-- forum/urls.py | 27 ++- forum/utils.py | 17 ++ forum/views/comments.py | 3 +- forum/views/subscriptions.py | 74 +------- forum/views/threads.py | 333 +++++++++++++++++++++++++++++++++++ 11 files changed, 629 insertions(+), 127 deletions(-) create mode 100644 forum/views/threads.py diff --git a/forum/models/comments.py b/forum/models/comments.py index 171a193f..32229a54 100644 --- a/forum/models/comments.py +++ b/forum/models/comments.py @@ -127,11 +127,15 @@ def insert( if parent_id: comment_data["parent_id"] = ObjectId(parent_id) + comment_data["endorsement"] = None + result = self._collection.insert_one(comment_data) if parent_id: self.update_child_count_in_parent_comment(parent_id, 1) if comment_thread_id: self.update_comment_count_in_comment_thread(comment_thread_id, 1) + + self.update_sk(str(result.inserted_id), parent_id) return str(result.inserted_id) def update( @@ -158,6 +162,7 @@ def update( editing_user_id: Optional[str] = None, edit_reason_code: Optional[str] = None, endorsement_user_id: Optional[str] = None, + sk: Optional[str] = None, ) -> int: """ Updates a comment document in the database. @@ -200,6 +205,7 @@ def update( ("child_count", child_count), ("depth", depth), ("closed", closed), + ("sk", sk), ] update_data: dict[str, Any] = { field: value for field, value in fields if value is not None @@ -216,6 +222,7 @@ def update( edit_history = [] if edit_history is None else edit_history edit_history.append( { + "author_id": editing_user_id, "original_body": original_body, "reason_code": edit_reason_code, "editor_username": self.get_author_username(editing_user_id), @@ -315,5 +322,19 @@ def update_comment_count_in_comment_thread( Returns: None. """ - update_comment_count_query = {"$inc": {"comment_count": count}} + update_comment_count_query = { + "$inc": {"comment_count": count}, + "$set": {"last_activity_at": datetime.now()}, + } CommentThread().update_count(comment_thread_id, update_comment_count_query) + + def get_sk(self, id: str, parent_id: Optional[str]) -> str: + """Returns sk field.""" + if parent_id is not None: + return f"{parent_id}-{id}" + return f"{id}" + + def update_sk(self, id: str, parent_id: Optional[str]) -> None: + """Updates sk field.""" + sk = self.get_sk(id, parent_id) + self.update(id, sk=sk) diff --git a/forum/models/contents.py b/forum/models/contents.py index 94aefc55..b4c130a2 100644 --- a/forum/models/contents.py +++ b/forum/models/contents.py @@ -50,7 +50,11 @@ def list(self, **kwargs: Any) -> Any: """ if self.content_type: kwargs["_type"] = self.content_type - return self._collection.find(kwargs) + result = self._collection.find(kwargs) + sort = kwargs.pop("sort", None) + if sort: + return result.sort("sk", sort) + return result @classmethod def get_votes_dict(cls, up: List[str], down: List[str]) -> dict[str, Any]: diff --git a/forum/models/model_utils.py b/forum/models/model_utils.py index e2152864..2387ddea 100644 --- a/forum/models/model_utils.py +++ b/forum/models/model_utils.py @@ -4,6 +4,8 @@ from bson import ObjectId from django.core.exceptions import ObjectDoesNotExist +from rest_framework import status +from rest_framework.response import Response from forum.models import Comment, CommentThread, Contents, Subscriptions, Users @@ -332,23 +334,23 @@ def get_read_states( whether the thread is read and the unread comment count. """ read_states = {} - user = Users().find_one({"_id": user_id, "read_states.course_id": course_id}) - read_state = user["read_states"][0] if user else {} - - if read_state: - read_dates = read_state.get("last_read_times", {}) - for thread in threads: - thread_key = str(thread["_id"]) - if thread_key in read_dates: - is_read = read_dates[thread_key] >= thread["last_activity_at"] - unread_comment_count = Contents().count_documents( - { - "comment_thread_id": ObjectId(thread_key), - "created_at": {"$gte": read_dates[thread_key]}, - "author_id": {"$ne": str(user_id)}, - } - ) - read_states[thread_key] = [is_read, unread_comment_count] + if user_id: + user = Users().find_one({"_id": user_id, "read_states.course_id": course_id}) + read_state = user["read_states"][0] if user else {} + if read_state: + read_dates = read_state.get("last_read_times", {}) + for thread in threads: + thread_key = str(thread["_id"]) + if thread_key in read_dates: + is_read = read_dates[thread_key] >= thread["last_activity_at"] + unread_comment_count = Contents().count_documents( + { + "comment_thread_id": ObjectId(thread_key), + "created_at": {"$gte": read_dates[thread_key]}, + "author_id": {"$ne": str(user_id)}, + } + ) + read_states[thread_key] = [is_read, unread_comment_count] return read_states @@ -776,3 +778,109 @@ def subscribe_user( def unsubscribe_user(user_id: str, source_id: str) -> None: """Unsubscribe a user from a source.""" Subscriptions().delete_subscription(user_id, source_id) + + +def delete_comments_of_a_thread(thread_id): + """Delete comments of a thread.""" + for comment in Comment().list( + comment_thread_id=ObjectId(thread_id), + depth=0, + parent_id=None, + ): + Comment().delete(comment["_id"]) + + +def validate_params( + params: dict[str, Any], user_id: Optional[str] = None +) -> Response | None: + """ + Validate the request parameters. + + Args: + params (dict): The request parameters. + user_id (optional[str]): The Id of the user for validation. + + Returns: + Response: A Response object with an error message if doesn't exist. + """ + valid_params = [ + "course_id", + "author_id", + "thread_type", + "flagged", + "unread", + "unanswered", + "unresponded", + "count_flagged", + "sort_key", + "page", + "per_page", + "request_id", + ] + if not user_id: + valid_params.append("user_id") + if "user_id" not in params: + return Response( + {"error": "Missing required parameter: user_id"}, + status=status.HTTP_400_BAD_REQUEST, + ) + user_id = params.get("user_id") + + for key in params: + if key not in valid_params: + return Response( + {"error": f"Invalid parameter: {key}"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if "course_id" not in params: + return Response( + {"error": "Missing required parameter: course_id"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + user = Users().get(user_id) + if not user: + return Response( + {"error": "User doesn't exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + return None + + +def get_threads( + params: dict[str, Any], + user_id: str, + serializer: Any, + thread_ids: Optional[list[str]] = None, + include_context: Optional[bool] = False, +) -> dict[str, Any]: + """get subscribed or all threads of a specific course for a specific user.""" + threads = handle_threads_query( + thread_ids, + user_id, + params["course_id"], + get_group_ids_from_params(params), + params.get("author_id", ""), + params.get("thread_type"), + bool(params.get("flagged", False)), + bool(params.get("unread", False)), + bool(params.get("unanswered", False)), + bool(params.get("unresponded", False)), + bool(params.get("count_flagged", False)), + params.get("sort_key", ""), + int(params.get("page", 1)), + int(params.get("per_page", 100)), + ) + context = {} + if include_context: + context = { + "include_endorsed": True, + "include_read_state": True, + } + if user_id: + context["user_id"] = user_id + serializer = serializer(threads.pop("collection"), many=True, context=context) + threads["collection"] = serializer.data + return threads diff --git a/forum/models/threads.py b/forum/models/threads.py index 4f6419b3..5afafdec 100644 --- a/forum/models/threads.py +++ b/forum/models/threads.py @@ -6,6 +6,7 @@ from bson import ObjectId from forum.models.contents import BaseContents +from forum.models.users import Users class CommentThread(BaseContents): @@ -79,7 +80,7 @@ def insert( course_id: str, commentable_id: str, author_id: str, - author_username: str, + author_username: Optional[str] = None, anonymous: bool = False, anonymous_to_peers: bool = False, thread_type: str = "discussion", @@ -140,7 +141,7 @@ def insert( "anonymous_to_peers": anonymous_to_peers, "closed": False, "author_id": author_id, - "author_username": author_username, + "author_username": author_username or self.get_author_username(author_id), "created_at": date, "updated_at": date, "last_activity_at": date, @@ -174,6 +175,10 @@ def update( pinned: Optional[bool] = None, comments_count: Optional[int] = None, endorsed: Optional[bool] = None, + edit_history: Optional[list[dict[str, Any]]] = None, + original_body: Optional[str] = None, + editing_user_id: Optional[str] = None, + edit_reason_code: Optional[str] = None, ) -> int: """ Updates a thread document in the database. @@ -227,6 +232,18 @@ def update( update_data: dict[str, Any] = { field: value for field, value in fields if value is not None } + if editing_user_id: + edit_history = [] if edit_history is None else edit_history + edit_history.append( + { + "author_id": editing_user_id, + "original_body": original_body, + "reason_code": edit_reason_code, + "editor_username": self.get_author_username(editing_user_id), + "created_at": datetime.now(), + } + ) + update_data["edit_history"] = edit_history date = datetime.now() update_data["updated_at"] = date @@ -236,3 +253,8 @@ def update( {"$set": update_data}, ) return result.modified_count + + def get_author_username(self, author_id: str) -> str | None: + """Return username for the respective author_id(user_id)""" + user = Users().get(author_id) + return user.get("username") if user else None diff --git a/forum/serializers/comment.py b/forum/serializers/comment.py index 39928d19..d2c853c7 100644 --- a/forum/serializers/comment.py +++ b/forum/serializers/comment.py @@ -4,10 +4,13 @@ from typing import Any +from bson import ObjectId from rest_framework import serializers +from forum.models import Comment from forum.serializers.contents import ContentSerializer from forum.serializers.custom_datetime import CustomDateTimeField +from forum.utils import prepare_comment_data_for_get_children class EndorsementSerializer(serializers.Serializer[dict[str, Any]]): @@ -49,11 +52,12 @@ class CommentSerializer(ContentSerializer): endorsed = serializers.BooleanField(default=False) depth = serializers.IntegerField(default=0) - thread_id = serializers.CharField() + thread_id = serializers.CharField(source="comment_thread_id") parent_id = serializers.CharField(default=None, allow_null=True) child_count = serializers.IntegerField(default=0) - sk = serializers.SerializerMethodField() + sk = serializers.CharField(default=None, required=False, allow_null=True) endorsement = EndorsementSerializer(default=None, required=False, allow_null=True) + children = serializers.SerializerMethodField() def __init__(self, *args: Any, **kwargs: Any) -> None: exclude_fields = kwargs.pop("exclude_fields", None) @@ -62,22 +66,32 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: for field in exclude_fields: self.fields.pop(field, None) + def get_children(self, obj: Any) -> list[dict[str, Any]]: + if not self.context.get("recursive", False): + return [] + + children = list( + Comment().list( + parent_id=ObjectId(obj["_id"]), + depth=1, + sort=self.context.get("sort", -1), + ) + ) + children_data = prepare_comment_data_for_get_children(children) + serializer = CommentSerializer( + children_data, + many=True, + context={"recursive": False}, + exclude_fields=["sk"], + ) + return serializer.data + def to_representation(self, instance: Any) -> dict[str, Any]: comment = super().to_representation(instance) if comment["parent_id"] == "None": comment["parent_id"] = None return comment - def get_sk(self, obj: dict[str, Any]) -> str: - """Return sk field""" - is_child = obj.get("parent_id") - if is_child is not None: - return "{parent_id}-{id}".format( - parent_id=obj.get("parent_id"), id=obj.get("_id") - ) - else: - return "{id}".format(id=obj.get("_id")) - def create(self, validated_data: dict[str, Any]) -> Any: """Raise NotImplementedError""" raise NotImplementedError diff --git a/forum/serializers/thread.py b/forum/serializers/thread.py index 15506a98..343bd91f 100644 --- a/forum/serializers/thread.py +++ b/forum/serializers/thread.py @@ -4,16 +4,22 @@ from typing import Any, Optional +from bson import ObjectId +from pymongo import DESCENDING, ASCENDING from rest_framework import serializers +from rest_framework.serializers import ValidationError +from forum.models import Comment from forum.models.model_utils import ( get_abuse_flagged_count, get_endorsed, get_read_states, get_username_from_id, ) +from forum.serializers.comment import CommentSerializer from forum.serializers.contents import ContentSerializer from forum.serializers.custom_datetime import CustomDateTimeField +from forum.utils import prepare_comment_data_for_get_children class ThreadSerializer(ContentSerializer): @@ -50,7 +56,7 @@ class ThreadSerializer(ContentSerializer): thread_type = serializers.CharField() title = serializers.CharField() - context = serializers.CharField() # type: ignore + context = serializers.CharField() last_activity_at = CustomDateTimeField() closed_by_id = serializers.CharField(allow_null=True, default=None) closed_by = serializers.SerializerMethodField() @@ -58,7 +64,7 @@ class ThreadSerializer(ContentSerializer): tags = serializers.ListField(default=[]) group_id = serializers.CharField(allow_null=True, default=None) pinned = serializers.BooleanField(default=False) - comment_count = serializers.IntegerField(default=0) + comments_count = serializers.IntegerField(required=False, source="comment_count") read = serializers.SerializerMethodField() unread_comments_count = serializers.SerializerMethodField() endorsed = serializers.SerializerMethodField() @@ -83,11 +89,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: - 'include_endorsed' (bool): Whether to include endorsement status. - 'include_read_state' (bool): Whether to include read state information. """ - context = kwargs.pop("context", {}) - self.with_responses = context.pop("with_responses", False) - self.count_flagged = context.pop("count_flagged", False) - self.include_endorsed = context.pop("include_endorsed", False) - self.include_read_state = context.pop("include_read_state", False) + self.context_data = kwargs.get("context", {}) + self.with_responses = self.context_data.pop("with_responses", False) + self.count_flagged = self.context_data.pop("count_flagged", False) + self.include_endorsed = self.context_data.pop("include_endorsed", False) + self.include_read_state = self.context_data.pop("include_read_state", False) # Customize fields based on context if not self.with_responses: @@ -119,11 +125,11 @@ def get_read(self, obj: dict[str, Any]) -> Optional[bool]: Optional[bool]: True if the thread is read, otherwise False or None. """ if self.include_read_state: - if isinstance(obj, dict) and obj.get("read") is not None: + if isinstance(obj, dict) and obj.get("read") is None: return obj.get("read", True) - user_id = obj["user_id"] + user_id = self.context_data.get("user_id", None) course_id = obj["course_id"] - thread_key = obj["id"] + thread_key = obj["_id"] is_read, _ = get_read_states([obj], user_id, course_id).get( thread_key, (False, obj["comment_count"]) ) @@ -141,11 +147,11 @@ def get_unread_comments_count(self, obj: dict[str, Any]) -> Optional[int]: Optional[int]: The number of unread comments or None. """ if self.include_read_state: - if isinstance(obj, dict) and obj.get("unread_comments_count") is not None: + if isinstance(obj, dict) and obj.get("unread_comments_count") is None: return obj.get("unread_comments_count", 0) - user_id = obj["user_id"] + user_id = self.context_data.get("user_id", None) course_id = obj["course_id"] - thread_key = obj["id"] + thread_key = obj["_id"] _, unread_count = get_read_states([obj], user_id, course_id).get( thread_key, (False, obj["comment_count"]) ) @@ -165,7 +171,7 @@ def get_endorsed(self, obj: dict[str, Any]) -> Optional[bool]: if self.include_endorsed: if isinstance(obj, dict) and obj.get("endorsed") is not None: return obj.get("endorsed", True) - thread_key = obj["id"] + thread_key = obj["_id"] return get_endorsed([thread_key]).get(thread_key, False) return None @@ -182,7 +188,7 @@ def get_abuse_flagged_count(self, obj: dict[str, Any]) -> int: if self.count_flagged: if isinstance(obj, dict) and obj.get("abuse_flagged_count") is not None: return obj.get("abuse_flagged_count", 0) - thread_key = obj["id"] + thread_key = obj["_id"] return get_abuse_flagged_count([thread_key]).get(thread_key, 0) return 0 @@ -197,8 +203,32 @@ def get_children(self, obj: dict[str, Any]) -> Optional[Any]: Optional[Any]: The responses or children related to the thread, or None if not included. """ if self.with_responses: - # Implement when needed - print(obj) + sorting_order = ( + DESCENDING + if self.context_data.get("reverse_order", True) + else ASCENDING + ) + children = list( + Comment().list( + comment_thread_id=ObjectId(obj["_id"]), + depth=0, + parent_id=None, + sort=sorting_order, + ) + ) + children_data = prepare_comment_data_for_get_children(children) + serializer = CommentSerializer( + data=children_data, + many=True, + context={ + "recursive": self.context_data.get("recursive", False), + "sort": sorting_order, + }, + exclude_fields=["sk"], + ) + if not serializer.is_valid(raise_exception=True): + raise ValidationError(serializer.errors) + return serializer.data return [] def get_resp_total(self, obj: dict[str, Any]) -> int: @@ -212,8 +242,7 @@ def get_resp_total(self, obj: dict[str, Any]) -> int: int: The total number of responses, defaulting to 0 if not included. """ if self.with_responses: - # Implement when needed - print(obj) + return len(self.get_children(obj)) return 0 def to_representation(self, instance: dict[str, Any]) -> dict[str, Any]: diff --git a/forum/urls.py b/forum/urls.py index 856f7187..2542af2c 100644 --- a/forum/urls.py +++ b/forum/urls.py @@ -14,6 +14,7 @@ ThreadSubscriptionAPIView, UserSubscriptionAPIView, ) +from forum.views.threads import CreateThreadAPIView, ThreadsAPIView, UserThreadsAPIView from forum.views.votes import CommentVoteView, ThreadVoteView api_patterns = [ @@ -52,17 +53,17 @@ CommentsAPIView.as_view(), name="comments-api", ), + path( + "threads//comments", + CreateThreadCommentAPIView.as_view(), + name="create-parent-comment-api", + ), # search threads API path( "search/threads", SearchThreadsView.as_view(), name="search-thread-api", ), - path( - "threads//comments", - CreateThreadCommentAPIView.as_view(), - name="create-parent-comment-api", - ), # subscription APIs path( "users//subscriptions", @@ -79,6 +80,22 @@ ThreadSubscriptionAPIView.as_view(), name="thread-subscriptions", ), + # threads API + path( + "course/threads", + CreateThreadAPIView.as_view(), + name="create-thread-api", + ), + path( + "threads", + UserThreadsAPIView.as_view(), + name="user-threads-api", + ), + path( + "threads/", + ThreadsAPIView.as_view(), + name="threads-api", + ), # Proxy view for various API endpoints path( "", diff --git a/forum/utils.py b/forum/utils.py index 3b532a80..b825520c 100644 --- a/forum/utils.py +++ b/forum/utils.py @@ -57,3 +57,20 @@ def get_int_value_from_collection( return int(collection[key]) except (TypeError, ValueError, KeyError): return default_value + + +def prepare_comment_data_for_get_children(children): + children_data = [] + for child in children: + children_data.append( + { + **child, + "id": str(child.get("_id")), + "user_id": child.get("author_id"), + "thread_id": str(child.get("comment_thread_id")), + "username": child.get("author_username"), + "parent_id": str(child.get("parent_id")), + "type": str(child.get("_type", "")).lower(), + } + ) + return children_data diff --git a/forum/views/comments.py b/forum/views/comments.py index 9615ca4c..c0ddbb50 100644 --- a/forum/views/comments.py +++ b/forum/views/comments.py @@ -74,6 +74,7 @@ def prepare_comment_api_response( "parent_id": str(comment.get("parent_id")), "type": str(comment.get("_type", "")).lower(), } + exclude_fields.append("children") serializer = CommentSerializer( data=comment_data, exclude_fields=exclude_fields, @@ -114,7 +115,7 @@ def get(self, request: Request, comment_id: str) -> Response: ) data = prepare_comment_api_response( comment, - exclude_fields=["sk", "endorsement"], + exclude_fields=["sk"], ) return Response(data, status=status.HTTP_200_OK) diff --git a/forum/views/subscriptions.py b/forum/views/subscriptions.py index 9a58ea6e..736cad35 100644 --- a/forum/views/subscriptions.py +++ b/forum/views/subscriptions.py @@ -11,10 +11,10 @@ from forum.models import CommentThread, Subscriptions, Users from forum.models.model_utils import ( find_subscribed_threads, - get_group_ids_from_params, - handle_threads_query, + get_threads, subscribe_user, unsubscribe_user, + validate_params, ) from forum.pagination import ForumPagination from forum.serializers.subscriptions import SubscriptionSerializer @@ -125,79 +125,15 @@ def get(self, request: Request, user_id: str) -> Response: Raises: HTTP_400_BAD_REQUEST: If the user does not exist. """ - user = Users().get(user_id) - if not user: - return Response( - {"error": "User doesn't exist"}, - status=status.HTTP_400_BAD_REQUEST, - ) - params = request.GET.dict() - validations = self._validate_params(params) + validations = validate_params(params, user_id) if validations: return validations - threads = handle_threads_query( - find_subscribed_threads(user_id, params["course_id"]), - user_id, - params["course_id"], - get_group_ids_from_params(params), - params.get("author_id", ""), - params.get("thread_type"), - bool(params.get("flagged", False)), - bool(params.get("unread", False)), - bool(params.get("unanswered", False)), - bool(params.get("unresponded", False)), - bool(params.get("count_flagged", False)), - params.get("sort_key", ""), - int(params.get("page", 1)), - int(params.get("per_page", 100)), - ) - serializer = ThreadSerializer(threads.pop("collection"), many=True) - threads["collection"] = serializer.data + thread_ids = find_subscribed_threads(user_id, params["course_id"]) + threads = get_threads(params, user_id, ThreadSerializer, thread_ids) return Response(data=threads, status=status.HTTP_200_OK) - def _validate_params(self, params: dict[str, Any]) -> Response | None: - """ - Validate the request parameters. - - Args: - params (dict): The request parameters. - - Returns: - Response: A Response object with an error message if doesn't exist. - """ - - valid_params = [ - "course_id", - "author_id", - "thread_type", - "flagged", - "unread", - "unanswered", - "unresponded", - "count_flagged", - "sort_key", - "page", - "per_page", - "request_id'", - ] - - for key in params: - if key not in valid_params: - return Response( - {"error": f"Invalid parameter: {key}"}, - status=status.HTTP_400_BAD_REQUEST, - ) - - if "course_id" not in params: - return Response( - {"error": "Missing required parameter: course_id"}, - status=status.HTTP_400_BAD_REQUEST, - ) - - return None - class ThreadSubscriptionAPIView(APIView): """ diff --git a/forum/views/threads.py b/forum/views/threads.py new file mode 100644 index 00000000..4d6d4de2 --- /dev/null +++ b/forum/views/threads.py @@ -0,0 +1,333 @@ +"""Forum Threads API Views.""" + +import logging +from typing import Any, Optional + +from django.core.exceptions import ObjectDoesNotExist +from rest_framework import status +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework.request import Request +from rest_framework.serializers import ValidationError +from rest_framework.views import APIView + +from forum.models.comments import Comment +from forum.models.model_utils import ( + delete_comments_of_a_thread, + get_threads, + validate_object, + validate_params, +) +from forum.models.threads import CommentThread +from forum.serializers.thread import ThreadSerializer +from forum.utils import str_to_bool + +log = logging.getLogger(__name__) + + +def get_thread_data(thread: dict[str, Any]) -> dict[str, Any]: + type = str(thread.get("_type", "")).lower() + thread_data = { + **thread, + "id": str(thread.get("_id")), + "type": "thread" if type == "commentthread" else type, + "user_id": thread.get("author_id"), + "username": str(thread.get("author_username")), + "comments_count": thread["comment_count"], + } + return thread_data + + +def prepare_thread_api_response( + thread: dict[str, Any], + include_context: Optional[bool] = False, + data_or_params: Optional[dict[str, Any]] = {}, + include_data_from_params: Optional[bool] = False, +): + thread_data = get_thread_data(thread) + + context = {} + if include_context: + context = { + "include_endorsed": True, + "include_read_state": True, + } + if include_data_from_params: + thread_data["resp_skip"] = ( + "resp_skip" in data_or_params and int(data_or_params["resp_skip"]) or 0 + ) + thread_data["resp_limit"] = ( + "resp_limit" in data_or_params + and int(data_or_params["resp_limit"]) + or 100 + ) + context["recursive"] = ( + str_to_bool(data_or_params.get("recursive", "False")), + ) + context["with_responses"] = str_to_bool( + data_or_params.get("with_responses", "True") + ) + context["mark_as_read"] = str_to_bool( + data_or_params.get("mark_as_read", "False") + ) + context["reverse_order"] = str_to_bool( + data_or_params.get("reverse_order", "True") + ) + context["merge_question_type_responses"] = str_to_bool( + data_or_params.get("merge_question_type_responses", "False") + ) + + if user_id := data_or_params.get("user_id"): + context["user_id"] = str(user_id) + + serializer = ThreadSerializer( + data=thread_data, + context=context, + ) + if not serializer.is_valid(raise_exception=True): + log.error(f"validation error in thread API call: {serializer.errors}") + raise ValidationError(serializer.errors) + + return serializer.data + + +class ThreadsAPIView(APIView): + """ + API view to handle operations related to threads. + + This view uses the CommentThread model for database interactions and the ThreadSerializer + for serializing and deserializing data. + """ + + permission_classes = (AllowAny,) + + def get(self, request: Request, thread_id: str) -> Response: + """ + Retrieve a thread by its ID. + + Args: + request: The HTTP request object. + thread_id: The ID of the thread to retrieve. + + Returns: + Response: A Response object containing the serialized thread data or an error message. + """ + try: + thread = validate_object(CommentThread, thread_id) + except ObjectDoesNotExist: + return Response( + {"error": "Thread does not exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + params = request.query_params + try: + serialized_data = prepare_thread_api_response(thread, True, params, True) + return Response(serialized_data) + + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + def delete(self, request: Request, thread_id: str) -> Response: + """ + Deletes a thread by it's ID. + + Parameters: + request (Request): The incoming request. + thread_id: The ID of the thread to be deleted. + Body: + Empty. + Response: + The details of the thread that is deleted. + """ + try: + thread = validate_object(CommentThread, thread_id) + except ObjectDoesNotExist: + return Response( + {"error": "thread does not exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + delete_comments_of_a_thread(thread_id) + thread = validate_object(CommentThread, thread_id) + + try: + serialized_data = prepare_thread_api_response(thread) + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + CommentThread().delete(thread_id) + + return Response(serialized_data, status=status.HTTP_200_OK) + + def put(self, request: Request, thread_id: str) -> Response: + """ + Updates an existing thread. + + Parameters: + request (Request): The incoming request. + thread_id: The ID of the thread to be edited. + Body: + fields to be updated. + Response: + The details of the thread that is updated. + """ + try: + thread = validate_object(Comment, thread_id) + except ObjectDoesNotExist: + return Response( + {"error": "thread does not exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + data = request.data + update_thread_data: dict[str, Any] = self._get_update_thread_data(data) + if thread: + update_thread_data["edit_history"] = thread.get("edit_history", []) + update_thread_data["original_body"] = thread.get("body") + + CommentThread().update(thread_id, **update_thread_data) + updated_thread = CommentThread().get(thread_id) + try: + serialized_data = prepare_thread_api_response(updated_thread, True, data) + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return Response(serialized_data, status=status.HTTP_200_OK) + + def _get_update_thread_data(self, data: dict[str, Any]) -> dict[str, Any]: + """convert request data to a dict excluding empty data""" + fields = [ + ("title", data.get("title")), + ("body", data.get("body")), + ("course_id", data.get("course_id")), + ("anonymous", str_to_bool(data.get("anonymous", "False"))), + ( + "anonymous_to_peers", + str_to_bool(data.get("anonymous_to_peers", "False")), + ), + ("closed", str_to_bool(data.get("closed", "False"))), + ("commentable_id", data.get("commentable_id", "course")), + ("author_id", data.get("user_id")), + ("editing_user_id", data.get("editing_user_id")), + ("pinned", str_to_bool(data.get("pinned", "False"))), + ("thread_type", data.get("thread_type", "discussion")), + ("edit_reason_code", data.get("edit_reason_code")), + ] + return {field: value for field, value in fields if value is not None} + + +class CreateThreadAPIView(APIView): + """ + API view to create a new thread. + + This view uses the CommentThread model for database interactions and the ThreadSerializer + for serializing and deserializing data. + """ + + permission_classes = (AllowAny,) + + def post(self, request: Request) -> Response: + """ + Create a new thread. + + Parameters: + request (Request): The incoming request. + Body: + fields to be added in a new thread. + Response: + The details of the thread that is created. + """ + data = request.data + try: + self.validate_request_data(data) + except ValueError as error: + return Response( + {"error": str(error)}, + status=status.HTTP_400_BAD_REQUEST, + ) + + thread = self.create_thread(data) + try: + serialized_data = prepare_thread_api_response(thread, True, data) + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return Response(serialized_data, status=status.HTTP_200_OK) + + def validate_request_data(self, data: dict[str, Any]) -> None: + """ + Validates the request data if it exists or not. + + Parameters: + data: request data to validate. + Response: + raise exception if some data does not exists. + """ + fields_to_validate = ["title", "body", "course_id", "user_id"] + for field in fields_to_validate: + if field not in data or not data[field]: + raise ValueError(f"{field} is missing.") + + def create_thread(self, data: dict[str, Any]) -> Any: + """handle thread creation and returns a thread.""" + new_comment_id = CommentThread().insert( + title=data["title"], + body=data["body"], + course_id=data["course_id"], + anonymous=str_to_bool(data.get("anonymous", "False")), + anonymous_to_peers=str_to_bool(data.get("anonymous_to_peers", "False")), + author_id=data["user_id"], + commentable_id=data.get("commentable_id", "course"), + thread_type=data.get("thread_type", "discussion"), + ) + return CommentThread().get(new_comment_id) + + +class UserThreadsAPIView(APIView): + """ + API View for getting all threads of a course. + + This view provides an endpoint for retrieving all threads based on course id. + """ + + permission_classes = (AllowAny,) + + def get(self, request: Request) -> Response: + """ + Retrieve a course's threads. + + Args: + request (HttpRequest): The HTTP request object. + + Returns: + Response: A Response object with the threads data. + + Raises: + HTTP_400_BAD_REQUEST: If the user does not exist. + """ + params = request.GET.dict() + validations = validate_params(params) + if validations: + return validations + + user_id = params.get("user_id") + course_id = params.get("course_id") + thread_filter = { + "_type": {"$in": [CommentThread.content_type]}, + "course_id": {"$in": [course_id]}, + } + filtered_threads = CommentThread().find(thread_filter) + thread_ids = [thread["_id"] for thread in filtered_threads] + threads = get_threads(params, user_id, ThreadSerializer, thread_ids, True) + return Response(data=threads, status=status.HTTP_200_OK) From 926bddd6ccc0da2b52533857512af8179e5e4051 Mon Sep 17 00:00:00 2001 From: Muhammad Faraz Maqsood Date: Wed, 28 Aug 2024 16:44:58 +0500 Subject: [PATCH 2/2] fix: quality fixes --- forum/models/comments.py | 12 +++--- forum/models/model_utils.py | 19 +++++---- forum/serializers/comment.py | 3 +- forum/serializers/thread.py | 7 +-- forum/utils.py | 5 ++- forum/views/comments.py | 2 + forum/views/threads.py | 82 ++++++++++++++++++++---------------- 7 files changed, 74 insertions(+), 56 deletions(-) diff --git a/forum/models/comments.py b/forum/models/comments.py index fb14b7a5..44db3f87 100644 --- a/forum/models/comments.py +++ b/forum/models/comments.py @@ -342,13 +342,13 @@ def update_comment_count_in_comment_thread( } CommentThread().update_count(comment_thread_id, update_comment_count_query) - def get_sk(self, id: str, parent_id: Optional[str]) -> str: + def get_sk(self, _id: str, parent_id: Optional[str]) -> str: """Returns sk field.""" if parent_id is not None: - return f"{parent_id}-{id}" - return f"{id}" + return f"{parent_id}-{_id}" + return f"{_id}" - def update_sk(self, id: str, parent_id: Optional[str]) -> None: + def update_sk(self, _id: str, parent_id: Optional[str]) -> None: """Updates sk field.""" - sk = self.get_sk(id, parent_id) - self.update(id, sk=sk) + sk = self.get_sk(_id, parent_id) + self.update(_id, sk=sk) diff --git a/forum/models/model_utils.py b/forum/models/model_utils.py index 7df4f730..8149b9e3 100644 --- a/forum/models/model_utils.py +++ b/forum/models/model_utils.py @@ -780,7 +780,7 @@ def unsubscribe_user(user_id: str, source_id: str) -> None: Subscriptions().delete_subscription(user_id, source_id) -def delete_comments_of_a_thread(thread_id): +def delete_comments_of_a_thread(thread_id: str) -> None: """Delete comments of a thread.""" for comment in Comment().list( comment_thread_id=ObjectId(thread_id), @@ -839,12 +839,13 @@ def validate_params( status=status.HTTP_400_BAD_REQUEST, ) - user = Users().get(user_id) - if not user: - return Response( - {"error": "User doesn't exist"}, - status=status.HTTP_400_BAD_REQUEST, - ) + if user_id: + user = Users().get(user_id) + if not user: + return Response( + {"error": "User doesn't exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) return None @@ -853,7 +854,7 @@ def get_threads( params: dict[str, Any], user_id: str, serializer: Any, - thread_ids: Optional[list[str]] = None, + thread_ids: list[str], include_context: Optional[bool] = False, ) -> dict[str, Any]: """get subscribed or all threads of a specific course for a specific user.""" @@ -873,7 +874,7 @@ def get_threads( int(params.get("page", 1)), int(params.get("per_page", 100)), ) - context = {} + context: dict[str, Any] = {} if include_context: context = { "include_endorsed": True, diff --git a/forum/serializers/comment.py b/forum/serializers/comment.py index d2c853c7..b8fd4270 100644 --- a/forum/serializers/comment.py +++ b/forum/serializers/comment.py @@ -67,6 +67,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.fields.pop(field, None) def get_children(self, obj: Any) -> list[dict[str, Any]]: + """Get comments of a thread.""" if not self.context.get("recursive", False): return [] @@ -84,7 +85,7 @@ def get_children(self, obj: Any) -> list[dict[str, Any]]: context={"recursive": False}, exclude_fields=["sk"], ) - return serializer.data + return list(serializer.data) def to_representation(self, instance: Any) -> dict[str, Any]: comment = super().to_representation(instance) diff --git a/forum/serializers/thread.py b/forum/serializers/thread.py index 343bd91f..6759c55c 100644 --- a/forum/serializers/thread.py +++ b/forum/serializers/thread.py @@ -5,7 +5,7 @@ from typing import Any, Optional from bson import ObjectId -from pymongo import DESCENDING, ASCENDING +from pymongo import ASCENDING, DESCENDING from rest_framework import serializers from rest_framework.serializers import ValidationError @@ -56,7 +56,7 @@ class ThreadSerializer(ContentSerializer): thread_type = serializers.CharField() title = serializers.CharField() - context = serializers.CharField() + context = serializers.CharField() # type: ignore last_activity_at = CustomDateTimeField() closed_by_id = serializers.CharField(allow_null=True, default=None) closed_by = serializers.SerializerMethodField() @@ -242,7 +242,8 @@ def get_resp_total(self, obj: dict[str, Any]) -> int: int: The total number of responses, defaulting to 0 if not included. """ if self.with_responses: - return len(self.get_children(obj)) + children = self.get_children(obj) or [] + return len(children) return 0 def to_representation(self, instance: dict[str, Any]) -> dict[str, Any]: diff --git a/forum/utils.py b/forum/utils.py index f2281420..acca3e54 100644 --- a/forum/utils.py +++ b/forum/utils.py @@ -101,7 +101,10 @@ def get_handler_by_name(name: str) -> Signal: raise KeyError(f"No signal found for the name: {name}") from exc -def prepare_comment_data_for_get_children(children): +def prepare_comment_data_for_get_children( + children: list[dict[str, Any]] +) -> list[dict[str, Any]]: + """Prepare children data to be used in serializer.""" children_data = [] for child in children: children_data.append( diff --git a/forum/views/comments.py b/forum/views/comments.py index c0ddbb50..96459a1a 100644 --- a/forum/views/comments.py +++ b/forum/views/comments.py @@ -74,6 +74,8 @@ def prepare_comment_api_response( "parent_id": str(comment.get("parent_id")), "type": str(comment.get("_type", "")).lower(), } + if not exclude_fields: + exclude_fields = [] exclude_fields.append("children") serializer = CommentSerializer( data=comment_data, diff --git a/forum/views/threads.py b/forum/views/threads.py index 4d6d4de2..234581bc 100644 --- a/forum/views/threads.py +++ b/forum/views/threads.py @@ -6,8 +6,8 @@ from django.core.exceptions import ObjectDoesNotExist from rest_framework import status from rest_framework.permissions import AllowAny -from rest_framework.response import Response from rest_framework.request import Request +from rest_framework.response import Response from rest_framework.serializers import ValidationError from rest_framework.views import APIView @@ -20,17 +20,18 @@ ) from forum.models.threads import CommentThread from forum.serializers.thread import ThreadSerializer -from forum.utils import str_to_bool +from forum.utils import get_int_value_from_collection, str_to_bool log = logging.getLogger(__name__) def get_thread_data(thread: dict[str, Any]) -> dict[str, Any]: - type = str(thread.get("_type", "")).lower() + """Prepare thread data for the api response.""" + _type = str(thread.get("_type", "")).lower() thread_data = { **thread, "id": str(thread.get("_id")), - "type": "thread" if type == "commentthread" else type, + "type": "thread" if _type == "commentthread" else _type, "user_id": thread.get("author_id"), "username": str(thread.get("author_username")), "comments_count": thread["comment_count"], @@ -41,9 +42,10 @@ def get_thread_data(thread: dict[str, Any]) -> dict[str, Any]: def prepare_thread_api_response( thread: dict[str, Any], include_context: Optional[bool] = False, - data_or_params: Optional[dict[str, Any]] = {}, + data_or_params: Optional[dict[str, Any]] = None, include_data_from_params: Optional[bool] = False, -): +) -> dict[str, Any] | None: + """Serialize thread data for the api response.""" thread_data = get_thread_data(thread) context = {} @@ -52,33 +54,32 @@ def prepare_thread_api_response( "include_endorsed": True, "include_read_state": True, } - if include_data_from_params: - thread_data["resp_skip"] = ( - "resp_skip" in data_or_params and int(data_or_params["resp_skip"]) or 0 - ) - thread_data["resp_limit"] = ( - "resp_limit" in data_or_params - and int(data_or_params["resp_limit"]) - or 100 - ) - context["recursive"] = ( - str_to_bool(data_or_params.get("recursive", "False")), - ) - context["with_responses"] = str_to_bool( - data_or_params.get("with_responses", "True") - ) - context["mark_as_read"] = str_to_bool( - data_or_params.get("mark_as_read", "False") - ) - context["reverse_order"] = str_to_bool( - data_or_params.get("reverse_order", "True") - ) - context["merge_question_type_responses"] = str_to_bool( - data_or_params.get("merge_question_type_responses", "False") - ) - - if user_id := data_or_params.get("user_id"): - context["user_id"] = str(user_id) + if data_or_params: + if include_data_from_params: + thread_data["resp_skip"] = get_int_value_from_collection( + data_or_params, "resp_skip", 0 + ) + thread_data["resp_limit"] = get_int_value_from_collection( + data_or_params, "resp_limit", 100 + ) + context["recursive"] = str_to_bool( + data_or_params.get("recursive", "False") + ) + context["with_responses"] = str_to_bool( + data_or_params.get("with_responses", "True") + ) + context["mark_as_read"] = str_to_bool( + data_or_params.get("mark_as_read", "False") + ) + context["reverse_order"] = str_to_bool( + data_or_params.get("reverse_order", "True") + ) + context["merge_question_type_responses"] = str_to_bool( + data_or_params.get("merge_question_type_responses", "False") + ) + + if user_id := data_or_params.get("user_id"): + context["user_id"] = user_id serializer = ThreadSerializer( data=thread_data, @@ -193,14 +194,23 @@ def put(self, request: Request, thread_id: str) -> Response: CommentThread().update(thread_id, **update_thread_data) updated_thread = CommentThread().get(thread_id) try: - serialized_data = prepare_thread_api_response(updated_thread, True, data) + if updated_thread: + serialized_data = prepare_thread_api_response( + updated_thread, + True, + data, + ) + return Response(serialized_data, status=status.HTTP_200_OK) except ValidationError as error: return Response( error.detail, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) - return Response(serialized_data, status=status.HTTP_200_OK) + return Response( + {"error": "Thread is not updated"}, + status=status.HTTP_400_BAD_REQUEST, + ) def _get_update_thread_data(self, data: dict[str, Any]) -> dict[str, Any]: """convert request data to a dict excluding empty data""" @@ -321,7 +331,7 @@ def get(self, request: Request) -> Response: if validations: return validations - user_id = params.get("user_id") + user_id = params.get("user_id", "") course_id = params.get("course_id") thread_filter = { "_type": {"$in": [CommentThread.content_type]},